Skip to content

Instantly share code, notes, and snippets.

@chezbgone
Last active January 8, 2023 09:41
Show Gist options
  • Select an option

  • Save chezbgone/ccab2732bd219bbf1ff5838d154131d7 to your computer and use it in GitHub Desktop.

Select an option

Save chezbgone/ccab2732bd219bbf1ff5838d154131d7 to your computer and use it in GitHub Desktop.
module Matmul where
import Data.Vector qualified as V
import Data.Vector.Mutable qualified as VM
data Matrix s r =
MkMatrix
{ matrixRows :: Int
, matrixCols :: Int
, matrixCells :: VM.MVector s (VM.MVector s r)
}
plus :: (VM.PrimMonad m, Num r)
=> Matrix (VM.PrimState m) r
-> Matrix (VM.PrimState m) r
-> Matrix (VM.PrimState m) r
-> m ()
plus dest a b
| aRows /= bRows || bRows /= destRows ||
aCols /= bCols || bCols /= destCols = error "dimension mismatch"
| otherwise = do
VM.iforM_ destCells $ \i row ->
VM.iforM_ row $ \j _ -> do
aVal <- aCells `VM.read` i >>= (`VM.read` j)
bVal <- bCells `VM.read` i >>= (`VM.read` j)
VM.write row j (aVal + bVal)
where
MkMatrix aRows aCols aCells = a
MkMatrix bRows bCols bCells = b
MkMatrix destRows destCols destCells = dest
times :: (VM.PrimMonad m, Num r)
=> Matrix (VM.PrimState m) r
-> Matrix (VM.PrimState m) r
-> Matrix (VM.PrimState m) r
-> m ()
times dest a b
| aCols /= bRows || aRows /= destRows || bCols /= destCols = error "dimension mismatch"
| otherwise =
VM.iforM_ destCells $ \i dest_i -> do -- for each i
VM.set dest_i 0 -- set all elements in row i to 0
a_i <- aCells `VM.read` i -- get a_i, the ith row of a
VM.iforM_ dest_i $ \j _ -> -- for each j
VM.iforM_ a_i $ \k a_ik -> do -- for each a_ik in a_i
b_kj <- bCells `VM.read` k >>= (`VM.read` j) -- get b_jk
VM.modify dest_i (+ (a_ik * b_kj)) j -- dest_ij += a_ik + b_jk
where
MkMatrix aRows aCols aCells = a
MkMatrix bRows bCols bCells = b
MkMatrix destRows destCols destCells = dest
-----------
-- utils --
-----------
newMatrix :: VM.PrimMonad m => Int -> Int -> r -> m (Matrix (VM.PrimState m) r)
newMatrix m n init_value = MkMatrix m n <$> VM.replicateM m (VM.replicate n init_value)
freeze2 :: VM.PrimMonad m
=> VM.MVector (VM.PrimState m) (VM.MVector (VM.PrimState m) a)
-> m (V.Vector (V.Vector a))
freeze2 mut_mat = V.mapM V.freeze =<< V.freeze mut_mat
thaw2 :: VM.PrimMonad m
=> V.Vector (V.Vector a)
-> m (VM.MVector (VM.PrimState m) (VM.MVector (VM.PrimState m) a))
thaw2 mat = V.thaw =<< V.mapM V.thaw mat
example :: IO ()
example = do
a <- fmap (MkMatrix @_ @Int 3 3) $
thaw2 $
V.fromList
[ V.fromList [1, 2, 3]
, V.fromList [4, 5, 6]
, V.fromList [7, 8, 9]
]
b <- fmap (MkMatrix @_ @Int 3 3) $
thaw2 $
V.fromList
[ V.fromList [1, -2, 2]
, V.fromList [0, 1, 2]
, V.fromList [3, 1, -1]
]
dest <- newMatrix 3 3 0
times dest a b
print =<< freeze2 (matrixCells dest)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment