Fast polynomial evaluation

This commit is contained in:
Abastro 2025-03-30 00:03:39 +09:00
parent e150bafc86
commit eff690198a
4 changed files with 143 additions and 6 deletions

View file

@ -5,6 +5,7 @@ import Poly
import System.Random
import System.TimeIt
import Control.Monad
import PolyFast
main :: IO ()
main = do
@ -16,13 +17,21 @@ main = do
putStrLn $ "Naive f * g: " <> show (f * g)
putStrLn $ "Karatsuba f * g: " <> show (normalize $ karatsubaMult f g)
putStrLn ""
experimentFor 250
experimentFor 500
experimentFor 1000
karatsubaFor 2000
karatsubaFor 4000
-- experimentFor 250
-- experimentFor 500
-- experimentFor 1000
-- karatsubaFor 2000
-- karatsubaFor 4000
fastKaratsubaFor 512
fastKaratsubaFor 1024
fastKaratsubaFor 2048
fastKaratsubaFor 4096
fastKaratsubaFor 8192
fastKaratsubaFor 16384
fastKaratsubaFor 32768
fastKaratsubaFor 65536
where
experimentFor n = do
setStdGen $ mkStdGen 10
@ -45,3 +54,13 @@ main = do
putStrLn "Karatsuba:"
_ <- timeIt $ evaluate (karatsubaMult f g)
putStrLn "Finished"
fastKaratsubaFor n = do
setStdGen $ mkStdGen 10
let randomPoly size = makePoly <$> replicateM size (randomRIO (-100, 100))
putStrLn $ "Size " <> show n
f :: Poly Int <- randomPoly n
g :: Poly Int <- randomPoly n
putStrLn "Fast Karatsuba:"
_ <- timeIt $ evaluate (karatsuba 0 (+) (-) (*) (unwrapPoly f) (unwrapPoly g))
putStrLn "Finished"

View file

@ -27,6 +27,7 @@ library
import: warnings, deps
exposed-modules:
Poly
PolyFast
hs-source-dirs: src
default-language: GHC2021

View file

@ -22,6 +22,9 @@ vecZipPad0With f xs ys = V.generate (max (V.length xs) (V.length ys)) $
newtype Poly a = Poly (V.Vector a)
deriving (Eq)
unwrapPoly :: Poly a -> V.Vector a
unwrapPoly (Poly v) = v
makePoly :: (V.Unbox a) => [a] -> Poly a
makePoly = Poly . V.fromList

114
src/PolyFast.hs Normal file
View file

@ -0,0 +1,114 @@
module PolyFast where
import Control.Monad
import Control.Monad.ST
import Data.Bits
import Data.List
import Data.Vector.Generic qualified as G
import Data.Vector.Generic.Mutable qualified as MG
import GHC.Base
plusPoly ::
(G.Vector v a) =>
(a -> a -> a) ->
v a ->
v a ->
v a
plusPoly add xs ys = runST $ do
let lenXs = G.length xs
lenYs = G.length ys
lenMn = lenXs `min` lenYs
lenMx = lenXs `max` lenYs
zs <- MG.unsafeNew lenMx
forM_ [0 .. lenMn - 1] $ \i ->
MG.unsafeWrite zs i (add (G.unsafeIndex xs i) (G.unsafeIndex ys i))
G.unsafeCopy
(MG.unsafeSlice lenMn (lenMx - lenMn) zs)
(G.unsafeSlice lenMn (lenMx - lenMn) (if lenXs <= lenYs then ys else xs))
G.unsafeFreeze zs
{-# INLINEABLE plusPoly #-}
karatsubaThreshold :: Int
karatsubaThreshold = 32
karatsuba ::
(G.Vector v a) =>
a ->
(a -> a -> a) ->
(a -> a -> a) ->
(a -> a -> a) ->
v a ->
v a ->
v a
karatsuba zer add sub mul = go
where
conv = inline convolution zer add mul
go xs ys
| lenXs <= karatsubaThreshold || lenYs <= karatsubaThreshold =
conv xs ys
| otherwise = runST $ do
zs <- MG.unsafeNew lenZs
forM_ [0 .. lenZs - 1] $ \k -> do
let z0 =
if k < G.length zs0
then G.unsafeIndex zs0 k
else zer
z11 =
if k - m >= 0 && k - m < G.length zs11
then G.unsafeIndex zs11 (k - m)
else zer
z10 =
if k - m >= 0 && k - m < G.length zs0
then G.unsafeIndex zs0 (k - m)
else zer
z12 =
if k - m >= 0 && k - m < G.length zs2
then G.unsafeIndex zs2 (k - m)
else zer
z2 =
if k - 2 * m >= 0 && k - 2 * m < G.length zs2
then G.unsafeIndex zs2 (k - 2 * m)
else zer
MG.unsafeWrite zs k (z0 `add` (z11 `sub` (z10 `add` z12)) `add` z2)
G.unsafeFreeze zs
where
lenXs = G.length xs
lenYs = G.length ys
lenZs = lenXs + lenYs - 1
m = ((lenXs `min` lenYs) + 1) `shiftR` 1
xs0 = G.slice 0 m xs
xs1 = G.slice m (lenXs - m) xs
ys0 = G.slice 0 m ys
ys1 = G.slice m (lenYs - m) ys
xs01 = plusPoly add xs0 xs1
ys01 = plusPoly add ys0 ys1
zs0 = go xs0 ys0
zs2 = go xs1 ys1
zs11 = go xs01 ys01
{-# INLINEABLE karatsuba #-}
convolution ::
(G.Vector v a) =>
a ->
(a -> a -> a) ->
(a -> a -> a) ->
v a ->
v a ->
v a
convolution zer add mul = \xs ys ->
let lenXs = G.length xs
lenYs = G.length ys
lenZs = lenXs + lenYs - 1
in if lenXs == 0 || lenYs == 0
then G.empty
else G.generate lenZs $ \k ->
foldl'
(\acc i -> acc `add` mul (G.unsafeIndex xs i) (G.unsafeIndex ys (k - i)))
zer
[max (k - lenYs + 1) 0 .. min k (lenXs - 1)]
{-# INLINEABLE convolution #-}