Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Last active March 10, 2026 16:59
Show Gist options
  • Select an option

  • Save LeeeeT/5d32c1aaffe0c79f687c2389716dd3a5 to your computer and use it in GitHub Desktop.

Select an option

Save LeeeeT/5d32c1aaffe0c79f687c2389716dd3a5 to your computer and use it in GitHub Desktop.
k-SIC + CNT + EVA
import itertools
from dataclasses import dataclass
type Name = int
name = itertools.count()
type Label = int
label = itertools.count()
type Pos = VAR | LAM | SUP | LAP | LD0 | LD1 | CNT | AVE
type Neg = SUB | APP | DUP | UDP | EVA
@dataclass(frozen=True)
class VAR:
nam: Name
@dataclass(frozen=True)
class LAM:
bnd: Name
bod: Name
@dataclass(frozen=True)
class SUP:
spl: Label
sp0: Name
sp1: Name
@dataclass(frozen=True)
class LAP:
fun: Name
arg: Name
@dataclass(frozen=True)
class LD0:
ldl: Label
ldb: Name
lda: Name
@dataclass(frozen=True)
class LD1:
ldl: Label
ldb: Name
lda: Name
@dataclass(frozen=True)
class CNT:
cnb: Name
@dataclass(frozen=True)
class AVE:
bod: Name
@dataclass(frozen=True)
class SUB:
nam: Name
@dataclass(frozen=True)
class APP:
arg: Name
ret: Name
@dataclass(frozen=True)
class DUP:
dpl: Label
dp0: Name
dp1: Name
@dataclass(frozen=True)
class UDP:
ud0: Name
ud1: Name
@dataclass(frozen=True)
class EVA:
ret: Name
type Rdx = tuple[Name, Name]
@dataclass(frozen=True)
class Net:
book: set[Rdx]
vars: dict[Name, Pos]
subs: dict[Name, Neg]
def empty_net() -> Net:
return Net(set(), {}, {})
def pos(net: Net, term: Pos) -> Name:
pos = next(name)
net.vars[pos] = term
return pos
def neg(net: Net, term: Neg) -> Name:
neg = next(name)
net.subs[neg] = term
return neg
def var(net: Net, nam: Name) -> Name:
return pos(net, VAR(nam))
def lam(net: Net, bnd: Name, bod: Name) -> Name:
return pos(net, LAM(bnd, bod))
def sup(net: Net, spl: Label, sp0: Name, sp1: Name) -> Name:
return pos(net, SUP(spl, sp0, sp1))
def lap(net: Net, fun: Name, arg: Name) -> Name:
return pos(net, LAP(fun, arg))
def ld0(net: Net, ldl: Label, ldb: Name, lda: Name) -> Name:
return pos(net, LD0(ldl, ldb, lda))
def ld1(net: Net, ldl: Label, ldb: Name, lda: Name) -> Name:
return pos(net, LD1(ldl, ldb, lda))
def cnt(net: Net, cnb: Name) -> Name:
return pos(net, CNT(cnb))
def ave(net: Net, bod: Name) -> Name:
return pos(net, AVE(bod))
def sub(net: Net, nam: Name) -> Name:
return neg(net, SUB(nam))
def app(net: Net, arg: Name, ret: Name) -> Name:
return neg(net, APP(arg, ret))
def dup(net: Net, dpl: Label, dp0: Name, dp1: Name) -> Name:
return neg(net, DUP(dpl, dp0, dp1))
def udp(net: Net, ud0: Name, ud1: Name) -> Name:
return neg(net, UDP(ud0, ud1))
def eva(net: Net, ret: Name) -> Name:
return neg(net, EVA(ret))
def wire(net: Net) -> tuple[Name, Name]:
nam = next(name)
return var(net, nam), sub(net, nam)
def show_pos(net: Net, pos: Name) -> str:
match net.vars[pos]:
case VAR(nam) if nam in net.vars:
return show_pos(net, nam)
case VAR(nam):
return f"+{nam}"
case LAM(bnd, bod):
return f"+({show_neg(net, bnd)} {show_pos(net, bod)})"
case SUP(spl, sp0, sp1):
return f"+{spl}{{{show_pos(net, sp0)} {show_pos(net, sp1)}}}"
case LAP(fun, arg):
return f"+?({show_pos(net, fun)} {show_pos(net, arg)})"
case LD0(ldl, ldb, lda) if ldb in net.vars:
return f"+?0{{{ldl}}}{{{show_pos(net, ldb)}}}"
case LD0(ldl, ldb, lda):
return show_pos(net, lda)
case LD1(ldl, ldb, lda) if ldb in net.vars:
return f"+?1{{{ldl}}}{{{show_pos(net, ldb)}}}"
case LD1(ldl, ldb, lda):
return show_pos(net, lda)
case CNT(cnb):
return f"+%[{show_pos(net, cnb)}]"
case AVE(bod):
return f"+!({show_pos(net, bod)})"
def show_neg(net: Net, neg: Name) -> str:
match net.subs[neg]:
case SUB(nam) if nam in net.subs:
return show_neg(net, nam)
case SUB(nam):
return f"-{nam}"
case APP(arg, ret):
return f"-({show_pos(net, arg)} {show_neg(net, ret)})"
case DUP(dpl, dp0, dp1):
return f"-&{dpl}{{{show_neg(net, dp0)} {show_neg(net, dp1)}}}"
case UDP(ud0, ud1):
return f"-%[{show_neg(net, ud0)} {show_neg(net, ud1)}]"
case EVA(ret):
return f"-!({show_neg(net, ret)})"
def reduce(net: Net) -> int:
itrs = 0
while net.book:
itrs += 1
lhs, rhs = net.book.pop()
match net.subs.pop(lhs), net.vars.pop(rhs):
case lhsc, VAR(nam) if nam in net.vars:
net.subs[lhs] = lhsc
net.book.add((lhs, nam))
case lhsc, VAR(nam):
net.subs[nam] = lhsc
case SUB(nam), rhsc if nam in net.subs:
net.vars[rhs] = rhsc
net.book.add((nam, rhs))
case SUB(nam), rhsc:
net.vars[nam] = rhsc
case APP(arg, ret), LAM(bnd, bod):
net.book.add((bnd, arg))
net.book.add((ret, bod))
case APP(arg, ret), SUP(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
x = next(name)
net.book.add((ret, sup(net, spl, lap(net, sp0, ld0(net, spl, arg, x)), lap(net, sp1, ld1(net, spl, arg, x)))))
case APP(arg, ret), AVE(bod):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((app(net, ap, bn), bod))
net.book.add((eva(net, an), arg))
net.book.add((ret, ave(net, bp)))
case DUP(dpl, dp0, dp1), LAM(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
x = next(name)
net.book.add((dp0, lam(net, an, ld0(net, dpl, bod, x))))
net.book.add((dp1, lam(net, bn, ld1(net, dpl, bod, x))))
net.book.add((bnd, sup(net, dpl, ap, bp)))
case DUP(dpl, dp0, dp1), SUP(spl, sp0, sp1) if dpl == spl:
net.book.add((dp0, sp0))
net.book.add((dp1, sp1))
case DUP(dpl, dp0, dp1), SUP(spl, sp0, sp1):
x = next(name)
y = next(name)
net.book.add((dp0, sup(net, spl, ld0(net, dpl, sp0, x), ld0(net, dpl, sp1, y))))
net.book.add((dp1, sup(net, spl, ld1(net, dpl, sp0, x), ld1(net, dpl, sp1, y))))
case DUP(dpl, dp0, dp1), AVE(bod):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((dp0, ave(net, ap)))
net.book.add((dp1, ave(net, bp)))
net.book.add((dup(net, dpl, an, bn), bod))
case UDP(ud0, ud1), LAM(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
lab = next(label)
x = next(name)
net.book.add((ud0, lam(net, an, ld0(net, lab, bod, x))))
net.book.add((ud1, lam(net, bn, ld1(net, lab, bod, x))))
net.book.add((bnd, sup(net, lab, ap, bp)))
case UDP(ud0, ud1), SUP(spl, sp0, sp1):
net.book.add((ud0, sup(net, spl, cnt(net, sp0), cnt(net, sp1))))
net.book.add((ud1, sup(net, spl, cnt(net, sp0), cnt(net, sp1))))
case UDP(ud0, ud1), CNT(cnb):
net.book.add((ud0, cnt(net, cnb)))
net.book.add((ud1, cnt(net, cnb)))
case UDP(ud0, ud1), AVE(bod):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((ud0, ave(net, ap)))
net.book.add((ud1, ave(net, bp)))
net.book.add((udp(net, an, bn), bod))
case EVA(ret), LAM(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((ret, lam(net, an, bp)))
net.book.add((bnd, ave(net, ap)))
net.book.add((eva(net, bn), bod))
case EVA(ret), SUP(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((eva(net, an), sp0))
net.book.add((eva(net, bn), sp1))
net.book.add((ret, sup(net, spl, ap, bp)))
case EVA(ret), AVE(bod):
net.book.add((ret, bod))
case lhsc, LAP(fun, arg):
net.subs[lhs] = lhsc
net.book.add((app(net, arg, lhs), fun))
case lhsc, LD0(ldl, ldb, lda) if ldb in net.vars:
net.subs[lhs] = lhsc
x = next(name)
net.vars[x] = net.vars.pop(ldb)
ap, an = wire(net)
net.book.add((dup(net, ldl, lhs, an), x))
net.vars[lda] = net.vars.pop(ap)
case lhsc, LD0(ldl, ldb, lda):
net.subs[lhs] = lhsc
net.book.add((lhs, lda))
case lhsc, LD1(ldl, ldb, lda) if ldb in net.vars:
net.subs[lhs] = lhsc
x = next(name)
net.vars[x] = net.vars.pop(ldb)
ap, an = wire(net)
net.book.add((dup(net, ldl, an, lhs), x))
net.vars[lda] = net.vars.pop(ap)
case lhsc, LD1(ldl, ldb, lda):
net.subs[lhs] = lhsc
net.book.add((lhs, lda))
case lhsc, CNT(cnb):
net.subs[lhs] = lhsc
x = next(name)
net.vars[x] = net.vars.pop(cnb)
ap, an = wire(net)
net.book.add((udp(net, lhs, an), x))
net.vars[cnb] = net.vars.pop(ap)
return itrs
def print_state(net: Net, root: Name, *, heap: bool = False) -> None:
print("ROOT:")
print(f" {show_pos(net, root)}")
print()
print("BOOK:")
for lhs, rhs in net.book:
print(f" {show_neg(net, lhs)} ⋈ {show_pos(net, rhs)}")
if heap:
print()
print("VARS:")
for nam in net.vars:
print(f" {nam} = {show_pos(net, nam)}")
print()
print("SUBS:")
for nam in net.subs:
print(f" {nam} = {show_neg(net, nam)}")
def print_reduction(net: Net, root: Name, *, heap: bool = False) -> None:
print("=" * 30)
print("=", "INITIAL".center(26), "=")
print("=" * 30)
print()
print_state(net, root, heap=heap)
print()
print("=" * 30)
print("=", "NORMALIZED".center(26), "=")
print("=" * 30)
print()
itrs = reduce(net)
print_state(net, root, heap=heap)
print()
print(f"ITRS: {itrs}")
def mk_nil(net: Net) -> Name:
ap, an = wire(net)
return ap
def mk_era(net: Net) -> Name:
ap, an = wire(net)
return an
def mk_ldp(net: Net, bod: Name, lab: Label | None = None) -> tuple[Name, Name]:
if lab is None:
lab = next(label)
x = next(name)
return ld0(net, lab, bod, x), ld1(net, lab, bod, x)
def mk_cnt(net: Net, bod: Name, n: int) -> list[Name]:
return [cnt(net, bod) for _ in range(n)]
def mk_eva(net: Net, bod: Name) -> Name:
ap, an = wire(net)
net.book.add((eva(net, an), bod))
return ap
def mk_c2(net: Net) -> Name:
fp, fn = wire(net)
xp, xn = wire(net)
f0, f1 = mk_ldp(net, fp)
fx = lap(net, f0, xp)
ffx = lap(net, f1, fx)
return lam(net, fn, lam(net, xn, ffx))
def mk_cpow2(net: Net, k: int) -> Name:
# k=1 -> c2, k=2 -> c4, k=3 -> c8, ...
# c_{2^k} = λf. c2 (c_{2^(k-1)} f)
if k < 1:
raise ValueError("k must be >= 1")
if k == 1:
return mk_c2(net)
fp, fn = wire(net)
prev = mk_cpow2(net, k - 1) # c_{2^(k-1)}
prev_f = lap(net, prev, fp) # f^(2^(k-1))
body = lap(net, mk_c2(net), prev_f) # square -> f^(2^k)
return lam(net, fn, body)
# λb.λt.λf.(b (f f) f)
def mk_F(net: Net) -> Name:
bp, bn = wire(net)
tp, tn = wire(net)
fp, fn = wire(net)
f0, f1, f2 = mk_cnt(net, fp, 3)
return lam(net, bn, lam(net, tn, lam(net, fn, lap(net, lap(net, bp, lap(net, f0, f1)), f2))))
# F^(2^N)
def test_F_fusion(net: Net, N: int) -> Name:
return lap(net, mk_cpow2(net, N), mk_F(net))
# λf.(λx.(f (x x)) λx.(f (x x)))
def mk_rec(net: Net) -> Name:
fp, fn = wire(net)
xp, xn = wire(net)
x0, x1 = mk_ldp(net, xp)
lhs, rhs = mk_ldp(net, lam(net, xn, lap(net, fp, lap(net, x0, x1))))
return lam(net, fn, lap(net, lhs, rhs))
# λe.λo.λi.e
def mk_E(net: Net) -> Name:
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
return lam(net, en, lam(net, on, lam(net, in_, ep)))
# λp.λe.λo.λi.(o p)
def mk_O(net: Net) -> Name:
pp, pn = wire(net)
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
return lam(net, pn, lam(net, en, lam(net, on, lam(net, in_, lap(net, op, pp)))))
# λp.λe.λo.λi.(i p)
def mk_I(net: Net) -> Name:
pp, pn = wire(net)
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
return lam(net, pn, lam(net, en, lam(net, on, lam(net, in_, lap(net, ip, pp)))))
# (rec
# λinsert.λn.(n
# λxs.λe.λo.λi.(xs
# (i E)
# λxs.(i xs)
# λxs.(i xs)
# )
# _
# λp.λxs.λe.λo.λi.(xs
# (o (insert p E))
# λxs.(o (insert p xs))
# λxs.(i (insert p xs))
# )
# )
# )
def mk_insert(net: Net) -> Name:
insertp, insertn = wire(net)
np, nn = wire(net)
insert0, insert1, insert2 = mk_cnt(net, insertp, 3)
def mk_case_E() -> Name:
xsp, xsn = wire(net)
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
i0, i1, i2 = mk_cnt(net, ip, 3)
def mk_case_E_case_E() -> Name:
return lap(net, i0, mk_E(net))
def mk_case_E_case_O() -> Name:
xsp, xsn = wire(net)
return lam(net, xsn, lap(net, i1, xsp))
def mk_case_E_case_I() -> Name:
xsp, xsn = wire(net)
return lam(net, xsn, lap(net, i2, xsp))
case_E_ret = lap(net, lap(net, lap(net, xsp, mk_case_E_case_E()), mk_case_E_case_O()), mk_case_E_case_I())
return lam(net, xsn, lam(net, en, lam(net, on, lam(net, in_, case_E_ret))))
def mk_case_O() -> Name:
return mk_nil(net)
def mk_case_I() -> Name:
pp, pn = wire(net)
xsp, xsn = wire(net)
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
p0, p1, p2 = mk_cnt(net, pp, 3)
o0, o1 = mk_cnt(net, op, 2)
def mk_case_I_case_E() -> Name:
return lap(net, o0, lap(net, lap(net, insert0, p0), mk_E(net)))
def mk_case_I_case_O() -> Name:
xsp, xsn = wire(net)
return lam(net, xsn, lap(net, o1, lap(net, lap(net, insert1, p1), xsp)))
def mk_case_I_case_I() -> Name:
xsp, xsn = wire(net)
return lam(net, xsn, lap(net, ip, lap(net, lap(net, insert2, p2), xsp)))
case_I_ret = lap(net, lap(net, lap(net, xsp, mk_case_I_case_E()), mk_case_I_case_O()), mk_case_I_case_I())
return lam(net, pn, lam(net, xsn, lam(net, en, lam(net, on, lam(net, in_, case_I_ret)))))
ret = lap(net, lap(net, lap(net, np, mk_case_E()), mk_case_O()), mk_case_I())
return lap(net, mk_rec(net), lam(net, insertn, lam(net, nn, ret)))
# (insert (I E))^(2^N)
def test_insert_fusion(net: Net, N: int) -> Name:
return lap(net, mk_cpow2(net, N), lap(net, mk_insert(net), lap(net, mk_I(net), mk_E(net))))
def main() -> None:
net = empty_net()
# ITRS: 9980
root = test_F_fusion(net, 100)
print_reduction(net, mk_eva(net, root))
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment