Skip to content

Instantly share code, notes, and snippets.

@chezbgone
Created March 1, 2022 03:05
Show Gist options
  • Select an option

  • Save chezbgone/36235f76a100652c919dd33856414998 to your computer and use it in GitHub Desktop.

Select an option

Save chezbgone/36235f76a100652c919dd33856414998 to your computer and use it in GitHub Desktop.
arbitrarily nested lists with overloaded lists
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
-- for OverloadedLists
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE FlexibleInstances #-}
import GHC.Exts
data Nat = Z | S Nat
deriving Show
data Tensor (n :: Nat) a where
ListOf :: [a] -> Tensor Z a
Tensor :: [Tensor n a] -> Tensor (S n) a
deriving instance Show a => Show (Tensor n a)
instance IsList (Tensor Z a) where
type Item (Tensor Z a) = a
fromList = ListOf
toList (ListOf as) = as
instance IsList (Tensor (S n) a) where
type Item (Tensor (S n) a) = Tensor n a
fromList = Tensor
toList (Tensor ts) = ts
vals :: Num a => Tensor (S (S Z)) a
vals = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
sumT :: Num a => Tensor n a -> a
sumT (ListOf as) = sum as
sumT (Tensor tens) = sum $ map sumT tens
main :: IO ()
main = do
print $ sumT vals
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment