Skip to content

Instantly share code, notes, and snippets.

@thelissimus
Created November 22, 2025 18:50
Show Gist options
  • Select an option

  • Save thelissimus/f3daecaeb7bbf33b465e65744ababcaf to your computer and use it in GitHub Desktop.

Select an option

Save thelissimus/f3daecaeb7bbf33b465e65744ababcaf to your computer and use it in GitHub Desktop.
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE UnboxedTuples #-}
module RelatedPostGen (module RelatedPostGen) where
import Control.DeepSeq (NFData)
import Control.Monad (when)
import Data.Aeson.TH
import Data.Primitive (MutableByteArray (..))
import Data.Primitive.ByteArray (newByteArray)
import Data.Text.Short (ShortText)
import Data.Vector (Vector, indexed, (!))
import Data.Vector qualified as V
import Data.Vector.Hashtables qualified as H
import Data.Vector.Mutable qualified as VM
import Data.Vector.Storable.Mutable (STVector)
import Data.Vector.Storable.Mutable qualified as VSM
import Foreign.Ptr (castPtr)
import Foreign.Storable (Storable (..))
import Foreign.Storable.Tuple ()
import GHC.Exts
import GHC.Generics (Generic)
import GHC.ST (ST (ST))
import GHC.Word
type HashTable s k v = H.Dictionary (H.PrimState (ST s)) VM.MVector k VM.MVector v
type TagMap s = HashTable s ShortText (STVector s Word32)
data Post = MkPost
{ _id :: {-# UNPACK #-} !ShortText
, tags :: {-# UNPACK #-} !(Vector ShortText)
, title :: {-# UNPACK #-} !ShortText
}
deriving stock (Generic, Show)
deriving anyclass (NFData)
$(deriveJSON defaultOptions ''Post)
data RelatedPosts = MkRelatedPosts
{ _id :: {-# UNPACK #-} !ShortText
, tags :: {-# UNPACK #-} !(Vector ShortText)
, related :: {-# UNPACK #-} !(Vector Post)
}
deriving stock (Generic, Show)
deriving anyclass (NFData)
$(deriveJSON defaultOptions ''RelatedPosts)
data TopEntry = TopEntry
{ ix :: {-# UNPACK #-} !Word32
, count :: {-# UNPACK #-} !Word8
}
instance Storable TopEntry where
sizeOf _ = sizeOf (undefined :: Word32) + sizeOf (undefined :: Word8)
alignment _ = alignment (undefined :: Word32)
peek ptr = do
ix <- peek (castPtr ptr)
count <- peekByteOff ptr (sizeOf (undefined :: Word32))
pure (TopEntry ix count)
poke ptr (TopEntry ix count) = do
poke (castPtr ptr) ix
pokeByteOff ptr (sizeOf (undefined :: Word32)) count
data TopN s = TopN
{ ix :: !(MutableByteArray s)
, count :: !(MutableByteArray s)
}
newTopN :: ST s (TopN s)
newTopN = do
-- 4 bytes per Word32
ixBA <- newByteArray (limitTopN * 4)
-- 1 byte per Word8
countBA <- newByteArray limitTopN
-- initialize both to 0
let goInitIx !i
| i < limitTopN = do
writeWord32BA ixBA i (0 :: Word32)
goInitIx (i + 1)
| otherwise = pure ()
goInitCount !i
| i < limitTopN = do
writeWord8BA countBA i (0 :: Word8)
goInitCount (i + 1)
| otherwise = pure ()
goInitIx 0
goInitCount 0
pure (TopN ixBA countBA)
resetTopN :: TopN s -> ST s ()
resetTopN (TopN ixBA countBA) = do
let goIx !i
| i < limitTopN = do
writeWord32BA ixBA i (0 :: Word32)
goIx (i + 1)
| otherwise = pure ()
goCount !i
| i < limitTopN = do
writeWord8BA countBA i (0 :: Word8)
goCount (i + 1)
| otherwise = pure ()
goIx 0
goCount 0
readTopEntry :: TopN s -> Int -> ST s TopEntry
readTopEntry (TopN ixBA countBA) i = do
!ix <- readWord32BA ixBA i
!count <- readWord8BA countBA i
pure (TopEntry ix count)
writeTopEntry :: TopN s -> Int -> Word32 -> Word8 -> ST s ()
writeTopEntry (TopN ixBA countBA) i !ix !count = do
writeWord32BA ixBA i ix
writeWord8BA countBA i count
limitTopN :: Int
limitTopN = 5
{-# INLINE limitTopN #-}
computeRelatedPosts :: Vector Post -> ST s (Vector RelatedPosts)
computeRelatedPosts posts = do
!tagMap :: TagMap s <- H.initialize 0
let !postsIdx = indexed posts
populateTagMap tagMap postsIdx
buildRelatedPosts tagMap postsIdx
{-# INLINE computeRelatedPosts #-}
populateTagMap :: TagMap s -> Vector (Int, Post) -> ST s ()
populateTagMap tagMap postsIdx = do
V.forM_ postsIdx \(!ix, MkPost{tags}) ->
V.forM_ tags \tag ->
H.lookup tagMap tag >>= \case
Just vec -> do
-- could be optimized with exponential growth if Vector exposed capacity
!vec' <- VSM.grow vec 1
VSM.write vec' (VSM.length vec) (fromIntegral ix)
H.insert tagMap tag vec'
Nothing -> H.insert tagMap tag =<< VSM.replicate 1 (fromIntegral ix)
{-# INLINE populateTagMap #-}
buildRelatedPosts :: TagMap s -> Vector (Int, Post) -> ST s (Vector RelatedPosts)
buildRelatedPosts tagMap postsIdx = do
!sharedTags :: STVector s Word8 <- VSM.replicate (V.length postsIdx) 0
!topN :: TopN s <- newTopN
V.forM postsIdx \(!ix, MkPost{_id, tags}) -> do
collectSharedTags sharedTags tagMap tags
VSM.write sharedTags ix 0 -- exclude self from related posts
rankTopN topN sharedTags
!related <- buildRelated postsIdx topN
resetTopN topN -- reset
VSM.set sharedTags 0 -- reset
pure MkRelatedPosts{_id, tags, related}
{-# INLINE buildRelatedPosts #-}
collectSharedTags :: STVector s Word8 -> TagMap s -> Vector ShortText -> ST s ()
collectSharedTags sharedTags tagMap tags = do
V.forM_ tags \(!tag) -> do
!idxs <- H.lookup' tagMap tag
VSM.forM_ idxs $ VSM.modify sharedTags (+ 1) . fromIntegral
{-# INLINE collectSharedTags #-}
rankTopN :: TopN s -> STVector s Word8 -> ST s ()
rankTopN topN sharedTags = do
!mba <- newByteArray 1
writeWord8BA mba 0 (0 :: Word8) -- initialize the count
let !(I# len#) = VSM.length sharedTags
go i#
| isTrue# (i# <# len#) = do
let !ix = I# i#
!count <- VSM.read sharedTags ix
!minTags <- readWord8BA mba 0
when (count > minTags) do
!upperBound <- getUpperBound# (limitTopN - 2) count topN
writeTopEntry topN (upperBound + 1) (fromIntegral ix) count
entry <- readTopEntry topN (limitTopN - 1)
writeWord8BA mba 0 entry.count
go (i# +# 1#)
| otherwise = pure ()
go 0#
where
getUpperBound# :: Int -> Word8 -> TopN s -> ST s Int
getUpperBound# upper count topN_ = goUB upper#
where
!(I# upper#) = upper
goUB curr#
| isTrue# (curr# >=# 0#) = do
let !i = I# curr#
entry <- readTopEntry topN_ i
if count > entry.count then do
-- shift entry up by 1
-- (read, then write to i+1)
writeTopEntry topN_ (i + 1) entry.ix entry.count
goUB (curr# -# 1#)
else
pure (I# curr#)
| otherwise =
pure (I# curr#)
{-# INLINE getUpperBound# #-}
{-# INLINE rankTopN #-}
buildRelated :: Vector (Int, Post) -> TopN s -> ST s (Vector Post)
buildRelated posts topN = do
!res <- VM.unsafeNew limitTopN
let go !ix
| ix >= limitTopN = V.unsafeFreeze res
| otherwise = do
entry <- readTopEntry topN ix
VM.write res ix (snd (posts ! fromIntegral entry.ix))
go (ix + 1)
go 0
{-# INLINE buildRelated #-}
readWord8BA :: MutableByteArray s -> Int -> ST s Word8
readWord8BA (MutableByteArray mba#) (I# i#) =
ST \s# -> case readWord8Array# mba# i# s# of
(# s'#, w# #) -> (# s'#, W8# w# #)
{-# INLINE readWord8BA #-}
writeWord8BA :: MutableByteArray s -> Int -> Word8 -> ST s ()
writeWord8BA (MutableByteArray mba#) (I# i#) (W8# w#) =
ST \s# -> case writeWord8Array# mba# i# w# s# of
s'# -> (# s'#, () #)
{-# INLINE writeWord8BA #-}
readWord32BA :: MutableByteArray s -> Int -> ST s Word32
readWord32BA (MutableByteArray mba#) (I# i#) =
ST \s# -> case readWord32Array# mba# i# s# of
(# s'#, w# #) -> (# s'#, W32# w# #)
{-# INLINE readWord32BA #-}
writeWord32BA :: MutableByteArray s -> Int -> Word32 -> ST s ()
writeWord32BA (MutableByteArray mba#) (I# i#) (W32# w#) =
ST \s# -> case writeWord32Array# mba# i# w# s# of
s'# -> (# s'#, () #)
{-# INLINE writeWord32BA #-}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment