Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Last active March 10, 2026 23:33
Show Gist options
  • Select an option

  • Save LeeeeT/a8ae05d9364e01a3f8496cdba627844a to your computer and use it in GitHub Desktop.

Select an option

Save LeeeeT/a8ae05d9364e01a3f8496cdba627844a to your computer and use it in GitHub Desktop.
k-SIC + CNT + REF + EVA
import itertools
from collections.abc import Callable
from dataclasses import dataclass
type Location = int
location = itertools.count()
type Label = int
label = itertools.count()
type Identifier = str
type Pos = VAR | LAM | SUP | LAP | LD0 | LD1 | CNT | REF | AVE
type Neg = SUB | APP | DUP | UDP | EVA
@dataclass(frozen=True)
class VAR:
loc: Location
@dataclass(frozen=True)
class LAM:
bnd: Location
bod: Location
@dataclass(frozen=True)
class SUP:
spl: Label
sp0: Location
sp1: Location
@dataclass(frozen=True)
class LAP:
fun: Location
arg: Location
@dataclass(frozen=True)
class LD0:
ldl: Label
ldb: Location
lda: Location
@dataclass(frozen=True)
class LD1:
ldl: Label
ldb: Location
lda: Location
@dataclass(frozen=True)
class CNT:
cnb: Location
@dataclass(frozen=True)
class REF:
ide: Identifier
@dataclass(frozen=True)
class AVE:
bod: Location
@dataclass(frozen=True)
class SUB:
loc: Location
@dataclass(frozen=True)
class APP:
arg: Location
ret: Location
@dataclass(frozen=True)
class DUP:
dpl: Label
dp0: Location
dp1: Location
@dataclass(frozen=True)
class UDP:
ud0: Location
ud1: Location
@dataclass(frozen=True)
class EVA:
ret: Location
type Package = tuple[Net, Location]
type Redex = tuple[Location, Location]
@dataclass(frozen=True)
class Net:
vars: dict[Location, Pos]
subs: dict[Location, Neg]
scop: dict[Identifier, Package]
book: set[Redex]
def net_empty() -> Net:
return Net({}, {}, {}, set())
def net_embed(net: Net, embedding: Net) -> None:
net.vars.update(embedding.vars)
net.subs.update(embedding.subs)
net.scop.update(embedding.scop)
net.book.update(embedding.book)
def net_clone(net: Net, locations: dict[Location, Location], labels: dict[Label, Label]) -> Net:
def clone_loc(loc: Location) -> Location:
if loc not in locations:
locations[loc] = next(location)
return locations[loc]
def clone_lab(lab: Label) -> Label:
if lab not in labels:
labels[lab] = next(label)
return labels[lab]
def clone_pos(pos: Pos) -> Pos:
match pos:
case VAR(loc):
return VAR(clone_loc(loc))
case LAM(bnd, bod):
return LAM(clone_loc(bnd), clone_loc(bod))
case SUP(spl, sp0, sp1):
return SUP(clone_lab(spl), clone_loc(sp0), clone_loc(sp1))
case LAP(fun, arg):
return LAP(clone_loc(fun), clone_loc(arg))
case LD0(ldl, ldb, lda):
return LD0(clone_lab(ldl), clone_loc(ldb), clone_loc(lda))
case LD1(ldl, ldb, lda):
return LD1(clone_lab(ldl), clone_loc(ldb), clone_loc(lda))
case CNT(cnb):
return CNT(clone_loc(cnb))
case REF(ide):
return REF(ide)
case AVE(bod):
return AVE(clone_loc(bod))
def clone_neg(neg: Neg) -> Neg:
match neg:
case SUB(loc):
return SUB(clone_loc(loc))
case APP(arg, ret):
return APP(clone_loc(arg), clone_loc(ret))
case DUP(dpl, dp0, dp1):
return DUP(clone_lab(dpl), clone_loc(dp0), clone_loc(dp1))
case UDP(ud0, ud1):
return UDP(clone_loc(ud0), clone_loc(ud1))
case EVA(ret):
return EVA(clone_loc(ret))
new = net_empty()
for loc, pos in net.vars.items():
new.vars[clone_loc(loc)] = clone_pos(pos)
for loc, neg in net.subs.items():
new.subs[clone_loc(loc)] = clone_neg(neg)
for ide, dfn in net.scop.items():
new.scop[ide] = dfn
for lhs, rhs in net.book:
new.book.add((clone_loc(lhs), clone_loc(rhs)))
return new
def pos(net: Net, term: Pos) -> Location:
pos = next(location)
net.vars[pos] = term
return pos
def neg(net: Net, term: Neg) -> Location:
neg = next(location)
net.subs[neg] = term
return neg
def var(net: Net, loc: Location) -> Location:
return pos(net, VAR(loc))
def lam(net: Net, bnd: Location, bod: Location) -> Location:
return pos(net, LAM(bnd, bod))
def sup(net: Net, spl: Label, sp0: Location, sp1: Location) -> Location:
return pos(net, SUP(spl, sp0, sp1))
def lap(net: Net, fun: Location, arg: Location) -> Location:
return pos(net, LAP(fun, arg))
def ld0(net: Net, ldl: Label, ldb: Location, lda: Location) -> Location:
return pos(net, LD0(ldl, ldb, lda))
def ld1(net: Net, ldl: Label, ldb: Location, lda: Location) -> Location:
return pos(net, LD1(ldl, ldb, lda))
def cnt(net: Net, cnb: Location) -> Location:
return pos(net, CNT(cnb))
def ref(net: Net, ide: Identifier) -> Location:
return pos(net, REF(ide))
def ave(net: Net, bod: Location) -> Location:
return pos(net, AVE(bod))
def sub(net: Net, loc: Location) -> Location:
return neg(net, SUB(loc))
def app(net: Net, arg: Location, ret: Location) -> Location:
return neg(net, APP(arg, ret))
def dup(net: Net, dpl: Label, dp0: Location, dp1: Location) -> Location:
return neg(net, DUP(dpl, dp0, dp1))
def udp(net: Net, ud0: Location, ud1: Location) -> Location:
return neg(net, UDP(ud0, ud1))
def eva(net: Net, ret: Location) -> Location:
return neg(net, EVA(ret))
def wire(net: Net) -> tuple[Location, Location]:
nam = next(location)
return var(net, nam), sub(net, nam)
def define(net: Net, ide: Identifier, cons: Callable[[Net], Location]) -> None:
dfn = net_empty()
rot = cons(dfn)
net.scop[ide] = dfn, rot
def show_pos(net: Net, pos: Location) -> str:
match net.vars[pos]:
case VAR(loc) if loc in net.vars:
return show_pos(net, loc)
case VAR(loc):
return f"+{loc}"
case LAM(bnd, bod):
return f"+({show_neg(net, bnd)} {show_pos(net, bod)})"
case SUP(spl, sp0, sp1):
return f"+{spl}{{{show_pos(net, sp0)} {show_pos(net, sp1)}}}"
case LAP(fun, arg):
return f"+?({show_pos(net, fun)} {show_pos(net, arg)})"
case LD0(ldl, ldb, lda) if ldb in net.vars:
return f"+?0&{ldl}{{{show_pos(net, ldb)}}}"
case LD0(ldl, ldb, lda):
return show_pos(net, lda)
case LD1(ldl, ldb, lda) if ldb in net.vars:
return f"+?1&{ldl}{{{show_pos(net, ldb)}}}"
case LD1(ldl, ldb, lda):
return show_pos(net, lda)
case CNT(cnb):
return f"+%[{show_pos(net, cnb)}]"
case REF(ide):
return f"+@{ide}"
case AVE(bod):
return f"+!({show_pos(net, bod)})"
def show_neg(net: Net, neg: Location) -> str:
match net.subs[neg]:
case SUB(loc) if loc in net.subs:
return show_neg(net, loc)
case SUB(loc):
return f"-{loc}"
case APP(arg, ret):
return f"-({show_pos(net, arg)} {show_neg(net, ret)})"
case DUP(dpl, dp0, dp1):
return f"-&{dpl}{{{show_neg(net, dp0)} {show_neg(net, dp1)}}}"
case UDP(ud0, ud1):
return f"-%[{show_neg(net, ud0)} {show_neg(net, ud1)}]"
case EVA(ret):
return f"-!({show_neg(net, ret)})"
def reduce(net: Net) -> int:
itrs = 0
while net.book:
itrs += 1
lhs, rhs = net.book.pop()
match net.subs.pop(lhs), net.vars.pop(rhs):
case lhsc, VAR(loc) if loc in net.vars:
net.subs[lhs] = lhsc
net.book.add((lhs, loc))
case lhsc, VAR(loc):
net.subs[loc] = lhsc
case SUB(loc), rhsc if loc in net.subs:
net.vars[rhs] = rhsc
net.book.add((loc, rhs))
case SUB(loc), rhsc:
net.vars[loc] = rhsc
case lhsc, REF(ide):
net.subs[lhs] = lhsc
dfn, rot = net.scop[ide]
net_embed(net, net_clone(dfn, locations:={}, labels:={}))
net.book.add((lhs, locations[rot]))
case APP(arg, ret), LAM(bnd, bod):
net.book.add((bnd, arg))
net.book.add((ret, bod))
case APP(arg, ret), SUP(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
x = next(location)
net.book.add((ret, sup(net, spl, lap(net, sp0, ld0(net, spl, arg, x)), lap(net, sp1, ld1(net, spl, arg, x)))))
case APP(arg, ret), AVE(bod):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((app(net, ap, bn), bod))
net.book.add((eva(net, an), arg))
net.book.add((ret, ave(net, bp)))
case DUP(dpl, dp0, dp1), LAM(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
x = next(location)
net.book.add((dp0, lam(net, an, ld0(net, dpl, bod, x))))
net.book.add((dp1, lam(net, bn, ld1(net, dpl, bod, x))))
net.book.add((bnd, sup(net, dpl, ap, bp)))
case DUP(dpl, dp0, dp1), SUP(spl, sp0, sp1) if dpl == spl:
net.book.add((dp0, sp0))
net.book.add((dp1, sp1))
case DUP(dpl, dp0, dp1), SUP(spl, sp0, sp1):
x = next(location)
y = next(location)
net.book.add((dp0, sup(net, spl, ld0(net, dpl, sp0, x), ld0(net, dpl, sp1, y))))
net.book.add((dp1, sup(net, spl, ld1(net, dpl, sp0, x), ld1(net, dpl, sp1, y))))
case DUP(dpl, dp0, dp1), AVE(bod):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((dp0, ave(net, ap)))
net.book.add((dp1, ave(net, bp)))
net.book.add((dup(net, dpl, an, bn), bod))
case UDP(ud0, ud1), LAM(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
lab = next(label)
x = next(location)
net.book.add((ud0, lam(net, an, ld0(net, lab, bod, x))))
net.book.add((ud1, lam(net, bn, ld1(net, lab, bod, x))))
net.book.add((bnd, sup(net, lab, ap, bp)))
case UDP(ud0, ud1), SUP(spl, sp0, sp1):
net.book.add((ud0, sup(net, spl, cnt(net, sp0), cnt(net, sp1))))
net.book.add((ud1, sup(net, spl, cnt(net, sp0), cnt(net, sp1))))
case UDP(ud0, ud1), CNT(cnb):
net.book.add((ud0, cnt(net, cnb)))
net.book.add((ud1, cnt(net, cnb)))
case UDP(ud0, ud1), AVE(bod):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((ud0, ave(net, ap)))
net.book.add((ud1, ave(net, bp)))
net.book.add((udp(net, an, bn), bod))
case EVA(ret), LAM(bnd, bod):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((ret, lam(net, an, bp)))
net.book.add((bnd, ave(net, ap)))
net.book.add((eva(net, bn), bod))
case EVA(ret), SUP(spl, sp0, sp1):
ap, an = wire(net)
bp, bn = wire(net)
net.book.add((eva(net, an), sp0))
net.book.add((eva(net, bn), sp1))
net.book.add((ret, sup(net, spl, ap, bp)))
case EVA(ret), AVE(bod):
net.book.add((ret, bod))
case lhsc, LAP(fun, arg):
net.subs[lhs] = lhsc
net.book.add((app(net, arg, lhs), fun))
case lhsc, LD0(ldl, ldb, lda) if ldb in net.vars:
net.subs[lhs] = lhsc
x = next(location)
net.vars[x] = net.vars.pop(ldb)
ap, an = wire(net)
net.book.add((dup(net, ldl, lhs, an), x))
net.vars[lda] = net.vars.pop(ap)
case lhsc, LD0(ldl, ldb, lda):
net.subs[lhs] = lhsc
net.book.add((lhs, lda))
case lhsc, LD1(ldl, ldb, lda) if ldb in net.vars:
net.subs[lhs] = lhsc
x = next(location)
net.vars[x] = net.vars.pop(ldb)
ap, an = wire(net)
net.book.add((dup(net, ldl, an, lhs), x))
net.vars[lda] = net.vars.pop(ap)
case lhsc, LD1(ldl, ldb, lda):
net.subs[lhs] = lhsc
net.book.add((lhs, lda))
case lhsc, CNT(cnb):
net.subs[lhs] = lhsc
x = next(location)
net.vars[x] = net.vars.pop(cnb)
ap, an = wire(net)
net.book.add((udp(net, lhs, an), x))
net.vars[cnb] = net.vars.pop(ap)
return itrs
def print_scope(net: Net) -> None:
for i, (ide, (dfn, rot)) in enumerate(net.scop.items()):
print(f"{ide} = {show_pos(dfn, rot)}")
for lhs, rhs in dfn.book:
print(f" {show_neg(dfn, lhs)} ⋈ {show_pos(dfn, rhs)}")
if i+1 < len(net.scop):
print()
def print_state(net: Net, root: Location, *, heap: bool = False) -> None:
print("ROOT:")
print(f" {show_pos(net, root)}")
print()
print("BOOK:")
for lhs, rhs in net.book:
print(f" {show_neg(net, lhs)} ⋈ {show_pos(net, rhs)}")
if heap:
print()
print("VARS:")
for loc in net.vars:
print(f" {loc} = {show_pos(net, loc)}")
print()
print("SUBS:")
for loc in net.subs:
print(f" {loc} = {show_neg(net, loc)}")
def print_reduction(net: Net, root: Location, *, heap: bool = False) -> None:
print("=" * 30)
print("=", "SCOPE".center(26), "=")
print("=" * 30)
print()
print_scope(net)
print()
print("=" * 30)
print("=", "INITIAL".center(26), "=")
print("=" * 30)
print()
print_state(net, root, heap=heap)
print()
print("=" * 30)
print("=", "NORMALIZED".center(26), "=")
print("=" * 30)
print()
itrs = reduce(net)
print_state(net, root, heap=heap)
print()
print(f"ITRS: {itrs}")
def mk_nil(net: Net) -> Location:
ap, an = wire(net)
return ap
def mk_era(net: Net) -> Location:
ap, an = wire(net)
return an
def mk_ldp(net: Net, bod: Location, lab: Label | None = None) -> tuple[Location, Location]:
if lab is None:
lab = next(label)
x = next(location)
return ld0(net, lab, bod, x), ld1(net, lab, bod, x)
def mk_cnt(net: Net, bod: Location, n: int) -> list[Location]:
return [cnt(net, bod) for _ in range(n)]
def mk_eva(net: Net, bod: Location) -> Location:
ap, an = wire(net)
net.book.add((eva(net, an), bod))
return ap
def mk_dup(net: Net, x: Location, lab: Label | None = None) -> tuple[Location, Location]:
if lab is None:
lab = next(label)
x0p, x0n = wire(net)
x1p, x1n = wire(net)
net.book.add((dup(net, lab, x0n, x1n), x))
return x0p, x1p
# λz.λs.z
def mk_Z(net: Net) -> Location:
zp, zn = wire(net)
sp, sn = wire(net)
return lam(net, zn, lam(net, sn, zp))
# λn.λz.λs.(s n)
def mk_S(net: Net) -> Location:
np, nn = wire(net)
zp, zn = wire(net)
sp, sn = wire(net)
return lam(net, nn, lam(net, zn, lam(net, sn, lap(net, sp, np))))
# λa.λb.(a b λp.(@S (@ADD p b)))
def mk_add(net: Net) -> Location:
ap, an = wire(net)
bp, bn = wire(net)
pp, pn = wire(net)
b0, b1 = mk_dup(net, bp)
return lam(net, an, lam(net, bn, lap(net, lap(net, ap, b0), lam(net, pn, lap(net, ref(net, "S"), lap(net, lap(net, ref(net, "ADD"), pp), b1))))))
# (@S @Z)
def mk_1(net: Net) -> Location:
return lap(net, ref(net, "S"), ref(net, "Z"))
def mk_c2(net: Net) -> Location:
fp, fn = wire(net)
xp, xn = wire(net)
f0, f1 = mk_ldp(net, fp)
fx = lap(net, f0, xp)
ffx = lap(net, f1, fx)
return lam(net, fn, lam(net, xn, ffx))
def mk_cpow2(net: Net, k: int) -> Location:
# k=1 -> c2, k=2 -> c4, k=3 -> c8, ...
# c_{2^k} = λf. c2 (c_{2^(k-1)} f)
if k < 1:
raise ValueError("k must be >= 1")
if k == 1:
return ref(net, "C2")
fp, fn = wire(net)
prev = mk_cpow2(net, k - 1) # c_{2^(k-1)}
prev_f = lap(net, prev, fp) # f^(2^(k-1))
body = lap(net, ref(net, "C2"), prev_f) # square -> f^(2^k)
return lam(net, fn, body)
# λx.x
def mk_id(net: Net) -> Location:
xp, xn = wire(net)
x0, x1 = mk_cnt(net, xp, 2)
return lam(net, xn, x0)
# λt.λf.t
def mk_true(net: Net) -> Location:
tp, tn = wire(net)
fp, fn = wire(net)
return lam(net, tn, lam(net, fn, tp))
# λt.λf.f
def mk_false(net: Net) -> Location:
tp, tn = wire(net)
fp, fn = wire(net)
return lam(net, tn, lam(net, fn, fp))
# λb.λt.λf.(b (f f) f)
def mk_F(net: Net) -> Location:
bp, bn = wire(net)
tp, tn = wire(net)
fp, fn = wire(net)
f0, f1, f2 = mk_cnt(net, fp, 3)
return lam(net, bn, lam(net, tn, lam(net, fn, lap(net, lap(net, bp, lap(net, f0, f1)), f2))))
# @F^(2^N)
def test_F_fusion(net: Net, N: int) -> Location:
return lap(net, mk_cpow2(net, N), ref(net, "F"))
# λe.λo.λi.e
def mk_E(net: Net) -> Location:
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
return lam(net, en, lam(net, on, lam(net, in_, ep)))
# λp.λe.λo.λi.(o p)
def mk_O(net: Net) -> Location:
pp, pn = wire(net)
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
return lam(net, pn, lam(net, en, lam(net, on, lam(net, in_, lap(net, op, pp)))))
# λp.λe.λo.λi.(i p)
def mk_I(net: Net) -> Location:
pp, pn = wire(net)
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
return lam(net, pn, lam(net, en, lam(net, on, lam(net, in_, lap(net, ip, pp)))))
# λn.(n
# λxs.λe.λo.λi.(xs
# (i @E)
# λxs.(i xs)
# λxs.(i xs)
# )
# @NIL
# λp.λxs.λe.λo.λi.(xs
# (o (@INSERT p @E))
# λxs.(o (@INSERT p xs))
# λxs.(i (@INSERT p xs))
# )
# )
def mk_insert(net: Net) -> Location:
np, nn = wire(net)
def mk_case_E() -> Location:
xsp, xsn = wire(net)
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
i0, i1, i2 = mk_cnt(net, ip, 3)
def mk_case_E_case_E() -> Location:
return lap(net, i0, ref(net, "E"))
def mk_case_E_case_O() -> Location:
xsp, xsn = wire(net)
return lam(net, xsn, lap(net, i1, xsp))
def mk_case_E_case_I() -> Location:
xsp, xsn = wire(net)
return lam(net, xsn, lap(net, i2, xsp))
case_E_ret = lap(net, lap(net, lap(net, xsp, mk_case_E_case_E()), mk_case_E_case_O()), mk_case_E_case_I())
return lam(net, xsn, lam(net, en, lam(net, on, lam(net, in_, case_E_ret))))
def mk_case_O() -> Location:
return ref(net, "NIL")
def mk_case_I() -> Location:
pp, pn = wire(net)
xsp, xsn = wire(net)
ep, en = wire(net)
op, on = wire(net)
ip, in_ = wire(net)
p0, p1, p2 = mk_cnt(net, pp, 3)
o0, o1 = mk_cnt(net, op, 2)
def mk_case_I_case_E() -> Location:
return lap(net, o0, lap(net, lap(net, ref(net, "INSERT"), p0), ref(net, "E")))
def mk_case_I_case_O() -> Location:
xsp, xsn = wire(net)
return lam(net, xsn, lap(net, o1, lap(net, lap(net, ref(net, "INSERT"), p1), xsp)))
def mk_case_I_case_I() -> Location:
xsp, xsn = wire(net)
return lam(net, xsn, lap(net, ip, lap(net, lap(net, ref(net, "INSERT"), p2), xsp)))
case_I_ret = lap(net, lap(net, lap(net, xsp, mk_case_I_case_E()), mk_case_I_case_O()), mk_case_I_case_I())
return lam(net, pn, lam(net, xsn, lam(net, en, lam(net, on, lam(net, in_, case_I_ret)))))
ret = lap(net, lap(net, lap(net, np, mk_case_E()), mk_case_O()), mk_case_I())
return lam(net, nn, ret)
# (@INSERT (@I @E))^(2^N)
def test_insert_fusion(net: Net, N: int) -> Location:
return lap(net, mk_cpow2(net, N), lap(net, ref(net, "INSERT"), lap(net, ref(net, "I"), ref(net, "E"))))
def main() -> None:
net = net_empty()
define(net, "ID", lambda net: mk_id(net))
define(net, "TRUE", lambda net: mk_true(net))
define(net, "FALSE", lambda net: mk_false(net))
define(net, "C2", lambda net: mk_c2(net))
define(net, "F", lambda net: mk_F(net))
define(net, "E", lambda net: mk_E(net))
define(net, "O", lambda net: mk_O(net))
define(net, "I", lambda net: mk_I(net))
define(net, "INSERT", lambda net: mk_insert(net))
# root = test_F_fusion(net, 100)
root = test_insert_fusion(net, 4)
print_reduction(net, mk_eva(net, root))
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment