This repository has been archived on 2025-04-04. You can view files and clone it, but cannot push or open issues or pull requests.
mathematical-algorithms/src/Poly.hs

96 lines
3.2 KiB
Haskell
Raw Normal View History

2025-03-23 16:56:30 +09:00
module Poly where
import Data.List
import Data.Maybe
2025-03-23 17:35:20 +09:00
import Data.Vector.Unboxed qualified as V
2025-03-23 16:56:30 +09:00
-- Zip two vectors while padding 0s on the shorter vector.
2025-03-23 17:35:20 +09:00
vecZipPad0With :: (V.Unbox a, Num a) => (a -> a -> a) -> V.Vector a -> V.Vector a -> V.Vector a
2025-03-23 16:56:30 +09:00
vecZipPad0With f xs ys = V.generate (max (V.length xs) (V.length ys)) $
\i -> fromMaybe 0 (xs V.!? i) `f` fromMaybe 0 (ys V.!? i)
-- | Polynomial type.
--
2025-03-23 17:35:20 +09:00
-- >>> Poly (V.fromList [1 :: Int .. 5])
2025-03-23 16:56:30 +09:00
-- 1 X^0 + 2 X^1 + 3 X^2 + 4 X^3 + 5 X^4
2025-03-23 17:35:20 +09:00
-- >>> Poly (V.fromList [1 :: Int, 2]) * Poly (V.fromList [3 :: Int, 4, 5])
2025-03-23 16:56:30 +09:00
-- 3 X^0 + 10 X^1 + 13 X^2 + 10 X^3
2025-03-23 17:35:20 +09:00
-- >>> Poly (V.fromList [1 :: Int, 2]) * Poly (V.fromList [])
2025-03-23 16:56:30 +09:00
-- 0 X^0
newtype Poly a = Poly (V.Vector a)
deriving (Eq)
2025-03-23 17:35:20 +09:00
makePoly :: (V.Unbox a) => [a] -> Poly a
makePoly = Poly . V.fromList
2025-03-23 16:56:30 +09:00
-- | Degree, assuming top term is nonzero
2025-03-23 17:35:20 +09:00
degree :: (V.Unbox a) => Poly a -> Int
degree (Poly f) = V.length f - 1
2025-03-23 16:56:30 +09:00
-- | Shift up polynomial by X^n
2025-03-23 17:35:20 +09:00
shiftUp :: (V.Unbox a, Num a) => Int -> Poly a -> Poly a
2025-03-23 16:56:30 +09:00
shiftUp n (Poly f) = Poly $ V.replicate n 0 <> f
-- | Shift down polynomial by X^n
2025-03-23 17:35:20 +09:00
shiftDown :: (V.Unbox a) => Int -> Poly a -> Poly a
2025-03-23 16:56:30 +09:00
shiftDown n (Poly f) = Poly $ V.drop n f
-- | Remainder under X^n
2025-03-23 17:35:20 +09:00
remXn :: (V.Unbox a) => Int -> Poly a -> Poly a
2025-03-23 16:56:30 +09:00
remXn n (Poly f) = Poly $ V.take n f
-- | Normalize polynomial, removing leading 0s
2025-03-23 17:35:20 +09:00
--
-- >>> normalize $ Poly (V.fromList [1 :: Int, 0, 0])
2025-03-23 16:56:30 +09:00
-- 1 X^0
2025-03-23 17:35:20 +09:00
--
-- >>> normalize $ Poly (V.fromList [1 :: Int, 2, 3, 0])
2025-03-23 16:56:30 +09:00
-- 1 X^0 + 2 X^1 + 3 X^2
2025-03-23 17:35:20 +09:00
normalize :: (Eq a, Num a, V.Unbox a) => Poly a -> Poly a
2025-03-23 16:56:30 +09:00
normalize (Poly f) = Poly remain
where
(_, remain) = V.spanR (== 0) f
-- | This Num instance implements the classical multiplication.
2025-03-23 17:35:20 +09:00
instance (Num a, V.Unbox a) => Num (Poly a) where
2025-03-23 16:56:30 +09:00
(+) :: Poly a -> Poly a -> Poly a
Poly f + Poly g = Poly $ vecZipPad0With (+) f g
(-) :: Poly a -> Poly a -> Poly a
Poly f - Poly g = Poly $ vecZipPad0With (-) f g
(*) :: Poly a -> Poly a -> Poly a
2025-03-23 17:35:20 +09:00
Poly f * Poly g = sum (map Poly mults)
2025-03-23 16:56:30 +09:00
where
2025-03-23 17:35:20 +09:00
mults = zipWith (\i fi -> V.map (fi *) (V.replicate i 0 <> g)) [0 ..] (V.toList f)
2025-03-23 16:56:30 +09:00
negate :: Poly a -> Poly a
negate (Poly f) = Poly $ V.map negate f
abs :: Poly a -> Poly a
abs = error "abs: invalid on poly"
signum :: Poly a -> Poly a
signum = error "signum: invalid on poly"
fromInteger :: Integer -> Poly a
fromInteger = Poly . V.singleton . fromInteger
2025-03-23 17:35:20 +09:00
instance (V.Unbox a, Show a) => Show (Poly a) where
show (Poly p) = intercalate " + " $ zipWith (\i coeff -> show coeff <> " X^" <> show i) [0 :: Int ..] (V.toList p)
2025-03-23 16:56:30 +09:00
2025-03-23 17:35:20 +09:00
karatsubaMult :: (Num a, V.Unbox a) => Poly a -> Poly a -> Poly a
2025-03-23 16:56:30 +09:00
karatsubaMult a b = atLog degBound a b
where
degBound = fromJust $ find (> max (degree a) (degree b)) [2 ^ i | i <- [0 :: Int ..]]
-- degBnd: power-of-two degree bound
atLog degBnd f g = case degBnd of
1 -> f * g
_ -> shiftUp degBnd prod1 + shiftUp nextBound (prodAdd - prod0 - prod1) + prod0
where
nextBound = degBnd `div` 2
f1 = shiftDown nextBound f
f0 = remXn nextBound f
g1 = shiftDown nextBound g
g0 = remXn nextBound g
prod0 = atLog nextBound f0 g0
prod1 = atLog nextBound f1 g1
prodAdd = atLog nextBound (f0 + f1) (g0 + g1)