Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Created March 1, 2026 11:43
Show Gist options
  • Select an option

  • Save LeeeeT/fe041f3dc3fcb197f88df1ee46b1fb75 to your computer and use it in GitHub Desktop.

Select an option

Save LeeeeT/fe041f3dc3fcb197f88df1ee46b1fb75 to your computer and use it in GitHub Desktop.
k-SIC + readback + collapse
import itertools
from collections.abc import Callable
from dataclasses import dataclass
type Name = int
name = itertools.count()
type Label = int
label = itertools.count()
type Binder = int
binder = itertools.count()
type Pos = Var | Nil | Lam | Sup | Atm
type Neg = Sub | Era | App | Dup | Red
@dataclass(frozen=True)
class Var:
nam: Name
@dataclass(frozen=True)
class Sub:
nam: Name
@dataclass(frozen=True)
class Nil:
pass
@dataclass(frozen=True)
class Era:
pass
@dataclass(frozen=True)
class Lam:
bnd: Name
bod: Name
@dataclass(frozen=True)
class App:
arg: Name
ret: Name
@dataclass(frozen=True)
class Dup:
dpl: Label
dp0: Name
dp1: Name
@dataclass(frozen=True)
class Sup:
spl: Label
sp0: Name
sp1: Name
@dataclass(frozen=True)
class Atm:
val: str
@dataclass(frozen=True)
class Red:
ctx: Callable[[str], str]
ret: Name
type Rdx = tuple[Name, Name]
@dataclass(frozen=True)
class Net:
book: set[Rdx]
vars: dict[Name, Pos]
subs: dict[Name, Neg]
def empty_net() -> Net:
return Net(set(), {}, {})
def var(net: Net, nam: Name) -> Name:
var = next(name)
net.vars[var] = Var(nam)
return var
def sub(net: Net, nam: Name) -> Name:
sub = next(name)
net.subs[sub] = Sub(nam)
return sub
def nil(net: Net) -> Name:
nil = next(name)
net.vars[nil] = Nil()
return nil
def era(net: Net) -> Name:
era = next(name)
net.subs[era] = Era()
return era
def lam(net: Net, bnd: Name, bod: Name) -> Name:
lam = next(name)
net.vars[lam] = Lam(bnd, bod)
return lam
def app(net: Net, arg: Name, ret: Name) -> Name:
app = next(name)
net.subs[app] = App(arg, ret)
return app
def dup(net: Net, dpl: Label, dp0: Name, dp1: Name) -> Name:
dup = next(name)
net.subs[dup] = Dup(dpl, dp0, dp1)
return dup
def sup(net: Net, spl: Label, sp0: Name, sp1: Name) -> Name:
sup = next(name)
net.vars[sup] = Sup(spl, sp0, sp1)
return sup
def atm(net: Net, val: str) -> Name:
atm = next(name)
net.vars[atm] = Atm(val)
return atm
def red(net: Net, ctx: Callable[[str], str], ret: Name) -> Name:
red = next(name)
net.subs[red] = Red(ctx, ret)
return red
def wire(net: Net) -> tuple[Name, Name]:
nam = next(name)
return var(net, nam), sub(net, nam)
def show_pos(net: Net, pos: Name) -> str:
match net.vars[pos]:
case Var(nam) if nam in net.vars:
return show_pos(net, nam)
case Var(nam):
return f"+{nam}"
case Nil():
return "+_"
case Lam(bnd, bod):
return f"+({show_neg(net, bnd)} {show_pos(net, bod)})"
case Sup(spl, sp0, sp1):
return f"+{spl}{{{show_pos(net, sp0)} {show_pos(net, sp1)}}}"
case Atm(val):
# return f"+A( {val} )"
return val
def show_neg(net: Net, neg: Name) -> str:
match net.subs[neg]:
case Sub(nam) if nam in net.subs:
return show_neg(net, nam)
case Sub(nam):
return f"-{nam}"
case Era():
return "-_"
case App(arg, ret):
return f"-({show_pos(net, arg)} {show_neg(net, ret)})"
case Dup(dpl, dp0, dp1):
return f"-{dpl}{{{show_neg(net, dp0)} {show_neg(net, dp1)}}}"
case Red(ctx, ret):
return f"-R({ctx("X")}, {show_neg(net, ret)})"
def reduce(net: Net) -> int:
itrs = 0
while net.book:
itrs += 1
lhs, rhs = net.book.pop()
match net.subs.pop(lhs), net.vars.pop(rhs):
case lhsc, Var(nam) if nam in net.vars:
net.subs[lhs] = lhsc
net.book.add((lhs, nam))
case lhsc, Var(nam):
net.subs[nam] = lhsc
case Sub(nam), rhsc if nam in net.subs:
net.vars[rhs] = rhsc
net.book.add((nam, rhs))
case Sub(nam), rhsc:
net.vars[nam] = rhsc
case Era(), Nil():
pass
case Era(), Lam(bnd, bod):
net.book.add((bnd, nil(net)))
net.book.add((era(net), bod))
case Era(), Sup(spl, sp0, sp1):
net.book.add((era(net), sp0))
net.book.add((era(net), sp1))
case Era(), Atm(val):
pass
case App(arg, ret), Nil():
net.book.add((era(net), arg))
net.book.add((ret, nil(net)))
case App(arg, ret), Lam(bnd, bod):
net.book.add((bnd, arg))
net.book.add((ret, bod))
case App(arg, ret), Sup(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dup(net, spl, an, bn), arg))
net.book.add((ret, sup(net, spl, cp, dp)))
net.book.add((app(net, ap, cn), sp0))
net.book.add((app(net, bp, dn), sp1))
case App(arg, ret), Atm(val):
net.book.add((red(net, lambda x, val=val: f"({val} {x})", ret), arg))
case Dup(dpl, dp0, dp1), Nil():
net.book.add((dp0, nil(net)))
net.book.add((dp1, nil(net)))
case Dup(dpl, dp0, dp1), Lam(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dp0, lam(net, an, bp)))
net.book.add((dp1, lam(net, cn, dp)))
net.book.add((bnd, sup(net, dpl, ap, cp)))
net.book.add((dup(net, dpl, bn, dn), bod))
case Dup(dpl, dp0, dp1), Sup(spl, sp0, sp1) if dpl == spl:
net.book.add((dp0, sp0))
net.book.add((dp1, sp1))
case Dup(dpl, dp0, dp1), Sup(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
cp, cn = wire(net)
dp, dn = wire(net)
net.book.add((dp0, sup(net, spl, ap, bp)))
net.book.add((dp1, sup(net, spl, cp, dp)))
net.book.add((dup(net, dpl, an, cn), sp0))
net.book.add((dup(net, dpl, bn, dn), sp1))
case Dup(dpl, dp0, dp1), Atm(val):
net.book.add((dp0, atm(net, val)))
net.book.add((dp1, atm(net, val)))
case Red(ctx, ret), Nil():
net.book.add((ret, atm(net, ctx("_"))))
case Red(ctx, ret), Lam(bnd, bod):
x = next(binder)
net.book.add((bnd, atm(net, f"{x}")))
net.book.add((red(net, lambda y, ctx=ctx, x=x: ctx(f"λ{x}.{y}"), ret), bod))
case Red(ctx, ret), Sup(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((red(net, ctx, an), sp0))
net.book.add((red(net, ctx, bn), sp1))
net.book.add((ret, sup(net, spl, ap, bp)))
case Red(ctx, ret), Atm(val):
net.book.add((ret, atm(net, ctx(val))))
return itrs
def print_state(net: Net, root: Name, *, wires: bool = False) -> None:
print("ROOT:")
print(f" {show_pos(net, root)}")
print()
print("BOOK:")
for lhs, rhs in net.book:
print(f" {show_neg(net, lhs)} ⋈ {show_pos(net, rhs)}")
if wires:
print()
print("VARS:")
for nam in net.vars:
print(f" {nam} = {show_pos(net, nam)}")
print()
print("SUBS:")
for nam in net.subs:
print(f" {nam} = {show_neg(net, nam)}")
def print_reduction(net: Net, root: Name, *, readback: bool = False, collapse: bool = False, wires: bool = False) -> None:
print("=" * 30)
print("=", "INITIAL".center(26), "=")
print("=" * 30)
print()
print_state(net, root, wires=wires)
print()
print("=" * 30)
print("=", "NORMALIZED".center(26), "=")
print("=" * 30)
print()
itrs = reduce(net)
print_state(net, root, wires=wires)
print()
print(f"ITRS: {itrs}")
if readback or collapse:
root = mk_red(net, root)
reduce(net)
if collapse:
print()
print("=" * 30)
print("=", "COLLAPSED".center(26), "=")
print("=" * 30)
print()
print_collapse(net, root)
def collect_labels(net: Net) -> set[Label]:
labs: set[Label] = set()
for pos in net.vars.values():
match pos:
case Sup(spl, sp0, sp1):
labs.add(spl)
case _:
pass
return labs
def mk_collapse_tree(net: Net, root: Name) -> list[Name]:
leaves = [root]
for lab in collect_labels(net):
new_leaves: list[Name] = []
for leaf in leaves:
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((dup(net, lab, an, bn), leaf))
new_leaves.extend([ap, bp])
leaves = new_leaves
return leaves
def print_collapse(net: Net, root: Name) -> list[str]:
leaves = mk_collapse_tree(net, root)
reduce(net)
for n, leaf in enumerate(leaves):
print(f"{n}) {show_pos(net, leaf)}")
def mk_app(net: Net, fun: Name, arg: Name) -> Name:
ap, an = wire(net)
net.book.add((app(net, arg, an), fun))
return ap
def mk_dup(net: Net, bod: Name, lab: Label | None = None) -> tuple[Name, Name]:
if lab is None:
lab = next(label)
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((dup(net, an, bn), bod))
return ap, bp
def mk_red(net: Net, bod: Name) -> Name:
ap, an = wire(net)
net.book.add((red(net, lambda x: x, an), bod))
return ap
# λt.λf.t
def mk_true(net: Net) -> Name:
tp, tn = wire(net)
fp, fn = wire(net)
net.book.add((era(net), fp))
return lam(net, tn, lam(net, fn, tp))
# λt.λf.f
def mk_false(net: Net) -> Name:
tp, tn = wire(net)
fp, fn = wire(net)
net.book.add((era(net), tp))
return lam(net, tn, lam(net, fn, fp))
# λp.(p A B)
def mk_pair(net: Net, a: Name, b: Name) -> Name:
pp, pn = wire(net)
return lam(net, pn, mk_app(net, mk_app(net, pp, a), b))
# (0{T,F}, 0{T,F})
def test_collapse_annihilation(net: Net) -> Name:
l0 = next(label)
return mk_pair(net, sup(net, l0, mk_true(net), mk_false(net)), sup(net, l0, mk_true(net), mk_false(net)))
# (0{T,F} 1{T,F})
def test_collapse_commutation(net: Net) -> Name:
l0 = next(label)
l1 = next(label)
return mk_pair(net, sup(net, l0, mk_true(net), mk_false(net)), sup(net, l1, mk_true(net), mk_false(net)))
def main() -> None:
net = empty_net()
# root = test_collapse_annihilation(net)
root = test_collapse_commutation(net)
print_reduction(net, root, collapse=True)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment