Skip to content

Instantly share code, notes, and snippets.

@Epivalent
Created January 21, 2026 10:12
Show Gist options
  • Select an option

  • Save Epivalent/bef4d22fe4e54170dc84f4004a73193c to your computer and use it in GitHub Desktop.

Select an option

Save Epivalent/bef4d22fe4e54170dc84f4004a73193c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""TopCVP_beta CLI
This script packages the current working construction (paper-faithful):
Primitives (all are *string* lambda terms built via plc_suite.terms):
- fetch_{i,n} (paper), with NOT-substitution implemented by swapping CSTT/CSTF
*only* at the (t==n+1 and j==i) app-slot inside the fetch constructor.
- and'_n (paper)
- OR derived purely by De Morgan using NOT-substitution + and'
Normalization/tagging:
- Default uses /mnt/data/planar_nf.py (fast normalizer).
- Optionally uses an external binary normalizer which accepts the term as argv1
and prints the NF on stdout.
Examples:
# tag a boolean term
python3 topcvp_beta_cli.py tag --term "(\\b.\\k.\\f.(b k f))"
# run the paper example circuit (prints new-bit tags only)
python3 topcvp_beta_cli.py paper-example
# evaluate a CVP gate list (JSON)
python3 topcvp_beta_cli.py cvp --eqs-json '[["const",1],["const",0],["and",1,2],["not",1],["or",3,4]]'
# evaluate a DIMACS CNF with a provided assignment
python3 topcvp_beta_cli.py cnf-eval --dimacs foo.cnf --assign "1,-2,3"
"""
from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
sys.path.insert(0, '/mnt/data')
import plc_suite.terms as T
# -------------------------
# Normalization backends
# -------------------------
class NFEngine:
def normalize(self, src: str) -> str:
raise NotImplementedError
class PlanarNFEngine(NFEngine):
def __init__(self):
from planar_nf import parse, normalize, to_string # local file provided by user
self._parse = parse
self._normalize = normalize
self._to_string = to_string
def normalize(self, src: str) -> str:
t = self._parse(src)
nf = self._normalize(t)
return self._to_string(nf)
class ExternalBinEngine(NFEngine):
def __init__(self, bin_path: str):
self.bin_path = bin_path
def normalize(self, src: str) -> str:
# External tool contract: argv1 is term, stdout is NF.
# Note: huge terms may exceed OS argv limits; for that case, wrap your binary.
p = subprocess.run(
[self.bin_path, src],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False,
)
if p.returncode != 0:
raise RuntimeError(
f"External normalizer failed (rc={p.returncode}).\nSTDERR:\n{p.stderr.strip()}"
)
return p.stdout.strip()
@dataclass
class BoolTagger:
eng: NFEngine
dec_true_nf: str
dec_false_nf: str
cache: Dict[str, str]
@classmethod
def build(cls, eng: NFEngine) -> 'BoolTagger':
dec_true_src = f"({T.DEC} {T.TRUE})"
dec_false_src = f"({T.DEC} {T.FALSE})"
dec_true_nf = eng.normalize(dec_true_src)
dec_false_nf = eng.normalize(dec_false_src)
return cls(eng=eng, dec_true_nf=dec_true_nf, dec_false_nf=dec_false_nf, cache={})
def tag(self, term: str) -> str:
# Fast tag of a boolean term via DEC-normal form equality.
if term in self.cache:
return self.cache[term]
src = f"({T.DEC} {term})"
nf = self.eng.normalize(src)
if nf == self.dec_true_nf:
self.cache[term] = 'T'
return 'T'
if nf == self.dec_false_nf:
self.cache[term] = 'F'
return 'F'
self.cache[term] = '?'
return '?'
def make_engine(args: argparse.Namespace) -> NFEngine:
if args.normalizer == 'planar_nf':
return PlanarNFEngine()
if args.normalizer == 'external':
if not args.normalizer_bin:
raise SystemExit('need --normalizer-bin when --normalizer external')
return ExternalBinEngine(args.normalizer_bin)
raise SystemExit(f'unknown normalizer: {args.normalizer}')
# -------------------------
# Term construction (paper)
# -------------------------
var, lam, lams, app = T.var, T.lam, T.lams, T.app
apps = getattr(T, 'apps', None)
TRUE, FALSE = T.TRUE, T.FALSE
COMP, AND = T.COMP, T.AND
CSTT, CSTF = T.CSTT, T.CSTF
def proj_of_tuple(t: str, j: int, m: int) -> str:
xs = [f'x{i}' for i in range(1, m + 1)]
return app(t, lams(xs, var(f'x{j}')))
def compose(f: str, g: str) -> str:
return app(app(COMP, f), g)
def append_const_op(n: int, const: str) -> str:
"""Unary op on tuples: <x1..xn> -> <x1..xn,const> (as a tuple, i.e. \k. ...)."""
xs = [f'x{i}' for i in range(1, n + 1)]
body = var('k')
for x in xs:
body = app(body, var(x))
body = app(body, const)
# op expects (v,k): op v k = v (\x1..xn. k x1..xn const)
return lams(['v', 'k'], app(var('v'), lams(xs, body)))
def setterT(n: int, j: int, i: int, *, which: str, neg: bool = False) -> str:
"""Unary transformer on Church tuples, following the paper.
which='t' uses CSTT at position j; which='f' uses CSTF.
For the i=j branch, the extra appended position b_{n+1} uses CSTapp,
where neg=True swaps CSTT<->CSTF as the NOT-substitution.
"""
if which not in ('t', 'f'):
raise ValueError(which)
CSTj = CSTT if which == 't' else CSTF
if which == 't':
CSTapp = CSTT if not neg else CSTF
else:
CSTapp = CSTF if not neg else CSTT
bs = [f'b{t}' for t in range(1, n + 2)] # b1..b_{n+1}
args: List[str] = []
for t in range(1, n + 2):
if t == j:
args.append(app(CSTj, var(f'b{t}')))
elif (t == n + 1) and (j == i):
args.append(app(CSTapp, var(f'b{t}')))
else:
args.append(var(f'b{t}'))
body = var('k')
for a in args:
body = app(body, a)
# unary transformer on tuples: t |-> (\k. t (\b1..b_{n+1}. k ...))
return lam('t', lam('k', app(var('t'), lams(bs, body))))
def cj(n: int, j: int, i: int, *, neg: bool = False) -> str:
S = setterT(n, j, i, which='t', neg=neg)
return lams(['g', 'h'], compose(var('g'), compose(S, var('h'))))
def fj(n: int, j: int, i: int, *, neg: bool = False) -> str:
return setterT(n, j, i, which='f', neg=neg)
def fetch_paper(i: int, n: int, *, neg: bool = False) -> str:
"""fetch_{i,n}; if neg=True, NOT-substitution inside fetch."""
if not (1 <= i <= n):
raise ValueError('need 1 <= i <= n')
base = T.tup([FALSE] * (n + 1))
rest = base
for j in range(n, 0, -1):
rest = app(app(app(var(f'x{j}'), cj(n, j, i, neg=neg)), fj(n, j, i, neg=neg)), rest)
xs = [f'x{j}' for j in range(1, n + 1)]
return lam('v', app(var('v'), lams(xs, rest)))
def and_prime_paper(n: int) -> str:
"""and'_n: <x1..x_{n+2}> -> <x1..xn, x_{n+1} AND x_{n+2}> (in CPS with k)."""
xs = [f'x{i}' for i in range(1, n + 3)]
out = var('k')
for i in range(1, n + 1):
out = app(out, var(f'x{i}'))
out = app(out, app(app(AND, var(f'x{n+1}')), var(f'x{n+2}')))
return lams(['v', 'k'], app(var('v'), lams(xs, out)))
def and_gate(i: int, j: int, n: int, v: str) -> str:
"""Append (xi AND xj) to a length-n vector v, producing length n+1."""
A = and_prime_paper(n)
return app(A, app(fetch_paper(i, n + 1), app(fetch_paper(j, n), v)))
def not_gate(i: int, n: int, v: str) -> str:
"""Append NOT(xi) to a length-n vector v, producing length n+1."""
return app(fetch_paper(i, n, neg=True), v)
def or_gate_dm(i: int, j: int, n: int, v: str) -> Tuple[str, int, int]:
"""Append (xi OR xj) using De Morgan with only not_gate + and_gate.
Expands to 4 steps, so vector length increases by 4.
Returns (v_out, n_out, out_pos).
"""
# x_{n+1} := ~xi
v = not_gate(i, n, v)
n += 1
t1 = n
# x_{n+1} := ~xj
v = not_gate(j, n, v)
n += 1
t2 = n
# x_{n+1} := t1 & t2
v = and_gate(t1, t2, n, v)
n += 1
t3 = n
# x_{n+1} := ~t3 = xi OR xj
v = not_gate(t3, n, v)
n += 1
out = n
return v, n, out
# -------------------------
# CVP gate compilation
# -------------------------
Eq = Union[Tuple[str, int], Tuple[str, int, int]] # ('const',0|1) / ('not',i) / ('and',i,j) / ('or',i,j)
def compile_cvp(eqs: Sequence[Eq]) -> Tuple[str, int, List[int]]:
"""Compile a high-level CVP gate list to a single vector term.
Returns (vector_term, final_length, wire_pos)
- wire_pos[k-1] is the actual vector position of logical gate k.
- OR introduces auxiliaries, but wire_pos always points to the gate output.
"""
v = T.tup([])
n = 0
wire_pos: List[int] = []
for gate_idx, eq in enumerate(eqs, start=1):
op = eq[0]
if op == 'const':
bit = int(eq[1])
const = TRUE if bit == 1 else FALSE
v = app(append_const_op(n, const), v)
n += 1
wire_pos.append(n)
elif op == 'not':
i = int(eq[1])
src_pos = wire_pos[i - 1]
v = not_gate(src_pos, n, v)
n += 1
wire_pos.append(n)
elif op == 'and':
i, j = int(eq[1]), int(eq[2])
pi, pj = wire_pos[i - 1], wire_pos[j - 1]
v = and_gate(pi, pj, n, v)
n += 1
wire_pos.append(n)
elif op == 'or':
i, j = int(eq[1]), int(eq[2])
pi, pj = wire_pos[i - 1], wire_pos[j - 1]
v, n, out_pos = or_gate_dm(pi, pj, n, v)
wire_pos.append(out_pos)
else:
raise ValueError(f'unknown gate op: {op}')
return v, n, wire_pos
def eval_cvp_output_tag(tagger: BoolTagger, eqs: Sequence[Eq]) -> str:
v, n, wire_pos = compile_cvp(eqs)
out_pos = wire_pos[-1]
out_term = proj_of_tuple(v, out_pos, n)
return tagger.tag(out_term)
# -------------------------
# CNF helper
# -------------------------
def parse_dimacs(path: str) -> Tuple[int, List[List[int]]]:
clauses: List[List[int]] = []
nvars = 0
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line or line.startswith('c'):
continue
if line.startswith('p'):
parts = line.split()
if len(parts) >= 4 and parts[1] == 'cnf':
nvars = int(parts[2])
continue
lits = [int(x) for x in line.split()]
if not lits:
continue
if lits[-1] == 0:
lits = lits[:-1]
if lits:
clauses.append(lits)
if nvars <= 0:
# fall back
nvars = max((abs(l) for c in clauses for l in c), default=0)
return nvars, clauses
def compile_cnf_to_cvp(nvars: int, clauses: List[List[int]], assignment: List[int]) -> List[Eq]:
"""Compile CNF under a *fixed* assignment to a CVP gate list.
assignment is a list of ints like [1,-2,3] meaning x1=T, x2=F, x3=T.
Missing vars default to F.
"""
asg: Dict[int, int] = {abs(v): (1 if v > 0 else 0) for v in assignment}
eqs: List[Eq] = []
# First, create x1..xn as constants (the assignment).
for i in range(1, nvars + 1):
eqs.append(('const', int(asg.get(i, 0))))
def lit_to_wire(lit: int) -> int:
vid = abs(lit)
if lit > 0:
return vid
# negated literal: create a NOT gate
eqs.append(('not', vid))
return len(eqs)
clause_wires: List[int] = []
for clause in clauses:
if not clause:
# empty clause => UNSAT
eqs.append(('const', 0))
clause_wires.append(len(eqs))
continue
cur = lit_to_wire(clause[0])
for lit in clause[1:]:
nxt = lit_to_wire(lit)
eqs.append(('or', cur, nxt))
cur = len(eqs)
clause_wires.append(cur)
# AND all clauses
if not clause_wires:
eqs.append(('const', 1))
return eqs
cur = clause_wires[0]
for w in clause_wires[1:]:
eqs.append(('and', cur, w))
cur = len(eqs)
return eqs
# -------------------------
# CLI helpers
# -------------------------
def load_term_arg(term: Optional[str], term_file: Optional[str]) -> str:
if term is None and term_file is None:
raise SystemExit('need --term or --term-file')
if term is not None and term_file is not None:
raise SystemExit('use only one of --term or --term-file')
if term_file is not None:
return Path(term_file).read_text(encoding='utf-8')
return term # type: ignore[return-value]
def load_eqs(args: argparse.Namespace) -> List[Eq]:
if args.eqs_file:
data = json.loads(Path(args.eqs_file).read_text(encoding='utf-8'))
elif args.eqs_json:
data = json.loads(args.eqs_json)
else:
raise SystemExit('need --eqs-json or --eqs-file')
eqs: List[Eq] = []
for e in data:
if not isinstance(e, (list, tuple)) or not e:
raise SystemExit(f'bad eq: {e!r}')
op = e[0]
if op == 'const':
eqs.append(('const', int(e[1])))
elif op == 'not':
eqs.append(('not', int(e[1])))
elif op in ('and', 'or'):
eqs.append((op, int(e[1]), int(e[2])))
else:
raise SystemExit(f'unknown op in eq: {e!r}')
return eqs
def parse_assign(s: str) -> List[int]:
# "1,-2,3" or "1 -2 3"
s = s.replace(',', ' ')
out: List[int] = []
for part in s.split():
out.append(int(part))
return out
# -------------------------
# Subcommands
# -------------------------
def cmd_norm(args: argparse.Namespace) -> int:
eng = make_engine(args)
term = load_term_arg(args.term, args.term_file)
nf = eng.normalize(term)
print(nf)
return 0
def cmd_tag(args: argparse.Namespace) -> int:
eng = make_engine(args)
tagger = BoolTagger.build(eng)
term = load_term_arg(args.term, args.term_file)
print(tagger.tag(term))
return 0
def cmd_paper_example(args: argparse.Namespace) -> int:
eng = make_engine(args)
tagger = BoolTagger.build(eng)
# Paper example (as we used):
# x1:=1; x2:=0; x3:=1; x4=x1&x2; x5=~x1; x6=x5&x3; x7=x4|x6
# We implement OR via De Morgan macro (4 extra bits). We only check each *new* bit.
def step(v: str, n: int, expected: str) -> str:
newbit = tagger.tag(proj_of_tuple(v, n, n))
ok = 'OK' if newbit == expected else 'BAD'
print(f" new x{n} = {newbit} (expect {expected}) => {ok}")
return newbit
bits: List[str] = []
v = T.tup([])
n = 0
# x1 := 1
v = app(append_const_op(n, TRUE), v); n += 1
bits.append(step(v, n, 'T'))
# x2 := 0
v = app(append_const_op(n, FALSE), v); n += 1
bits.append(step(v, n, 'F'))
# x3 := 1
v = app(append_const_op(n, TRUE), v); n += 1
bits.append(step(v, n, 'T'))
# x4 := x1 & x2
v = and_gate(1, 2, n, v); n += 1
exp = 'T' if (bits[0] == 'T' and bits[1] == 'T') else 'F'
bits.append(step(v, n, exp))
# x5 := ~x1
v = not_gate(1, n, v); n += 1
exp = 'T' if bits[0] == 'F' else 'F'
bits.append(step(v, n, exp))
# x6 := x5 & x3
v = and_gate(5, 3, n, v); n += 1
exp = 'T' if (bits[4] == 'T' and bits[2] == 'T') else 'F'
bits.append(step(v, n, exp))
# OR macro: x7 := ~x4 ; x8 := ~x6 ; x9 := x7 & x8 ; x10 := ~x9
v = not_gate(4, n, v); n += 1
exp = 'T' if bits[3] == 'F' else 'F'
bits.append(step(v, n, exp))
v = not_gate(6, n, v); n += 1
exp = 'T' if bits[5] == 'F' else 'F'
bits.append(step(v, n, exp))
v = and_gate(7, 8, n, v); n += 1
exp = 'T' if (bits[6] == 'T' and bits[7] == 'T') else 'F'
bits.append(step(v, n, exp))
v = not_gate(9, n, v); n += 1
exp = 'T' if bits[8] == 'F' else 'F'
bits.append(step(v, n, exp))
print(f"\nFinal derived OR bit (x{n}) = {bits[-1]}")
print("Expected (x4 OR x6) =", 'T' if (bits[3] == 'T' or bits[5] == 'T') else 'F')
return 0
def cmd_cvp(args: argparse.Namespace) -> int:
eng = make_engine(args)
tagger = BoolTagger.build(eng)
eqs = load_eqs(args)
v, n, wire_pos = compile_cvp(eqs)
if args.dump_wires:
print('wire_pos (logical_gate -> vector_position):')
for i, p in enumerate(wire_pos, start=1):
print(f' g{i} -> x{p}')
if args.dump_gate_tags:
for gi, p in enumerate(wire_pos, start=1):
b = tagger.tag(proj_of_tuple(v, p, n))
print(f'g{gi} (x{p}) = {b}')
out_pos = wire_pos[-1]
out = tagger.tag(proj_of_tuple(v, out_pos, n))
print(out)
return 0
def cmd_cnf_eval(args: argparse.Namespace) -> int:
eng = make_engine(args)
tagger = BoolTagger.build(eng)
nvars, clauses = parse_dimacs(args.dimacs)
assignment = parse_assign(args.assign)
eqs = compile_cnf_to_cvp(nvars, clauses, assignment)
if args.dump_eqs:
print(json.dumps(eqs))
out = eval_cvp_output_tag(tagger, eqs)
print(out)
return 0
# -------------------------
# Main
# -------------------------
def build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(prog='topcvp_beta_cli.py')
p.add_argument(
'--normalizer',
choices=['planar_nf', 'external'],
default='planar_nf',
help='normalization backend',
)
p.add_argument(
'--normalizer-bin',
default=None,
help='path to external normalizer binary (only for --normalizer external)',
)
sub = p.add_subparsers(dest='cmd', required=True)
s_norm = sub.add_parser('norm', help='normalize a term and print NF')
s_norm.add_argument('--term', default=None)
s_norm.add_argument('--term-file', default=None)
s_norm.set_defaults(func=cmd_norm)
s_tag = sub.add_parser('tag', help='tag a boolean term as T/F/? using DEC')
s_tag.add_argument('--term', default=None)
s_tag.add_argument('--term-file', default=None)
s_tag.set_defaults(func=cmd_tag)
s_paper = sub.add_parser('paper-example', help="run the paper example circuit (fast tagging)")
s_paper.set_defaults(func=cmd_paper_example)
s_cvp = sub.add_parser('cvp', help='evaluate a CVP gate list (JSON) and print final T/F/?')
s_cvp.add_argument('--eqs-json', default=None, help='JSON list of gates')
s_cvp.add_argument('--eqs-file', default=None, help='path to JSON file')
s_cvp.add_argument('--dump-wires', action='store_true', help='print logical->vector mapping')
s_cvp.add_argument('--dump-gate-tags', action='store_true', help='tag each gate output')
s_cvp.set_defaults(func=cmd_cvp)
s_cnf = sub.add_parser('cnf-eval', help='evaluate DIMACS CNF under a fixed assignment')
s_cnf.add_argument('--dimacs', required=True)
s_cnf.add_argument('--assign', required=True, help='e.g. "1,-2,3" meaning x1=T,x2=F,x3=T')
s_cnf.add_argument('--dump-eqs', action='store_true', help='print the compiled gate list')
s_cnf.set_defaults(func=cmd_cnf_eval)
return p
def main(argv: Optional[Sequence[str]] = None) -> int:
ap = build_argparser()
args = ap.parse_args(argv)
return int(args.func(args))
if __name__ == '__main__':
raise SystemExit(main())
@Epivalent
Copy link
Author

r"""
planar_nf.py — fast beta-normalizer for (mostly) planar/linear untyped lambda terms.

Input syntax (matches your examples):
  - Abstraction: \\\x. BODY        (also accepts 'λ' instead of '\')
  - Application: (F A)          (fully parenthesized recommended; N-ary apps allowed: (f a b c))
  - Variables:   identifiers like x, x1, b5, foo_bar

Key idea for speed on planar/linear terms:
  - During parsing, each binder records its *unique* variable occurrence node (if any).
  - Beta-reduction then becomes O(1): splice the argument subtree into that occurrence (no copying).
  - Strong normalization is implemented iteratively (no Python recursion).

If a binder is used more than once, we raise (that’s outside “planar/linear”).
"""

from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, List, Tuple, Union, Dict, Iterator
import itertools
import re

# -----------------------------
# Core term data structures
# -----------------------------

_binder_ids = itertools.count(1)

class Node:
    __slots__ = ("parent", "pslot")
    def __init__(self) -> None:
        self.parent: Optional[Node] = None
        self.pslot: Optional[str] = None  # which field in parent: "body" / "fn" / "arg"

class Binder:
    __slots__ = ("name", "id", "occ")
    def __init__(self, name: str) -> None:
        self.name = name
        self.id = next(_binder_ids)
        self.occ: Optional[Var] = None  # planar/linear: 0 or 1 occurrence

class Var(Node):
    __slots__ = ("binder", "name")
    def __init__(self, name: str, binder: Optional[Binder] = None) -> None:
        super().__init__()
        self.binder = binder
        self.name = name
        if binder is not None:
            if binder.occ is None:
                binder.occ = self
            else:
                # Non-linear: multiple occurrences
                raise ValueError(f"Non-linear variable use: binder '{binder.name}' appears more than once.")

class Abs(Node):
    __slots__ = ("binder", "body")
    def __init__(self, binder: Binder, body: Node) -> None:
        super().__init__()
        self.binder = binder
        self.body = body
        body.parent = self
        body.pslot = "body"

class App(Node):
    __slots__ = ("fn", "arg")
    def __init__(self, fn: Node, arg: Node) -> None:
        super().__init__()
        self.fn = fn
        self.arg = arg
        fn.parent = self
        fn.pslot = "fn"
        arg.parent = self
        arg.pslot = "arg"

Term = Union[Var, Abs, App]

# -----------------------------
# Tokenizer
# -----------------------------

_ident_re = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")

def tokenize(src: str) -> Iterator[str]:
    i = 0
    n = len(src)
    while i < n:
        c = src[i]
        if c.isspace():
            i += 1
            continue
        if c in ("(", ")", "."):
            yield c
            i += 1
            continue
        if c == "\\" or c == "λ":
            yield "\\"
            i += 1
            continue
        m = _ident_re.match(src, i)
        if not m:
            raise SyntaxError(f"Unexpected character at {i}: {src[i:i+20]!r}")
        yield m.group(0)
        i = m.end()

# -----------------------------
# Parser (iterative, stack-based)
# -----------------------------

_LP = ("LPAREN", None)  # marker (kind, binder_stack_len)
_LAM = ("LAM", None)    # marker (kind, Binder)

def parse(src: str) -> Term:
    toks = list(tokenize(src))
    tstack: List[object] = []
    bstack: List[Binder] = []
    i = 0

    def reduce_lams() -> None:
        # If stack ends with [("LAM", binder), <term>], reduce to Abs(binder, term)
        nonlocal tstack, bstack
        while len(tstack) >= 2 and isinstance(tstack[-2], tuple) and tstack[-2][0] == "LAM" and isinstance(tstack[-1], Node):
            _, b = tstack[-2]
            body = tstack.pop()  # type: ignore
            tstack.pop()
            # Pop binder from binder stack (must match)
            if not bstack or bstack[-1] is not b:
                raise RuntimeError("Binder stack mismatch while reducing lambda.")
            bstack.pop()
            tstack.append(Abs(b, body))  # type: ignore

    while i < len(toks):
        tok = toks[i]
        i += 1

        if tok == "(":
            tstack.append(("LPAREN", len(bstack)))
            continue

        if tok == ")":
            items: List[Node] = []
            while tstack:
                top = tstack.pop()
                if isinstance(top, tuple) and top[0] == "LPAREN":
                    target_len = top[1]
                    # Restore binder stack to what it was at '('
                    if len(bstack) < target_len:
                        raise RuntimeError("Binder stack underflow at ')'.")
                    del bstack[target_len:]
                    break
                if not isinstance(top, Node):
                    raise SyntaxError("Malformed group inside parentheses.")
                items.append(top)
            else:
                raise SyntaxError("Unmatched ')'.")
            items.reverse()
            if not items:
                raise SyntaxError("Empty parentheses group.")
            # (a b c) => ((a b) c)
            term: Node = items[0]
            for nxt in items[1:]:
                term = App(term, nxt)
            tstack.append(term)
            reduce_lams()
            continue

        if tok == "\\":
            if i >= len(toks): raise SyntaxError("Unexpected end after lambda.")
            name = toks[i]; i += 1
            if i >= len(toks) or toks[i] != ".": raise SyntaxError("Expected '.' after lambda parameter.")
            i += 1
            b = Binder(name)
            bstack.append(b)
            tstack.append(("LAM", b))
            continue

        if tok == ".":
            raise SyntaxError("Unexpected '.'.")

        # identifier: resolve to nearest binder with same name, else free var
        bnd: Optional[Binder] = None
        for b in reversed(bstack):
            if b.name == tok:
                bnd = b
                break
        tstack.append(Var(tok, bnd))
        reduce_lams()

    # Close any remaining lambdas (top-level)
    if any(isinstance(x, tuple) and x[0] == "LPAREN" for x in tstack):
        raise SyntaxError("Unmatched '('.")

    # Fold any remaining top-level applications (rare; usually one term)
    terms = [x for x in tstack if isinstance(x, Node)]
    if not terms:
        raise SyntaxError("No term parsed.")
    term: Node = terms[0]
    for nxt in terms[1:]:
        term = App(term, nxt)
    return term  # type: ignore

# -----------------------------
# Beta reduction (O(1) for linear)
# -----------------------------

def _replace_child(parent: Node, slot: str, new: Node) -> None:
    if isinstance(parent, Abs) and slot == "body":
        parent.body = new
    elif isinstance(parent, App) and slot == "fn":
        parent.fn = new
    elif isinstance(parent, App) and slot == "arg":
        parent.arg = new
    else:
        raise RuntimeError("Bad parent/slot in replacement.")
    new.parent = parent
    new.pslot = slot

def beta(abs_node: Abs, arg: Node) -> Node:
    b = abs_node.binder
    body = abs_node.body

    occ = b.occ
    if occ is None:
        # Argument unused
        body.parent = None
        body.pslot = None
        return body

    # If the body itself is the occurrence, the result is just the argument
    if occ is body:
        arg.parent = None
        arg.pslot = None
        return arg

    # Splice arg into occurrence site
    p = occ.parent
    slot = occ.pslot
    if p is None:
        # body itself is the variable
        arg.parent = None
        arg.pslot = None
        return arg
    assert slot is not None
    _replace_child(p, slot, arg)

    # Detach body from abs
    body.parent = None
    body.pslot = None
    return body

# -----------------------------
# Weak-head normalization (iterative)
# -----------------------------

def whnf(t: Node) -> Node:
    """
    Reduce t by normal-order at the head (left spine) until it is in weak-head normal form.
    Implemented with spine reuse to avoid allocations.
    """
    while True:
        # Unwind left spine collecting App nodes
        apps: List[App] = []
        cur: Node = t
        while isinstance(cur, App):
            apps.append(cur)
            cur = cur.fn

        reduced = False

        # Rebuild from inner -> outer, performing beta steps when possible
        for app in reversed(apps):
            app.fn = cur
            cur.parent = app
            cur.pslot = "fn"

            if isinstance(cur, Abs):
                cur = beta(cur, app.arg)
                reduced = True
                # cur is detached (parent None) here; outer rebuild will attach it
            else:
                cur = app

        cur.parent = None
        cur.pslot = None
        t = cur

        if not reduced:
            return t

# -----------------------------
# Full (strong) normalization, iterative
# -----------------------------

def normalize(t: Node) -> Node:
    """
    Compute full beta-normal form (strong normalization), normal-order style.
    Works best when the term is planar/linear.
    """
    t = whnf(t)

    # Manual stack to avoid recursion.
    # Frames:
    #   ("abs", abs_node)
    #   ("app_fn", app_node)  - fn done, next normalize arg
    #   ("app_arg", app_node) - arg done, finalize app and whnf it
    stack: List[Tuple[str, Node]] = []

    while True:
        if isinstance(t, Abs):
            stack.append(("abs", t))
            t = whnf(t.body)
            continue

        if isinstance(t, App):
            stack.append(("app_fn", t))
            t = whnf(t.fn)
            continue

        # t is Var (or free symbol)
        while stack:
            tag, node = stack.pop()

            if tag == "abs":
                absn: Abs = node  # type: ignore
                absn.body = t
                t.parent = absn
                t.pslot = "body"
                t = absn
                continue

            if tag == "app_fn":
                appn: App = node  # type: ignore
                appn.fn = t
                t.parent = appn
                t.pslot = "fn"
                stack.append(("app_arg", appn))
                t = whnf(appn.arg)
                break  # go back to outer loop

            if tag == "app_arg":
                appn: App = node  # type: ignore
                appn.arg = t
                t.parent = appn
                t.pslot = "arg"
                t = whnf(appn)  # may reduce (and returns detached)
                continue

            raise RuntimeError("Unknown frame tag.")

        else:
            # Finished rebuilding to root
            t.parent = None
            t.pslot = None
            return t

