infix 関数適用演算子で複数のモナドをつなぐ
競プロ典型 90 問 65 日目 - RGB Balls 2(★7)
import Control.Monad
import Control.Monad.ST
import Data.Bits
import Data.Bool
import qualified Data.ByteString.Char8 as C
import Data.Function
import Data.List
import Data.Vector.Unboxed ((!))
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as UM
main = sol <$> ints <*> ints >>= print
ints = unfoldr (C.readInt . C.dropWhile (==' ')) <$> C.getLine
sol [r,g,b,k] [x,y,z] = convAt (conv ur ug) ub k
where
t = tbl (1+maximum [r,g,b])
ur = U.generate (1+r) $ (bool 0 . comb t r) <*> (k-y <=)
ug = U.generate (1+g) $ (bool 0 . comb t g) <*> (k-z <=)
ub = U.generate (1+b) $ (bool 0 . comb t b) <*> (k-x <=)
convAt a b k = U.ifoldl' f 0 b
where
f s i x
| k-U.length a<i && i<=k && x>0 = s .+. a!(k-i) .*. x
| otherwise = s
conv a b = U.create $ do
let
m = U.length a + U.length b - 1
k = finiteBitSize m - countLeadingZeros m
n = 2^k
ma <- U.thaw $ a U.++ U.replicate (n-U.length a) 0
mb <- U.thaw $ b U.++ U.replicate (n-U.length b) 0
let r = U.iterateN (1+n `div` 2) (.*. pwr 3 ((p-1) `div` n)) 1
fft ma n r
fft mb n r
forM_ [0..n-1] $
(>>=) . UM.unsafeRead mb <*> flip (UM.unsafeModify ma . flip (.*.))
fft ma n $ U.iterateN (1+n `div` 2) (.*. pwr 3 (p-1-(p-1) `div` n)) 1
forM_ [0..n-1] $ UM.unsafeModify ma (.*. inv n)
return ma
fft :: UM.MVector s Int -> Int -> U.Vector Int -> ST s ()
fft ma n r = when (n>1) $
fix (\loop i w mv -> if w>=n
then return mv
else do
mu <- UM.replicate n 0
forM_ [0..i-1] $ \j ->
forM_ [0..w-1] $ \k -> do
v1 <- UM.unsafeRead mv (k+w*j)
v2 <- UM.unsafeRead mv (k+w*j+(n `div` 2))
UM.unsafeWrite mu (k+w*j*2) (v1 .+. v2)
UM.unsafeWrite mu (k+w*j*2+w) (r!(w*j)*(v1 .-. v2))
loop (i `div` 2) (w*2) mu
) (n `div` 2) 1 ma >>= UM.unsafeCopy ma
tbl n = runST $ do
fc <- UM.replicate (n+1) 0
ifc <- UM.replicate (n+1) 0
UM.unsafeWrite fc 0 1
forM_ [1..n] $ \i -> do
v <- UM.unsafeRead fc (i-1)
UM.unsafeWrite fc i (v .*. i)
v <- UM.unsafeRead fc n
UM.unsafeWrite ifc n (inv v)
forM_ [n, n-1..1] $ \i -> do
v <- UM.unsafeRead ifc i
UM.unsafeWrite ifc (i-1) (v .*. i)
(,) <$> U.unsafeFreeze fc <*> U.unsafeFreeze ifc
comb (fc, ifc) n m
| m<0 || n<m = 0
| otherwise = fc!n .*. ifc!m .*. ifc!(n-m)
p = 998244353 :: Int
(.*.) = ((`mod` p) .) . (*)
infixl 7 .*.
(.+.) = ((`mod` p) .) . (+)
infixl 6 .+.
(.-.) = ((`mod` p) .) . (-)
infixl 6 .-.
inv 1 = 1
inv !a = pwr a (p-2)
pwr !a !n
| n == 0 = 1
| odd n = a .*. r
| otherwise = r
where
q = pwr a (n `div` 2)
r = q .*. q
fac 0 = 1
fac !n = n .*. fac (n-1)
いいなと思ったら応援しよう!
ありがとう