Last active
January 8, 2023 09:41
-
-
Save chezbgone/ccab2732bd219bbf1ff5838d154131d7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module Matmul where | |
| import Data.Vector qualified as V | |
| import Data.Vector.Mutable qualified as VM | |
| data Matrix s r = | |
| MkMatrix | |
| { matrixRows :: Int | |
| , matrixCols :: Int | |
| , matrixCells :: VM.MVector s (VM.MVector s r) | |
| } | |
| plus :: (VM.PrimMonad m, Num r) | |
| => Matrix (VM.PrimState m) r | |
| -> Matrix (VM.PrimState m) r | |
| -> Matrix (VM.PrimState m) r | |
| -> m () | |
| plus dest a b | |
| | aRows /= bRows || bRows /= destRows || | |
| aCols /= bCols || bCols /= destCols = error "dimension mismatch" | |
| | otherwise = do | |
| VM.iforM_ destCells $ \i row -> | |
| VM.iforM_ row $ \j _ -> do | |
| aVal <- aCells `VM.read` i >>= (`VM.read` j) | |
| bVal <- bCells `VM.read` i >>= (`VM.read` j) | |
| VM.write row j (aVal + bVal) | |
| where | |
| MkMatrix aRows aCols aCells = a | |
| MkMatrix bRows bCols bCells = b | |
| MkMatrix destRows destCols destCells = dest | |
| times :: (VM.PrimMonad m, Num r) | |
| => Matrix (VM.PrimState m) r | |
| -> Matrix (VM.PrimState m) r | |
| -> Matrix (VM.PrimState m) r | |
| -> m () | |
| times dest a b | |
| | aCols /= bRows || aRows /= destRows || bCols /= destCols = error "dimension mismatch" | |
| | otherwise = | |
| VM.iforM_ destCells $ \i dest_i -> do -- for each i | |
| VM.set dest_i 0 -- set all elements in row i to 0 | |
| a_i <- aCells `VM.read` i -- get a_i, the ith row of a | |
| VM.iforM_ dest_i $ \j _ -> -- for each j | |
| VM.iforM_ a_i $ \k a_ik -> do -- for each a_ik in a_i | |
| b_kj <- bCells `VM.read` k >>= (`VM.read` j) -- get b_jk | |
| VM.modify dest_i (+ (a_ik * b_kj)) j -- dest_ij += a_ik + b_jk | |
| where | |
| MkMatrix aRows aCols aCells = a | |
| MkMatrix bRows bCols bCells = b | |
| MkMatrix destRows destCols destCells = dest | |
| ----------- | |
| -- utils -- | |
| ----------- | |
| newMatrix :: VM.PrimMonad m => Int -> Int -> r -> m (Matrix (VM.PrimState m) r) | |
| newMatrix m n init_value = MkMatrix m n <$> VM.replicateM m (VM.replicate n init_value) | |
| freeze2 :: VM.PrimMonad m | |
| => VM.MVector (VM.PrimState m) (VM.MVector (VM.PrimState m) a) | |
| -> m (V.Vector (V.Vector a)) | |
| freeze2 mut_mat = V.mapM V.freeze =<< V.freeze mut_mat | |
| thaw2 :: VM.PrimMonad m | |
| => V.Vector (V.Vector a) | |
| -> m (VM.MVector (VM.PrimState m) (VM.MVector (VM.PrimState m) a)) | |
| thaw2 mat = V.thaw =<< V.mapM V.thaw mat | |
| example :: IO () | |
| example = do | |
| a <- fmap (MkMatrix @_ @Int 3 3) $ | |
| thaw2 $ | |
| V.fromList | |
| [ V.fromList [1, 2, 3] | |
| , V.fromList [4, 5, 6] | |
| , V.fromList [7, 8, 9] | |
| ] | |
| b <- fmap (MkMatrix @_ @Int 3 3) $ | |
| thaw2 $ | |
| V.fromList | |
| [ V.fromList [1, -2, 2] | |
| , V.fromList [0, 1, 2] | |
| , V.fromList [3, 1, -1] | |
| ] | |
| dest <- newMatrix 3 3 0 | |
| times dest a b | |
| print =<< freeze2 (matrixCells dest) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment