Instantly share code, notes, and snippets.
Created
November 22, 2025 18:50
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
-
Save thelissimus/f3daecaeb7bbf33b465e65744ababcaf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| {-# LANGUAGE TemplateHaskell #-} | |
| {-# LANGUAGE MagicHash #-} | |
| {-# LANGUAGE OverloadedRecordDot #-} | |
| {-# LANGUAGE UnboxedTuples #-} | |
| module RelatedPostGen (module RelatedPostGen) where | |
| import Control.DeepSeq (NFData) | |
| import Control.Monad (when) | |
| import Data.Aeson.TH | |
| import Data.Primitive (MutableByteArray (..)) | |
| import Data.Primitive.ByteArray (newByteArray) | |
| import Data.Text.Short (ShortText) | |
| import Data.Vector (Vector, indexed, (!)) | |
| import Data.Vector qualified as V | |
| import Data.Vector.Hashtables qualified as H | |
| import Data.Vector.Mutable qualified as VM | |
| import Data.Vector.Storable.Mutable (STVector) | |
| import Data.Vector.Storable.Mutable qualified as VSM | |
| import Foreign.Ptr (castPtr) | |
| import Foreign.Storable (Storable (..)) | |
| import Foreign.Storable.Tuple () | |
| import GHC.Exts | |
| import GHC.Generics (Generic) | |
| import GHC.ST (ST (ST)) | |
| import GHC.Word | |
| type HashTable s k v = H.Dictionary (H.PrimState (ST s)) VM.MVector k VM.MVector v | |
| type TagMap s = HashTable s ShortText (STVector s Word32) | |
| data Post = MkPost | |
| { _id :: {-# UNPACK #-} !ShortText | |
| , tags :: {-# UNPACK #-} !(Vector ShortText) | |
| , title :: {-# UNPACK #-} !ShortText | |
| } | |
| deriving stock (Generic, Show) | |
| deriving anyclass (NFData) | |
| $(deriveJSON defaultOptions ''Post) | |
| data RelatedPosts = MkRelatedPosts | |
| { _id :: {-# UNPACK #-} !ShortText | |
| , tags :: {-# UNPACK #-} !(Vector ShortText) | |
| , related :: {-# UNPACK #-} !(Vector Post) | |
| } | |
| deriving stock (Generic, Show) | |
| deriving anyclass (NFData) | |
| $(deriveJSON defaultOptions ''RelatedPosts) | |
| data TopEntry = TopEntry | |
| { ix :: {-# UNPACK #-} !Word32 | |
| , count :: {-# UNPACK #-} !Word8 | |
| } | |
| instance Storable TopEntry where | |
| sizeOf _ = sizeOf (undefined :: Word32) + sizeOf (undefined :: Word8) | |
| alignment _ = alignment (undefined :: Word32) | |
| peek ptr = do | |
| ix <- peek (castPtr ptr) | |
| count <- peekByteOff ptr (sizeOf (undefined :: Word32)) | |
| pure (TopEntry ix count) | |
| poke ptr (TopEntry ix count) = do | |
| poke (castPtr ptr) ix | |
| pokeByteOff ptr (sizeOf (undefined :: Word32)) count | |
| data TopN s = TopN | |
| { ix :: !(MutableByteArray s) | |
| , count :: !(MutableByteArray s) | |
| } | |
| newTopN :: ST s (TopN s) | |
| newTopN = do | |
| -- 4 bytes per Word32 | |
| ixBA <- newByteArray (limitTopN * 4) | |
| -- 1 byte per Word8 | |
| countBA <- newByteArray limitTopN | |
| -- initialize both to 0 | |
| let goInitIx !i | |
| | i < limitTopN = do | |
| writeWord32BA ixBA i (0 :: Word32) | |
| goInitIx (i + 1) | |
| | otherwise = pure () | |
| goInitCount !i | |
| | i < limitTopN = do | |
| writeWord8BA countBA i (0 :: Word8) | |
| goInitCount (i + 1) | |
| | otherwise = pure () | |
| goInitIx 0 | |
| goInitCount 0 | |
| pure (TopN ixBA countBA) | |
| resetTopN :: TopN s -> ST s () | |
| resetTopN (TopN ixBA countBA) = do | |
| let goIx !i | |
| | i < limitTopN = do | |
| writeWord32BA ixBA i (0 :: Word32) | |
| goIx (i + 1) | |
| | otherwise = pure () | |
| goCount !i | |
| | i < limitTopN = do | |
| writeWord8BA countBA i (0 :: Word8) | |
| goCount (i + 1) | |
| | otherwise = pure () | |
| goIx 0 | |
| goCount 0 | |
| readTopEntry :: TopN s -> Int -> ST s TopEntry | |
| readTopEntry (TopN ixBA countBA) i = do | |
| !ix <- readWord32BA ixBA i | |
| !count <- readWord8BA countBA i | |
| pure (TopEntry ix count) | |
| writeTopEntry :: TopN s -> Int -> Word32 -> Word8 -> ST s () | |
| writeTopEntry (TopN ixBA countBA) i !ix !count = do | |
| writeWord32BA ixBA i ix | |
| writeWord8BA countBA i count | |
| limitTopN :: Int | |
| limitTopN = 5 | |
| {-# INLINE limitTopN #-} | |
| computeRelatedPosts :: Vector Post -> ST s (Vector RelatedPosts) | |
| computeRelatedPosts posts = do | |
| !tagMap :: TagMap s <- H.initialize 0 | |
| let !postsIdx = indexed posts | |
| populateTagMap tagMap postsIdx | |
| buildRelatedPosts tagMap postsIdx | |
| {-# INLINE computeRelatedPosts #-} | |
| populateTagMap :: TagMap s -> Vector (Int, Post) -> ST s () | |
| populateTagMap tagMap postsIdx = do | |
| V.forM_ postsIdx \(!ix, MkPost{tags}) -> | |
| V.forM_ tags \tag -> | |
| H.lookup tagMap tag >>= \case | |
| Just vec -> do | |
| -- could be optimized with exponential growth if Vector exposed capacity | |
| !vec' <- VSM.grow vec 1 | |
| VSM.write vec' (VSM.length vec) (fromIntegral ix) | |
| H.insert tagMap tag vec' | |
| Nothing -> H.insert tagMap tag =<< VSM.replicate 1 (fromIntegral ix) | |
| {-# INLINE populateTagMap #-} | |
| buildRelatedPosts :: TagMap s -> Vector (Int, Post) -> ST s (Vector RelatedPosts) | |
| buildRelatedPosts tagMap postsIdx = do | |
| !sharedTags :: STVector s Word8 <- VSM.replicate (V.length postsIdx) 0 | |
| !topN :: TopN s <- newTopN | |
| V.forM postsIdx \(!ix, MkPost{_id, tags}) -> do | |
| collectSharedTags sharedTags tagMap tags | |
| VSM.write sharedTags ix 0 -- exclude self from related posts | |
| rankTopN topN sharedTags | |
| !related <- buildRelated postsIdx topN | |
| resetTopN topN -- reset | |
| VSM.set sharedTags 0 -- reset | |
| pure MkRelatedPosts{_id, tags, related} | |
| {-# INLINE buildRelatedPosts #-} | |
| collectSharedTags :: STVector s Word8 -> TagMap s -> Vector ShortText -> ST s () | |
| collectSharedTags sharedTags tagMap tags = do | |
| V.forM_ tags \(!tag) -> do | |
| !idxs <- H.lookup' tagMap tag | |
| VSM.forM_ idxs $ VSM.modify sharedTags (+ 1) . fromIntegral | |
| {-# INLINE collectSharedTags #-} | |
| rankTopN :: TopN s -> STVector s Word8 -> ST s () | |
| rankTopN topN sharedTags = do | |
| !mba <- newByteArray 1 | |
| writeWord8BA mba 0 (0 :: Word8) -- initialize the count | |
| let !(I# len#) = VSM.length sharedTags | |
| go i# | |
| | isTrue# (i# <# len#) = do | |
| let !ix = I# i# | |
| !count <- VSM.read sharedTags ix | |
| !minTags <- readWord8BA mba 0 | |
| when (count > minTags) do | |
| !upperBound <- getUpperBound# (limitTopN - 2) count topN | |
| writeTopEntry topN (upperBound + 1) (fromIntegral ix) count | |
| entry <- readTopEntry topN (limitTopN - 1) | |
| writeWord8BA mba 0 entry.count | |
| go (i# +# 1#) | |
| | otherwise = pure () | |
| go 0# | |
| where | |
| getUpperBound# :: Int -> Word8 -> TopN s -> ST s Int | |
| getUpperBound# upper count topN_ = goUB upper# | |
| where | |
| !(I# upper#) = upper | |
| goUB curr# | |
| | isTrue# (curr# >=# 0#) = do | |
| let !i = I# curr# | |
| entry <- readTopEntry topN_ i | |
| if count > entry.count then do | |
| -- shift entry up by 1 | |
| -- (read, then write to i+1) | |
| writeTopEntry topN_ (i + 1) entry.ix entry.count | |
| goUB (curr# -# 1#) | |
| else | |
| pure (I# curr#) | |
| | otherwise = | |
| pure (I# curr#) | |
| {-# INLINE getUpperBound# #-} | |
| {-# INLINE rankTopN #-} | |
| buildRelated :: Vector (Int, Post) -> TopN s -> ST s (Vector Post) | |
| buildRelated posts topN = do | |
| !res <- VM.unsafeNew limitTopN | |
| let go !ix | |
| | ix >= limitTopN = V.unsafeFreeze res | |
| | otherwise = do | |
| entry <- readTopEntry topN ix | |
| VM.write res ix (snd (posts ! fromIntegral entry.ix)) | |
| go (ix + 1) | |
| go 0 | |
| {-# INLINE buildRelated #-} | |
| readWord8BA :: MutableByteArray s -> Int -> ST s Word8 | |
| readWord8BA (MutableByteArray mba#) (I# i#) = | |
| ST \s# -> case readWord8Array# mba# i# s# of | |
| (# s'#, w# #) -> (# s'#, W8# w# #) | |
| {-# INLINE readWord8BA #-} | |
| writeWord8BA :: MutableByteArray s -> Int -> Word8 -> ST s () | |
| writeWord8BA (MutableByteArray mba#) (I# i#) (W8# w#) = | |
| ST \s# -> case writeWord8Array# mba# i# w# s# of | |
| s'# -> (# s'#, () #) | |
| {-# INLINE writeWord8BA #-} | |
| readWord32BA :: MutableByteArray s -> Int -> ST s Word32 | |
| readWord32BA (MutableByteArray mba#) (I# i#) = | |
| ST \s# -> case readWord32Array# mba# i# s# of | |
| (# s'#, w# #) -> (# s'#, W32# w# #) | |
| {-# INLINE readWord32BA #-} | |
| writeWord32BA :: MutableByteArray s -> Int -> Word32 -> ST s () | |
| writeWord32BA (MutableByteArray mba#) (I# i#) (W32# w#) = | |
| ST \s# -> case writeWord32Array# mba# i# w# s# of | |
| s'# -> (# s'#, () #) | |
| {-# INLINE writeWord32BA #-} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment