From eff690198a70a60f33a00455ef6c51ac58415677 Mon Sep 17 00:00:00 2001 From: Abastro Date: Sun, 30 Mar 2025 00:03:39 +0900 Subject: [PATCH] Fast polynomial evaluation --- app/Main.hs | 31 +++++++-- mathematical-algorithms.cabal | 1 + src/Poly.hs | 3 + src/PolyFast.hs | 114 ++++++++++++++++++++++++++++++++++ 4 files changed, 143 insertions(+), 6 deletions(-) create mode 100644 src/PolyFast.hs diff --git a/app/Main.hs b/app/Main.hs index a25aded..2055a86 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -5,6 +5,7 @@ import Poly import System.Random import System.TimeIt import Control.Monad +import PolyFast main :: IO () main = do @@ -16,13 +17,21 @@ main = do putStrLn $ "Naive f * g: " <> show (f * g) putStrLn $ "Karatsuba f * g: " <> show (normalize $ karatsubaMult f g) - putStrLn "" - experimentFor 250 - experimentFor 500 - experimentFor 1000 - karatsubaFor 2000 - karatsubaFor 4000 + -- experimentFor 250 + -- experimentFor 500 + -- experimentFor 1000 + -- karatsubaFor 2000 + -- karatsubaFor 4000 + + fastKaratsubaFor 512 + fastKaratsubaFor 1024 + fastKaratsubaFor 2048 + fastKaratsubaFor 4096 + fastKaratsubaFor 8192 + fastKaratsubaFor 16384 + fastKaratsubaFor 32768 + fastKaratsubaFor 65536 where experimentFor n = do setStdGen $ mkStdGen 10 @@ -45,3 +54,13 @@ main = do putStrLn "Karatsuba:" _ <- timeIt $ evaluate (karatsubaMult f g) putStrLn "Finished" + + fastKaratsubaFor n = do + setStdGen $ mkStdGen 10 + let randomPoly size = makePoly <$> replicateM size (randomRIO (-100, 100)) + putStrLn $ "Size " <> show n + f :: Poly Int <- randomPoly n + g :: Poly Int <- randomPoly n + putStrLn "Fast Karatsuba:" + _ <- timeIt $ evaluate (karatsuba 0 (+) (-) (*) (unwrapPoly f) (unwrapPoly g)) + putStrLn "Finished" diff --git a/mathematical-algorithms.cabal b/mathematical-algorithms.cabal index e3474cf..171dc3f 100644 --- a/mathematical-algorithms.cabal +++ b/mathematical-algorithms.cabal @@ -27,6 +27,7 @@ library import: warnings, deps exposed-modules: Poly + PolyFast hs-source-dirs: src default-language: GHC2021 diff --git a/src/Poly.hs b/src/Poly.hs index 4214554..d7f060c 100644 --- a/src/Poly.hs +++ b/src/Poly.hs @@ -22,6 +22,9 @@ vecZipPad0With f xs ys = V.generate (max (V.length xs) (V.length ys)) $ newtype Poly a = Poly (V.Vector a) deriving (Eq) +unwrapPoly :: Poly a -> V.Vector a +unwrapPoly (Poly v) = v + makePoly :: (V.Unbox a) => [a] -> Poly a makePoly = Poly . V.fromList diff --git a/src/PolyFast.hs b/src/PolyFast.hs new file mode 100644 index 0000000..e7ee280 --- /dev/null +++ b/src/PolyFast.hs @@ -0,0 +1,114 @@ +module PolyFast where + +import Control.Monad +import Control.Monad.ST +import Data.Bits +import Data.List +import Data.Vector.Generic qualified as G +import Data.Vector.Generic.Mutable qualified as MG +import GHC.Base + +plusPoly :: + (G.Vector v a) => + (a -> a -> a) -> + v a -> + v a -> + v a +plusPoly add xs ys = runST $ do + let lenXs = G.length xs + lenYs = G.length ys + lenMn = lenXs `min` lenYs + lenMx = lenXs `max` lenYs + + zs <- MG.unsafeNew lenMx + forM_ [0 .. lenMn - 1] $ \i -> + MG.unsafeWrite zs i (add (G.unsafeIndex xs i) (G.unsafeIndex ys i)) + G.unsafeCopy + (MG.unsafeSlice lenMn (lenMx - lenMn) zs) + (G.unsafeSlice lenMn (lenMx - lenMn) (if lenXs <= lenYs then ys else xs)) + + G.unsafeFreeze zs +{-# INLINEABLE plusPoly #-} + +karatsubaThreshold :: Int +karatsubaThreshold = 32 + +karatsuba :: + (G.Vector v a) => + a -> + (a -> a -> a) -> + (a -> a -> a) -> + (a -> a -> a) -> + v a -> + v a -> + v a +karatsuba zer add sub mul = go + where + conv = inline convolution zer add mul + go xs ys + | lenXs <= karatsubaThreshold || lenYs <= karatsubaThreshold = + conv xs ys + | otherwise = runST $ do + zs <- MG.unsafeNew lenZs + forM_ [0 .. lenZs - 1] $ \k -> do + let z0 = + if k < G.length zs0 + then G.unsafeIndex zs0 k + else zer + z11 = + if k - m >= 0 && k - m < G.length zs11 + then G.unsafeIndex zs11 (k - m) + else zer + z10 = + if k - m >= 0 && k - m < G.length zs0 + then G.unsafeIndex zs0 (k - m) + else zer + z12 = + if k - m >= 0 && k - m < G.length zs2 + then G.unsafeIndex zs2 (k - m) + else zer + z2 = + if k - 2 * m >= 0 && k - 2 * m < G.length zs2 + then G.unsafeIndex zs2 (k - 2 * m) + else zer + MG.unsafeWrite zs k (z0 `add` (z11 `sub` (z10 `add` z12)) `add` z2) + G.unsafeFreeze zs + where + lenXs = G.length xs + lenYs = G.length ys + lenZs = lenXs + lenYs - 1 + + m = ((lenXs `min` lenYs) + 1) `shiftR` 1 + + xs0 = G.slice 0 m xs + xs1 = G.slice m (lenXs - m) xs + ys0 = G.slice 0 m ys + ys1 = G.slice m (lenYs - m) ys + + xs01 = plusPoly add xs0 xs1 + ys01 = plusPoly add ys0 ys1 + zs0 = go xs0 ys0 + zs2 = go xs1 ys1 + zs11 = go xs01 ys01 +{-# INLINEABLE karatsuba #-} + +convolution :: + (G.Vector v a) => + a -> + (a -> a -> a) -> + (a -> a -> a) -> + v a -> + v a -> + v a +convolution zer add mul = \xs ys -> + let lenXs = G.length xs + lenYs = G.length ys + lenZs = lenXs + lenYs - 1 + in if lenXs == 0 || lenYs == 0 + then G.empty + else G.generate lenZs $ \k -> + foldl' + (\acc i -> acc `add` mul (G.unsafeIndex xs i) (G.unsafeIndex ys (k - i))) + zer + [max (k - lenYs + 1) 0 .. min k (lenXs - 1)] +{-# INLINEABLE convolution #-}