Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Last active July 7, 2025 14:30
Show Gist options
  • Select an option

  • Save LeeeeT/0c7790e7592b2e40c351af195757dbd9 to your computer and use it in GitHub Desktop.

Select an option

Save LeeeeT/0c7790e7592b2e40c351af195757dbd9 to your computer and use it in GitHub Desktop.
HOAS CoC
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
type Term[T] = Var[T] | Typ[T] | Lam[T] | Fun[T] | App[T]
@dataclass(frozen=True)
class Var[T]:
var: T
@dataclass(frozen=True)
class Typ[T]:
pass
@dataclass(frozen=True)
class Lam[T]:
dom: Term[T]
bod: Callable[[Term[T]], Term[T]]
@dataclass(frozen=True)
class Fun[T]:
dom: Term[T]
cod: Callable[[Term[T]], Term[T]]
@dataclass(frozen=True)
class App[T]:
fun: Term[T]
arg: Term[T]
def show(term: Term[int], depth: int = 0) -> str:
match term:
case Var(var):
return f"{var}"
case Typ():
return "*"
case Lam(dom, bod):
return f"λ({depth} : {show(dom, depth)}). {show(bod(Var(depth)), depth + 1)}"
case Fun(dom, cod):
return f"Π({depth} : {show(dom, depth)}). {show(cod(Var(depth)), depth + 1)}"
case App(fun, arg):
fun = f"({show(fun, depth)})" if isinstance(fun, Lam) else show(fun, depth)
arg = f"({show(arg, depth)})" if not isinstance(arg, Var) else show(arg, depth)
return f"{fun} {arg}"
def nf[T](term: Term[T]) -> Term[T]:
match whnf(term):
case Lam(dom, bod):
return Lam(nf(dom), lambda arg: nf(bod(arg)))
case Fun(dom, cod):
return Fun(nf(dom), lambda arg: nf(cod(arg)))
case App(fun, arg):
return App(nf(fun), nf(arg))
case _:
return term
def whnf[T](term: Term[T]) -> Term[T]:
match term:
case App(fun, arg):
match whnf(fun):
case Lam(_, bod):
return whnf(bod(arg))
case _:
return term
case _:
return term
def equal(first: Term[object], second: Term[object]) -> bool:
match whnf(first), whnf(second):
case Var(v1), Var(v2):
return v1 is v2
case Typ(), Typ():
return True
case Lam(dom1, bod1), Lam(dom2, bod2):
v = Var(object())
return equal(dom1, dom2) and equal(bod1(v), bod2(v))
case Fun(dom1, cod1), Fun(dom2, cod2):
v = Var(object())
return equal(dom1, dom2) and equal(cod1(v), cod2(v))
case App(fun1, arg1), App(fun2, arg2):
return equal(fun1, fun2) and equal(arg1, arg2)
case _:
return False
def infer(term: Term[Any], ctx: dict[Term[Any], Term[Any]] = {}) -> Term[Any]:
match term:
case Var(var):
return ctx[Var(var)]
case Typ():
return Typ()
case Lam(dom, bod):
return Fun(dom, lambda arg: infer(bod(arg), ctx | {arg: dom}))
case Fun(dom, _):
return Typ()
case App(fun, arg):
match whnf(infer(fun, ctx)):
case Fun(dom, cod) if equal(infer(arg, ctx), dom):
return cod(arg)
case _:
raise TypeError
case _:
raise TypeError
Unit: Term[Any] = Fun(Typ(), lambda R: Fun(R, lambda unit: R))
unit: Term[Any] = Lam(Typ(), lambda R: Lam(R, lambda unit: unit))
Bool: Term[Any] = Fun(Typ(), lambda R: Fun(R, lambda true: Fun(R, lambda false: R)))
true: Term[Any] = Lam(Typ(), lambda R: Lam(R, lambda true: Lam(R, lambda false: true)))
false: Term[Any] = Lam(Typ(), lambda R: Lam(R, lambda true: Lam(R, lambda false: false)))
negate: Term[Any] = Lam(Bool, lambda b: App(App(App(b, Bool), false), true))
main = App(negate, true)
print(show(nf(main)))
print(show(nf(infer(main))))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment