Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Last active October 29, 2025 17:43
Show Gist options
  • Select an option

  • Save LeeeeT/495dd1cffe343228cbc166111de03d64 to your computer and use it in GitHub Desktop.

Select an option

Save LeeeeT/495dd1cffe343228cbc166111de03d64 to your computer and use it in GitHub Desktop.
(Broken)
from dataclasses import dataclass
from itertools import count
type Name = int
name = count()
def fresh() -> Name:
return next(name)
type Term = Var | Lam | App | Dup | Sup | Bri | Ann
@dataclass(frozen=True)
class Var:
name: Name
@dataclass(frozen=True)
class Lam:
name: Name
body: Term
@dataclass(frozen=True)
class App:
fun: Term
arg: Term
@dataclass(frozen=True)
class Dup:
name: Name
body: Term
@dataclass(frozen=True)
class Sup:
fst: Term
snd: Term
@dataclass(frozen=True)
class Bri:
name: Name
body: Term
@dataclass(frozen=True)
class Ann:
typ: Term
val: Term
def show(term: Term) -> str:
match term:
case Var(name):
return f"{name}"
case Lam(name, body):
return f"λ{name}.{show(body)}"
case App(fun, arg):
return f"({show(fun)} {show(arg)})"
case Dup(name, body):
return f"&{name}.{show(body)}"
case Sup(fst, snd):
return f"{{{show(fst)} {show(snd)}}}"
case Bri(name, body):
return f"θ{name}.{show(body)}"
case Ann(typ, val):
return f"[{show(typ)} {show(val)}]"
def nf(term: Term, ctx: dict[Name, Term] | None = None) -> Term:
if ctx is None:
ctx = {}
match whnf(term, ctx):
case Var(name):
return Var(name)
case Lam(name, body):
return Lam(name, nf(body, ctx))
case App(fun, arg):
return App(nf(fun, ctx), nf(arg, ctx))
case Dup(name, body):
return Dup(name, nf(body, ctx))
case Sup(fst, snd):
return Sup(nf(fst, ctx), nf(snd, ctx))
case Bri(name, body):
return Bri(name, nf(body, ctx))
case Ann(typ, val):
return Ann(nf(typ, ctx), nf(val, ctx))
def whnf(term: Term, ctx: dict[Name, Term] | None = None) -> Term:
if ctx is None:
ctx = {}
match term:
case Var(name) if name in ctx:
return whnf(ctx[name], ctx)
case App(fun, arg):
match whnf(fun, ctx):
case Lam(name, body):
ctx[name] = arg
return whnf(body, ctx)
case Sup(fst, snd):
x = fresh()
return Sup(App(fst, Dup(x, arg)), App(snd, Var(x)))
case Bri(name, body):
x = fresh()
y = fresh()
ctx[name] = Lam(x, Var(y))
return Bri(y, App(body, Ann(arg, Var(x))))
case fun:
return App(fun, arg)
case Dup(name, body):
match whnf(body, ctx):
case Lam(lname, lbody):
x = fresh()
y = fresh()
z = fresh()
ctx[name] = Lam(z, Var(y))
ctx[lname] = Sup(Var(x), Var(z))
return Lam(x, Dup(y, lbody))
case Sup(fst, snd):
ctx[name] = snd
return whnf(fst, ctx)
case Bri(bname, bbody):
x = fresh()
y = fresh()
z = fresh()
ctx[name] = Bri(z, Var(y))
ctx[bname] = Sup(Var(x), Var(z))
return Bri(x, Dup(y, bbody))
case body:
return Dup(name, body)
case Ann(typ, val):
match whnf(typ, ctx):
case Lam(name, body):
x = fresh()
y = fresh()
ctx[name] = Bri(x, Var(y))
return Lam(y, Ann(body, App(val, Var(x))))
case Sup(fst, snd):
x = fresh()
return Sup(Ann(fst, Dup(x, val)), Ann(snd, Var(x)))
case Bri(name, body):
ctx[name] = val
return whnf(body, ctx)
case typ:
return Ann(typ, val)
case term:
return term
def check(term: Term, ctx: dict[Name, Term] | None = None) -> bool:
if ctx is None:
ctx = {}
match whnf(term, ctx):
case Var(_):
return True
case Lam(_, body):
return check(body, ctx)
case App(fun, arg):
return check(fun, ctx) and check(arg, ctx)
case Dup(_, body):
return check(body, ctx)
case Sup(fst, snd):
return check(fst, ctx) and check(snd, ctx)
case Bri(_, body):
return check(body, ctx)
case Ann(typ1, val1):
match whnf(val1, ctx):
case Ann(typ2, val2):
return equal(typ1, typ2, ctx) and check(Ann(typ2, val2), ctx)
case val1:
return check(val1, ctx)
def equal(first: Term, second: Term, ctx: dict[Name, Term] | None = None) -> bool:
if ctx is None:
ctx = {}
# print(f"Checking: {show(first)} == {show(second)}")
match whnf(first, ctx), whnf(second, ctx):
case Var(name1), Var(name2):
return name1 == name2
case Lam(name1, body1), Lam(name2, body2):
x = fresh()
return equal(App(Lam(name1, body1), Var(x)), App(Lam(name2, body2), Var(x)), ctx)
case App(fun1, arg1), App(fun2, arg2):
return equal(fun1, fun2, ctx) and equal(arg1, arg2, ctx)
case Dup(name1, body1), Dup(name2, body2):
x = fresh()
return equal(Dup(name1, Sup(body1, Var(x))), Dup(name2, Sup(body2, Var(x))), ctx)
case Sup(fst1, snd1), Sup(fst2, snd2):
return equal(fst1, fst2, ctx) and equal(snd1, snd2, ctx)
case Bri(name1, body1), Bri(name2, body2):
x = fresh()
return equal(Ann(Bri(name1, body1), Var(x)), Ann(Bri(name2, body2), Var(x)), ctx)
case Ann(typ1, val1), Ann(typ2, val2):
return equal(typ1, typ2, ctx) and equal(val1, val2, ctx)
case _:
return False
Any = Bri(x:=fresh(), Var(x))
Fun = Lam(a:=fresh(), Lam(b:=fresh(), Bri(f:=fresh(), Lam(x:=fresh(), Ann(Var(b), App(Var(f), Ann(Var(a), Var(x))))))))
# Unit = θself λP λa {(self P {a: P}): P}
Unit = Bri(self:=fresh(), Lam(P:=fresh(), Lam(a:=fresh(), Ann(Dup(P1:=fresh(), Var(P)), App(App(Var(self), Dup(P2:=fresh(), Var(P1))), Ann(Var(P2), Var(a)))))))
unit = Lam(P:=fresh(), Lam(a:=fresh(), Var(a)))
main = Ann(Unit, unit)
print("Orig: ", show(main))
print("NF: ", show(nf(main, ctx:={})))
print("Valid: ", check(main))
print("Context:", {name: show(term) for name, term in ctx.items()})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment