見出し画像

infix 関数適用演算子で複数のモナドをつなぐ

競プロ典型 90 問 65 日目 - RGB Balls 2(★7)

import Control.Monad
import Control.Monad.ST
import Data.Bits
import Data.Bool
import qualified Data.ByteString.Char8 as C
import Data.Function
import Data.List
import Data.Vector.Unboxed ((!))
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM

main = sol <$> ints <*> ints >>= print

ints = unfoldr (C.readInt . C.dropWhile (==' ')) <$> C.getLine

sol [r,g,b,k] [x,y,z] = convAt (conv ur ug) ub k
  where
  t = tbl (1+maximum [r,g,b])
  ur = U.generate (1+r) $ (bool 0 . comb t r) <*> (k-y <=) 
  ug = U.generate (1+g) $ (bool 0 . comb t g) <*> (k-z <=)
  ub = U.generate (1+b) $ (bool 0 . comb t b) <*> (k-x <=)

convAt a b k = U.ifoldl' f 0 b
  where
  f s i x 
    | k-U.length a<i && i<=k && x>0 = s .+. a!(k-i) .*. x
    | otherwise = s

conv a b = U.create $ do
  let
    m = U.length a + U.length b - 1
    k = finiteBitSize m - countLeadingZeros m
    n = 2^k
  ma <- U.thaw $ a U.++ U.replicate (n-U.length a) 0
  mb <- U.thaw $ b U.++ U.replicate (n-U.length b) 0
  let r = U.iterateN (1+n `div` 2) (.*. pwr 3 ((p-1) `div` n)) 1
  fft ma n r
  fft mb n r
  forM_ [0..n-1] $
    (>>=) . UM.unsafeRead mb <*> flip (UM.unsafeModify ma . flip (.*.))
  fft ma n $ U.iterateN (1+n `div` 2) (.*. pwr 3 (p-1-(p-1) `div` n)) 1
  forM_ [0..n-1] $ UM.unsafeModify ma (.*. inv n)
  return ma

fft :: UM.MVector s Int -> Int -> U.Vector Int -> ST s ()
fft ma n r = when (n>1) $
  fix (\loop i w mv -> if w>=n
    then return mv
    else do
      mu <- UM.replicate n 0
      forM_ [0..i-1] $ \j ->
        forM_ [0..w-1] $ \k -> do
          v1 <- UM.unsafeRead mv (k+w*j)
          v2 <- UM.unsafeRead mv (k+w*j+(n `div` 2))
          UM.unsafeWrite mu (k+w*j*2) (v1 .+. v2)
          UM.unsafeWrite mu (k+w*j*2+w) (r!(w*j)*(v1 .-. v2))
      loop (i `div` 2) (w*2) mu
    ) (n `div` 2) 1 ma >>= UM.unsafeCopy ma

tbl n = runST $ do
  fc <- UM.replicate (n+1) 0
  ifc <- UM.replicate (n+1) 0
  UM.unsafeWrite fc 0 1
  forM_ [1..n] $ \i -> do
    v <- UM.unsafeRead fc (i-1)
    UM.unsafeWrite fc i (v .*. i)
  v <- UM.unsafeRead fc n
  UM.unsafeWrite ifc n (inv v)
  forM_ [n, n-1..1] $ \i -> do
    v <- UM.unsafeRead ifc i
    UM.unsafeWrite ifc (i-1) (v .*. i)
  (,) <$> U.unsafeFreeze fc <*> U.unsafeFreeze ifc

comb (fc, ifc) n m
  | m<0 || n<m  = 0
  | otherwise   = fc!n .*. ifc!m .*. ifc!(n-m)

p = 998244353 :: Int

(.*.) = ((`mod` p) .) . (*)
infixl 7 .*.

(.+.) = ((`mod` p) .) . (+)
infixl 6 .+.

(.-.) = ((`mod` p) .) . (-)
infixl 6 .-.

inv 1 = 1
inv !a = pwr a (p-2)

pwr !a !n
  | n == 0    = 1
  | odd n     = a .*. r
  | otherwise = r
  where
  q = pwr a (n `div` 2)
  r = q .*. q

fac 0 = 1
fac !n = n .*. fac (n-1)

いいなと思ったら応援しよう!

karoyakani
ありがとう