# -----------------------------
# Utilities
# -----------------------------

def node_count(t: Node) -> int:
    seen = set()
    stack = [t]
    n = 0
    while stack:
        x = stack.pop()
        if id(x) in seen:  # shouldn’t happen unless you add sharing manually
            continue
        seen.add(id(x))
        n += 1
        if isinstance(x, Abs):
            stack.append(x.body)
        elif isinstance(x, App):
            stack.append(x.arg)
            stack.append(x.fn)
    return n

def to_string(t: Node) -> str:
    """
    Serialize back to fully-parenthesized syntax:
      Var          => x
      Abs          => (\\x.BODY)
      App          => (F A)

    Note: printing very large terms is inherently expensive (output size).
    """
    env: Dict[int, str] = {}
    used: Dict[str, int] = {}

    def fresh(base: str) -> str:
        k = used.get(base, 0)
        used[base] = k + 1
        return base if k == 0 else f"{base}{k}"

    out: List[str] = []
    # action kinds: 'str', 'node', 'pop'
    actions: List[Tuple[str, object]] = [("node", t)]

    while actions:
        kind, payload = actions.pop()

        if kind == "str":
            out.append(payload)  # type: ignore
            continue

        if kind == "pop":
            bid = payload  # type: ignore
            env.pop(bid, None)
            continue

        # kind == "node"
        n = payload  # type: ignore
        if isinstance(n, Var):
            if n.binder is None:
                out.append(n.name)
            else:
                out.append(env.get(n.binder.id, n.binder.name))
            continue

        if isinstance(n, Abs):
            b = n.binder
            name = env.get(b.id)
            if name is None:
                name = fresh(b.name)
                env[b.id] = name
            actions.append(("pop", b.id))
            actions.append(("str", ")"))
            actions.append(("node", n.body))
            actions.append(("str", "."))
            actions.append(("str", name))
            actions.append(("str", "\\"))
            actions.append(("str", "("))
            continue

        if isinstance(n, App):
            actions.append(("str", ")"))
            actions.append(("node", n.arg))
            actions.append(("str", " "))
            actions.append(("node", n.fn))
            actions.append(("str", "("))
            continue

        raise TypeError("Unknown node type in to_string().")

    return "".join(out)

def parse_and_normalize(src: str) -> Node:
    return normalize(parse(src))

if __name__ == "__main__":
    import sys, time
    if len(sys.argv) < 2:
        print("usage: python3 planar_nf.py '<term>'")
        raise SystemExit(2)
    term_src = sys.argv[1]
    t0 = time.perf_counter()
    ast = parse(term_src)
    t1 = time.perf_counter()
    nf = normalize(ast)
    t2 = time.perf_counter()
    print(f"parsed nodes: {node_count(ast)}  parse_s={t1-t0:.4f}")
    print(f"nf nodes:     {node_count(nf)}   norm_s={t2-t1:.4f}")
    # Beware: printing giant terms is slow
    # print(to_string(nf))
    ```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment