Skip to content

Instantly share code, notes, and snippets.

@Epivalent
Created November 7, 2025 17:50
Show Gist options
  • Select an option

  • Save Epivalent/bd3e8d05be39267d8e71cf8f4a88f614 to your computer and use it in GitHub Desktop.

Select an option

Save Epivalent/bd3e8d05be39267d8e71cf8f4a88f614 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# THRML unified sampler (clean rewrite)
# - Perfect matching on planar cubic bridgeless graphs via 2-face-coloring + NAE vertex constraints.
# - Face K-coloring (default K=4) on the dual graph via Potts (penalize equal colors on adjacent faces).
# Supports two input modes:
# (1) Lambda term -> primal rotation system -> dual faces, Tutte embedding for plotting
# (2) Adjacency matrix (0/1 square text file) of the dual graph for direct face K-coloring
# Includes:
# * Block Gibbs with optional block partition via greedy vertex-coloring (--block-coloring)
# * Optional anneal ladder (--anneal, --beta-start/end, --phases, --sweeps-per-phase)
# * Adaptive enforcement: --enforce-proper (coloring) / --enforce-perfect (matching)
# * Plotting with adaptive radius and optional edge downsampling for big graphs
import math
import argparse
import time
from datetime import datetime
from collections import defaultdict, deque
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
# ---- lambda-term utilities (provided by user's lambda_map7.py) ----
from lambda_map7 import (
tokenize, Parser,
build_rotation_with_labels,
enumerate_faces, face_containing_edge,
tutte_embedding_with_largest_face, tutte_embedding_with_face, tutte_embedding_with_face_aligned,
)
# ---- THRML imports ----
from thrml.block_management import Block
from thrml.block_sampling import BlockGibbsSpec, sample_states, SamplingSchedule
from thrml.models.discrete_ebm import CategoricalEBMFactor, CategoricalGibbsConditional
from thrml.factor import FactorSamplingProgram
from thrml.pgm import CategoricalNode
# -------------------- Logging helpers --------------------
_t0 = time.time()
def _now(): return datetime.now().strftime("%H:%M:%S")
def _elapsed(): return f"{(time.time()-_t0):7.2f}s"
def log(msg, enable=True):
if enable:
print(f"[{_now()} | {_elapsed()}] {msg}", flush=True)
# -------------------- Geometry / constraints (lambda-term path) --------------------
def he2f_map(rotation):
"""Map half-edge -> face id by walking 'next-left' around faces."""
def build_halfedge_next_left(rotation):
nxt = {}
for u, nbrs in rotation.items():
for v in nbrs:
order = rotation[v]
idx = order.index(u)
w = order[(idx - 1) % len(order)]
nxt[(u, v)] = (v, w)
return nxt
nxt = build_halfedge_next_left(rotation)
visited=set(); faces_list=[]; he2f={}
for u,nbrs in rotation.items():
for v in nbrs:
he=(u,v)
if he in visited: continue
cyc=[]; cur=he
while cur not in visited:
visited.add(cur); cyc.append(cur); cur=nxt[cur]
fid=len(faces_list); faces_list.append(cyc)
for h in cyc: he2f[h]=fid
return he2f, faces_list
def vertex_face_triples(rotation):
"""For each primal vertex (cubic), return the 3 incident face ids (in rotation order)."""
he2f, faces_list = he2f_map(rotation)
V_triples = {}
F = len(faces_list)
for v, nbrs in rotation.items():
if len(nbrs) != 3:
raise ValueError("Graph must be cubic for this model")
f = tuple(he2f[(v, u)] for u in nbrs)
V_triples[v] = tuple(f)
return V_triples, faces_list
def dual_edges_from_rotation(rotation):
"""List of dual edges (unordered face pairs) from the primal rotation system."""
edges, faces, edge_faces = enumerate_faces(rotation)
seen=set(); pairs=[]
for (u,v),(lf,rf) in edge_faces.items():
a,b = (lf,rf) if lf<rf else (rf,lf)
if (a,b) not in seen:
seen.add((a,b)); pairs.append((a,b))
return pairs
def tutte_embedding_for_term(rotation, root_var_edge):
fid_guess, _ = face_containing_edge(rotation, root_var_edge)
if fid_guess is None:
pos, outer_face = tutte_embedding_with_largest_face(rotation)
else:
pos, outer_face = tutte_embedding_with_face(rotation, fid_guess)
# Align a boundary edge for stable orientation
outer_vertices = []
for (u,v) in enumerate_faces(rotation)[1][outer_face]:
if u not in outer_vertices:
outer_vertices.append(u)
Lb = len(outer_vertices)
top_v = max(outer_vertices, key=lambda vv: pos[vv][1])
tidx = outer_vertices.index(top_v)
left_v = outer_vertices[(tidx - 1) % Lb]
pos, outer_face = tutte_embedding_with_face_aligned(rotation, outer_face, (left_v, top_v))
return pos, outer_face
def dual_bfs_bitstring(labels, rotation, outer_face):
"""Produce a bitstring-like summary by BFS in the dual starting at the outer face (root '0')."""
edges, faces, edge_faces = enumerate_faces(rotation)
dual_adj = defaultdict(list)
for (u,v), (lf,rf) in edge_faces.items():
dual_adj[lf].append(rf); dual_adj[rf].append(lf)
order = []
seen = {outer_face}
q = deque([outer_face])
while q:
f = q.popleft()
for g in dual_adj[f]:
if g not in seen:
seen.add(g); order.append(g); q.append(g)
# hex digit per face label (works for K<=16) with leading 0 for root
return "0" + "".join(hex(int(labels[f]))[2:] for f in order)
# -------------------- Adjacency input (dual graph) --------------------
def parse_adj_matrix(path):
"""Read square 0/1 adjacency, with or without spaces between digits."""
import numpy as onp
rows = []
with open(path, "r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if not line: continue
if " " in line or " " in line:
toks = [t for t in line.replace(" "," ").split(" ") if t!=""]
else:
toks = list(line)
row = [int(x) for x in toks]
rows.append(row)
if not rows: raise ValueError("Empty adjacency file")
n = len(rows)
for r in rows:
if len(r) != n: raise ValueError("Adjacency must be square")
for v in r:
if v not in (0,1): raise ValueError("Adjacency entries must be 0/1")
A = onp.array(rows, dtype=onp.uint8)
onp.fill_diagonal(A, 0)
A = ((A + A.T) > 0).astype(onp.uint8) # symmetrize and binarize
return A
def adj_edges(A):
"""Return list of (i,j) with i<j where A[i,j]==1."""
n = A.shape[0]; pairs=[]
for i in range(n):
for j in range(i+1, n):
if A[i,j]: pairs.append((i,j))
return pairs
def greedy_vertex_coloring(A):
"""Greedy vertex coloring to produce large independent sets as Gibbs blocks."""
import numpy as onp
n = A.shape[0]
deg = A.sum(axis=1).astype(int)
order = list(onp.argsort(-deg)) # high-degree first
color_of = [-1]*n
classes = []
for v in order:
forbidden = {color_of[u] for u in range(n) if A[v,u] and color_of[u] != -1}
c = 0
while c in forbidden: c += 1
color_of[v] = c
if c >= len(classes): classes.append([v])
else: classes[c].append(v)
return classes # list of lists of vertices
def generic_layout(A, seed=0):
"""Spring layout (via networkx if available) or fallback to circle."""
try:
import networkx as nx
G = nx.from_numpy_array(A)
pos = nx.spring_layout(G, seed=seed, dim=2)
return {int(k):(float(v[0]), float(v[1])) for k,v in pos.items()}
except Exception:
import numpy as onp
n = A.shape[0]
ang = onp.linspace(0, 2*onp.pi, n, endpoint=False)
R = 1.0
return {i:(float(R*onp.cos(ang[i])), float(R*onp.sin(ang[i]))) for i in range(n)}
# -------------------- Plotting --------------------
def adaptive_radius(span, n_vertices, base_frac=0.012, min_frac=0.004):
"""Scale node radius as ~ 1/sqrt(n) with clamps for readability."""
import numpy as onp
n = max(int(n_vertices), 1)
scale = onp.sqrt(30.0 / n)
frac = base_frac * float(scale)
frac = float(min(base_frac, max(min_frac, frac)))
return frac * float(span)
def draw_term_coloring(rotation, faces_list, pos, labels, title, out_path,
palette=None, mono=False, mono_color="#9370DB", alpha=0.30):
import numpy as onp
fig = plt.figure(figsize=(8.6, 8.6)); ax = plt.gca()
ax.set_aspect("equal"); ax.set_xticks([]); ax.set_yticks([])
if mono:
for fid, cyc in enumerate(faces_list):
if int(labels[fid]) == 1:
pts = onp.array([pos[u] for (u,_) in cyc])
ax.fill(pts[:,0], pts[:,1], alpha=alpha, color=mono_color, zorder=0.1)
else:
if palette is None:
palette = ["#4c78a8","#f58518","#e45756","#72b7b2"]
for fid, cyc in enumerate(faces_list):
c = palette[int(labels[fid]) % len(palette)]
pts = onp.array([pos[u] for (u,_) in cyc])
ax.fill(pts[:,0], pts[:,1], alpha=alpha, color=c, zorder=0.1)
edges, _, _ = enumerate_faces(rotation)
for (u,v) in edges:
x1,y1 = pos[u]; x2,y2 = pos[v]
ax.plot([x1,x2],[y1,y2], lw=1.0, alpha=0.45, color="#444", zorder=0.8)
xs = [p[0] for p in pos.values()]; ys=[p[1] for p in pos.values()]
span = max(max(xs)-min(xs), max(ys)-min(ys))
R = adaptive_radius(span, len(rotation)); r = 0.6 * R
for v in rotation:
x,y = pos[v]
circ_out = plt.Circle((x,y), R, fc="black", ec="black", lw=1.0, zorder=2.0)
circ_in = plt.Circle((x,y), r, fc=(1,1,1,0), ec="none", zorder=2.1)
ax.add_patch(circ_out); ax.add_patch(circ_in)
ax.set_title(title, fontsize=10); plt.tight_layout(); plt.savefig(out_path, dpi=180); plt.close(fig)
def draw_term_matching(rotation, faces_list, pos, spins, matching, title, out_path,
face_alpha=0.30, face_color="#9370DB"):
import numpy as onp
fig = plt.figure(figsize=(8.6, 8.6)); ax = plt.gca()
ax.set_aspect("equal"); ax.set_xticks([]); ax.set_yticks([])
for fid, cyc in enumerate(faces_list):
if int(spins[fid]) == 1:
pts = onp.array([pos[u] for (u,_) in cyc])
ax.fill(pts[:,0], pts[:,1], alpha=face_alpha, color=face_color, zorder=0.1)
edges, _, _ = enumerate_faces(rotation)
for (u,v) in edges:
x1,y1 = pos[u]; x2,y2 = pos[v]
ax.plot([x1,x2],[y1,y2], lw=1.0, alpha=0.45, color="#444", zorder=0.8)
# draw matching as thick bands
def data_width(ax, lw_pts):
fig = ax.figure; dpi = fig.dpi
px = lw_pts * dpi / 72.0
xlim = ax.get_xlim(); ylim = ax.get_ylim()
x_per_px = (xlim[1]-xlim[0]) / ax.bbox.width
y_per_px = (ylim[1]-ylim[0]) / ax.bbox.height
data_per_px = math.sqrt((x_per_px**2 + y_per_px**2)/2.0)
return px * data_per_px
xs = [p[0] for p in pos.values()]; ys=[p[1] for p in pos.values()]
span = max(max(xs)-min(xs), max(ys)-min(ys))
_ = adaptive_radius(span, len(rotation))
w = 4.0 * data_width(ax, 1.0)
for (u,v) in matching:
p = np.array(pos[u]); q = np.array(pos[v])
d = q - p; L = np.hypot(d[0], d[1]) + 1e-12
n = np.array([-d[1], d[0]]) / L
A = p + n*(w/2); B = p - n*(w/2); C = q - n*(w/2); D = q + n*(w/2)
quad = np.vstack([A,B,C,D])
ax.fill(quad[:,0], quad[:,1], alpha=0.35, color="#ff4d4d", zorder=1.2)
ax.set_title(title, fontsize=10); plt.tight_layout(); plt.savefig(out_path, dpi=180); plt.close(fig)
def draw_adj_coloring(A, pos, labels, title, out_path, palette=None, alpha=0.9, node_size=None, max_edges=None):
import numpy as onp
if palette is None:
palette = ["#4c78a8","#f58518","#e45756","#72b7b2"]
fig = plt.figure(figsize=(8.6, 8.6)); ax = plt.gca()
ax.set_aspect("equal"); ax.set_xticks([]); ax.set_yticks([])
n = A.shape[0]
edges_drawn = 0
for i in range(n):
for j in range(i+1, n):
if A[i,j]:
if (max_edges is not None) and (edges_drawn >= max_edges):
break
x1,y1 = pos[i]; x2,y2 = pos[j]
ax.plot([x1,x2],[y1,y2], lw=1.0, alpha=0.25, color="#444", zorder=0.5)
edges_drawn += 1
if (max_edges is not None) and (edges_drawn >= max_edges):
break
xs = [p[0] for p in pos.values()]; ys = [p[1] for p in pos.values()]
span = max(max(xs)-min(xs), max(ys)-min(ys)) or 1.0
if node_size is None:
R = adaptive_radius(span, n)
else:
R = float(node_size) * span
r = 0.6 * R
for i in range(n):
x,y = pos[i]; col = palette[int(labels[i]) % len(palette)]
circ = plt.Circle((x,y), R, fc=col, ec="#222", lw=0.8, alpha=alpha, zorder=1.0)
ax.add_patch(circ)
inner = plt.Circle((x,y), r, fc=(1,1,1,0), ec="none", zorder=1.1)
ax.add_patch(inner)
ax.set_title(title, fontsize=10); plt.tight_layout(); plt.savefig(out_path, dpi=180); plt.close(fig)
# -------------------- THRML model construction --------------------
def build_prog_matching(V_triples, nodes, penalty=1.0, outer_node=None, clamp_penalty=25.0):
"""2-color faces; NAE at each vertex; softly clamp outer face to color 0 (symmetry)."""
a_nodes, b_nodes, c_nodes = [], [], []
for _, (f1, f2, f3) in V_triples.items():
a_nodes.append(nodes[f1]); b_nodes.append(nodes[f2]); c_nodes.append(nodes[f3])
K = 2
W = jnp.zeros((len(a_nodes), K, K, K), dtype=jnp.float32)
# Penalize all-0 and all-1 (NAE)
W = W.at[:, 0,0,0].set(-penalty)
W = W.at[:, 1,1,1].set(-penalty)
factors = [CategoricalEBMFactor([Block(a_nodes), Block(b_nodes), Block(c_nodes)], W)]
if outer_node is not None:
Wu = jnp.zeros((1, K), dtype=jnp.float32)
Wu = Wu.at[:, 1].set(-clamp_penalty) # prefer 0
factors.append(CategoricalEBMFactor([Block([outer_node])], Wu))
free_blocks = [Block([n]) for n in nodes]
spec = BlockGibbsSpec(free_blocks, [])
sampler = CategoricalGibbsConditional(K)
prog = FactorSamplingProgram(spec, [sampler for _ in spec.free_blocks], factors, [])
return prog, free_blocks, K
def build_prog_faceK(nodes, dual_pairs, K=4, penalty_equal=1.0, outer_node=None,
clamp_color=0, clamp_penalty=25.0, neighbor_node=None, neighbor_color=1, neighbor_penalty=10.0,
color_blocks=None):
"""Potts: penalize equal colors on adjacent faces; add soft clamps for symmetry breaking."""
a_nodes = [nodes[a] for (a,b) in dual_pairs]
b_nodes = [nodes[b] for (a,b) in dual_pairs]
W = jnp.zeros((len(dual_pairs), K, K), dtype=jnp.float32)
for k in range(K):
W = W.at[:, k, k].set(-penalty_equal) # penalize equality
factors = [CategoricalEBMFactor([Block(a_nodes), Block(b_nodes)], W)]
if outer_node is not None:
Wu = jnp.zeros((1, K), dtype=jnp.float32)
Wu = Wu.at[:, :].set(-clamp_penalty)
Wu = Wu.at[:, clamp_color].set(0.0)
factors.append(CategoricalEBMFactor([Block([outer_node])], Wu))
if neighbor_node is not None:
Wu2 = jnp.zeros((1, K), dtype=jnp.float32)
Wu2 = Wu2.at[:, :].set(-neighbor_penalty)
Wu2 = Wu2.at[:, neighbor_color].set(0.0)
factors.append(CategoricalEBMFactor([Block([neighbor_node])], Wu2))
if color_blocks:
free_blocks = [Block([nodes[i] for i in cls]) for cls in color_blocks]
else:
free_blocks = [Block([n]) for n in nodes]
spec = BlockGibbsSpec(free_blocks, [])
sampler = CategoricalGibbsConditional(K)
prog = FactorSamplingProgram(spec, [sampler for _ in spec.free_blocks], factors, [])
return prog, free_blocks, K
# -------------------- Sampling helpers --------------------
def normalize_results(out):
"""THRML returns nested (mem, results); unwrap the states tensor."""
res = out
if isinstance(out, (list, tuple)):
res = out[-1]
if isinstance(res, (list, tuple)):
res = res[0]
return res
def run_thrml_sampling(key, prog, free_blocks, all_nodes, n_chains, n_warmup, n_samples,
steps_per_sample=2, jit_chains=False, trace=False, init_states=None):
"""Run sampling; optionally vectorize chains with jit+vmap; accept per-block initial states."""
k_init, k_samp = jax.random.split(key, 2)
init_per_block = []
for bi, block in enumerate(free_blocks):
k_init, sub = jax.random.split(k_init, 2)
if init_states and bi < len(init_states) and init_states[bi] is not None:
arr = jnp.asarray(init_states[bi])
if arr.ndim == 1: arr = arr[None, :]
if arr.shape[0] != n_chains:
arr = jnp.broadcast_to(arr, (n_chains, arr.shape[1]))
else:
# default K=4 initializer; THRML ignores invalid categories via sampler range
arr = jax.random.randint(sub, (n_chains, len(block.nodes)), 0, 4, dtype=jnp.uint8)
init_per_block.append(arr.astype(jnp.uint8))
schedule = SamplingSchedule(n_warmup=n_warmup, n_samples=n_samples, steps_per_sample=steps_per_sample)
all_block = Block(all_nodes)
if jit_chains:
log("Compiling jit/vmap chain function...", trace)
def one_chain_sample(subkey, *init_blocks):
init_state_chain = [b for b in init_blocks]
out = sample_states(subkey, prog, schedule, init_state_chain, [], [all_block])
return normalize_results(out)
keys = jax.random.split(k_samp, n_chains)
try:
vmapped = jax.jit(jax.vmap(one_chain_sample, in_axes=(0,)+ (0,)*len(init_per_block)))
samples = vmapped(keys, *[arr for arr in init_per_block])
return np.array(samples)
except Exception as e:
log(f"jit/vmap failed ({e}); falling back to sequential loop.", trace)
keys = jax.random.split(k_samp, n_chains)
results = []
for ci in range(n_chains):
init_state_chain = [arr[ci] for arr in init_per_block]
out = sample_states(keys[ci], prog, schedule, init_state_chain, [], [all_block])
results.append(np.array(normalize_results(out)))
return np.stack(results, axis=0) # (chains, n_samples, total_nodes)
def pick_best_sample(samples_array, score_fn):
"""Return (best_vec, best_score) by scanning chains × samples with a given score function."""
best = None; best_score = None
C, S, D = samples_array.shape
for c in range(C):
for s in range(S):
vec = samples_array[c, s]
sc = score_fn(vec)
if (best is None) or (sc < best_score):
best = vec; best_score = sc
return best, float(best_score)
def anneal_ladder(key, build_prog_fn, nodes, phases, betas, sweeps_per_phase,
steps_per_sample=1, jit_chains=False, trace=False, precompile=False):
"""Run a short single-chain schedule at increasing β to produce a good init state."""
prog, free_blocks, K = build_prog_fn(betas[0])
if precompile:
schedule_pc = SamplingSchedule(n_warmup=1, n_samples=1, steps_per_sample=steps_per_sample)
all_block_pc = Block(nodes)
_ = sample_states(key, prog, schedule_pc, [jnp.zeros((len(b.nodes),), dtype=jnp.uint8) for b in free_blocks], [], [all_block_pc])
state_for_next = None
k = key
for i in range(phases):
beta_i = betas[i]
log(f"Anneal phase {i+1}/{phases} at β={beta_i}", trace)
if i > 0:
prog, free_blocks, K = build_prog_fn(beta_i)
init_states = None
if state_for_next is not None:
init_states = []
offset = 0
for b in free_blocks:
size = len(b.nodes)
vec = state_for_next[offset:offset+size]
init_states.append(vec)
offset += size
k, sub = jax.random.split(k, 2)
samples = run_thrml_sampling(sub, prog, free_blocks, nodes, n_chains=1,
n_warmup=sweeps_per_phase, n_samples=1,
steps_per_sample=steps_per_sample, jit_chains=jit_chains, trace=trace,
init_states=init_states)
state_for_next = samples[0, -1]
init_states = []
offset = 0
for b in free_blocks:
size = len(b.nodes)
vec = state_for_next[offset:offset+size]
init_states.append(vec[None, :])
offset += size
return init_states, free_blocks, K
# -------------------- CLI & main --------------------
DEF_TERM = r"λx. λy. λz. λw. λu. λv. λt. λp. λq. x (λr. λs. y (λm. λn. z (λo. w (λϕ. u (λψ. λω. v (λa. t (λb. p (λc. q (r ((s m) (n (o (ϕ ψ))))) (ω (a (b c)))))))))))"
def main():
ap = argparse.ArgumentParser(description="THRML sampler: perfect matching (2-color+NAE) or face 4-coloring (Potts)." )
ap.add_argument("--task", choices=["matching","face4"], default="matching")
ap.add_argument("--term", type=str, default=DEF_TERM, help="Planar lambda term (quoted)." )
ap.add_argument("--adj-file", type=str, default="", help="Square 0/1 adjacency for the dual graph (face coloring)." )
ap.add_argument("--K", type=int, default=4, help="Number of face colors for --task face4." )
# penalties (pre-β); we scale by β internally
ap.add_argument("--penalty", type=float, default=1.0, help="Penalty per violated constraint (pre-β)." )
ap.add_argument("--clamp-penalty", type=float, default=25.0, help="Penalty for symmetry clamps (pre-β)." )
ap.add_argument("--neighbor-penalty", type=float, default=10.0, help="Penalty for neighbor clamp (pre-β)." )
# anneal
ap.add_argument("--anneal", action="store_true", help="Use β ladder preconditioning." )
ap.add_argument("--beta-start", type=float, default=2.0)
ap.add_argument("--beta-end", type=float, default=8.0)
ap.add_argument("--phases", type=int, default=4)
ap.add_argument("--sweeps-per-phase", type=int, default=40)
# final sampling
ap.add_argument("--beta", type=float, default=6.0, help="β if no anneal; ignored if --anneal (we use beta-end)." )
ap.add_argument("--chains", type=int, default=1)
ap.add_argument("--warmup", type=int, default=150)
ap.add_argument("--samples", type=int, default=40)
ap.add_argument("--steps-per-sample", type=int, default=2)
ap.add_argument("--jit-chains", action="store_true")
ap.add_argument("--seed", type=int, default=0)
# adaptive
ap.add_argument("--enforce-proper", action="store_true", help="For face4: adapt β/penalty/sweeps to reach target conflicts." )
ap.add_argument("--target-conflicts", type=int, default=0)
ap.add_argument("--enforce-perfect", action="store_true", help="For matching: adapt to drive NAE violations to 0." )
ap.add_argument("--max-rounds", type=int, default=6)
ap.add_argument("--beta-mult", type=float, default=1.4)
ap.add_argument("--penalty-mult", type=float, default=1.25)
ap.add_argument("--warmup-mult", type=float, default=1.5)
ap.add_argument("--steps-inc", type=int, default=1)
ap.add_argument("--restarts", type=int, default=0)
# plotting
ap.add_argument("--png", type=str, default="thrml_out.png")
ap.add_argument("--energy", type=str, default="thrml_energy.png")
ap.add_argument("--trace", action="store_true")
ap.add_argument("--precompile", type=int, default=0)
ap.add_argument("--mono", action="store_true", help="Term face-coloring: shade faces of color==1 only." )
ap.add_argument("--skip-plot", action="store_true")
ap.add_argument("--plot-max-edges", type=int, default=None, help="Adjacency plotting: cap edges for speed." )
ap.add_argument("--block-coloring", action="store_true", help="Adjacency: use greedy independent-set blocks for faster sampling." )
args = ap.parse_args()
log("Start", args.trace)
key = jax.random.key(args.seed)
# -------------- Adjacency path (face coloring) --------------
if args.adj_file:
if args.task != "face4":
raise SystemExit("--adj-file is only supported with --task face4.")
log(f"Reading adjacency from {args.adj_file} ...", args.trace)
A = parse_adj_matrix(args.adj_file)
n = A.shape[0]
log(f"Adjacency loaded: n={n}, edges={int(A.sum()//2)}", args.trace)
K = int(args.K)
nodes = [CategoricalNode() for _ in range(n)]
# symmetry clamps: fix node 0 to color 0; fix one neighbor to color 1
neighbor_idx = None
for j in range(1, n):
if A[0, j]:
neighbor_idx = j; break
pairs = adj_edges(A)
classes = greedy_vertex_coloring(A) if args.block_coloring else None
def make_prog(beta_scale: float, penalty_scale: float):
return build_prog_faceK(
nodes, pairs, K=K,
penalty_equal=penalty_scale * beta_scale,
outer_node=nodes[0], clamp_color=0, clamp_penalty=args.clamp_penalty * beta_scale,
neighbor_node=(nodes[neighbor_idx] if neighbor_idx is not None else None),
neighbor_color=1, neighbor_penalty=args.neighbor_penalty * beta_scale,
color_blocks=classes
)
def conflicts_of(vec):
# number of monochromatic edges
return sum(1 for (a,b) in pairs if int(vec[a]) == int(vec[b]))
# Prep β/penalty based on anneal or fixed
beta_cur = (args.beta_end if args.anneal else args.beta)
penalty_cur = args.penalty
warmup_cur = args.warmup
steps_cur = args.steps_per_sample
init_states=None; free_blocks=None; _K=None
if args.anneal:
betas = np.linspace(args.beta_start, args.beta_end, num=args.phases).tolist()
log(f"Anneal ladder: betas={betas}, sweeps_per_phase={args.sweeps_per_phase}", args.trace)
start_ann = time.time()
def make_for_anneal(b): return make_prog(b, penalty_cur)
init_states, free_blocks, _K = anneal_ladder(key, make_for_anneal, nodes, args.phases, betas,
sweeps_per_phase=args.sweeps_per_phase,
steps_per_sample=1, jit_chains=False, trace=args.trace, precompile=False)
log(f"Anneal finished in {time.time()-start_ann:.2f}s", args.trace)
else:
prog, free_blocks, _K = make_prog(beta_cur, penalty_cur)
# Adaptive loop (if requested)
best_vec=None; best_conf=None
rounds = 1 if not args.enforce_proper else args.max_rounds
for ri in range(rounds):
if args.enforce_proper and ri > 0:
beta_cur *= args.beta_mult
penalty_cur *= args.penalty_mult
warmup_cur = int(max(warmup_cur * args.warmup_mult, warmup_cur + 1))
steps_cur += args.steps_inc
log(f"[adaptive] round {ri+1}: β={beta_cur:.3g}, penalty={penalty_cur:.3g}, warmup={warmup_cur}, steps={steps_cur}", args.trace)
prog, free_blocks, _K = make_prog(beta_cur, penalty_cur)
n_tries = max(1, args.restarts+1)
cand_vec=None; cand_conf=None
for ti in range(n_tries):
log(f"Sampling: round={ri+1}/{rounds}, try={ti+1}/{n_tries} | chains={args.chains}, warmup={warmup_cur}, samples={args.samples}, steps={steps_cur} ...", args.trace)
start_samp = time.time()
samples = run_thrml_sampling(key, prog, free_blocks, nodes, n_chains=args.chains,
n_warmup=warmup_cur, n_samples=args.samples,
steps_per_sample=steps_cur,
jit_chains=args.jit_chains, trace=args.trace,
init_states=init_states)
log(f"Sampling finished in {time.time()-start_samp:.2f}s", args.trace)
samples_array = np.array(samples)
vec, conf = pick_best_sample(samples_array, conflicts_of)
if (cand_vec is None) or (conf < cand_conf):
cand_vec, cand_conf = vec, conf
init_states = None # subsequent try restarts randomly
if (best_vec is None) or (cand_conf < best_conf):
best_vec, best_conf = cand_vec, cand_conf
log(f"[adaptive] best conflicts this round: {cand_conf}", args.trace)
if best_conf <= args.target_conflicts:
log(f"[adaptive] target reached: conflicts={best_conf}", args.trace)
break
final_labels = best_vec
conflicts = int(best_conf)
proper = (conflicts <= args.target_conflicts)
# Plot + trace (plot can be skipped or edge-limited for speed)
pos = generic_layout(A, seed=args.seed)
title = (f"THRML face-{K} coloring (adj): β≈{beta_cur:.2f}, penalty≈{penalty_cur:.2f} | "
f"conflicts={conflicts} | proper={proper}")
if not args.skip_plot:
draw_adj_coloring(A, pos, final_labels, title, args.png, max_edges=args.plot_max_edges)
log(f"Graph saved: {args.png}", args.trace)
try:
mean_conf = []
for sidx in range(samples_array.shape[1]):
cur = samples_array[:, sidx, :]
cs = []
for ch in range(cur.shape[0]):
cs.append(sum(1 for (a,b) in pairs if int(cur[ch,a])==int(cur[ch,b])))
mean_conf.append(np.mean(cs))
fig = plt.figure(figsize=(6,3.2)); ax = plt.gca()
ax.plot(mean_conf, lw=1.2); ax.set_xlabel("sample index"); ax.set_ylabel("mean conflicts")
ax.set_title("Sampling trace (adj face coloring)")
plt.tight_layout(); plt.savefig(args.energy, dpi=160); plt.close(fig)
log(f"Energy trace saved: {args.energy}", args.trace)
except Exception:
pass
print("---- Results (adjacency) ----")
print(f"Nodes: {n}, Edges: {int(A.sum()//2)}")
print(f"Conflicts (final): {int(conflicts)}, Proper coloring: {proper}")
print(f"Final β≈{beta_cur:.3g}, penalty≈{penalty_cur:.3g}, warmup={warmup_cur}, steps={steps_cur}")
if not args.skip_plot:
print(f"[OK] Saved graph to '{args.png}' and energy trace to '{args.energy}'.")
log("Done.", args.trace)
return
# -------------- Lambda-term path --------------
log("Parsing lambda term...", args.trace)
toks = tokenize(args.term); ast = Parser(toks).parse()
log("Parsed term.", args.trace)
log("Building rotation system...", args.trace)
rotation, kinds, tree_edges, root_tree, var_edges, root_var_edge, edge_labels, outer_param, outer_body_str, root_body_id, root_var_id = build_rotation_with_labels(ast, "vbp", "paf")
log(f"Rotation built: vertices={len(rotation)}", args.trace)
log("Enumerating faces & constraints...", args.trace)
V_triples, faces_list = vertex_face_triples(rotation)
n_faces = len(faces_list)
log(f"Faces={n_faces}, vertex-triples={len(V_triples)}", args.trace)
log("Computing Tutte embedding...", args.trace)
pos, outer_face = tutte_embedding_for_term(rotation, root_var_edge)
log(f"Embedding done. outer_face={outer_face}", args.trace)
log("Creating THRML nodes...", args.trace)
nodes = [CategoricalNode() for _ in range(n_faces)]
log("Nodes created.", args.trace)
if args.task == "matching":
# Program factory (β scales penalties)
def make_prog(beta_scale: float, penalty_scale: float):
return build_prog_matching(
V_triples, nodes,
penalty=penalty_scale * beta_scale,
outer_node=nodes[outer_face],
clamp_penalty=args.clamp_penalty * beta_scale
)
def score_of(vec):
# NAE violations (we want 0)
violated = 0
for (a,b,c) in V_triples.values():
s0,s1,s2 = int(vec[a]), int(vec[b]), int(vec[c])
if s0==s1==s2: violated += 1
return violated
target = 0
enforce = args.enforce_perfect
else: # face4 on lambda-term dual
K = int(args.K) if args.K else 4
pairs = dual_edges_from_rotation(rotation)
# Break symmetry by clamping outer_face to 0 and one neighbor to 1
neighbor = None
for (a,b) in pairs:
if a == outer_face: neighbor = nodes[b]; break
if b == outer_face: neighbor = nodes[a]; break
def make_prog(beta_scale: float, penalty_scale: float):
return build_prog_faceK(
nodes, pairs, K=K,
penalty_equal=penalty_scale * beta_scale,
outer_node=nodes[outer_face], clamp_color=0, clamp_penalty=args.clamp_penalty * beta_scale,
neighbor_node=neighbor, neighbor_color=1, neighbor_penalty=args.neighbor_penalty * beta_scale,
color_blocks=None
)
def score_of(vec):
return sum(1 for (a,b) in pairs if int(vec[a])==int(vec[b]))
target = args.target_conflicts
enforce = args.enforce_proper
# Initialize via anneal if requested
beta_cur = (args.beta_end if args.anneal else args.beta)
penalty_cur = args.penalty
warmup_cur = args.warmup
steps_cur = args.steps_per_sample
init_states=None; free_blocks=None; _K=None
if args.anneal:
betas = np.linspace(args.beta_start, args.beta_end, num=args.phases).tolist()
log(f"Anneal ladder: betas={betas}, sweeps_per_phase={args.sweeps_per_phase}", args.trace)
start_ann = time.time()
def make_for_anneal(b): return make_prog(b, penalty_cur)
init_states, free_blocks, _K = anneal_ladder(key, make_for_anneal, nodes, args.phases, betas,
sweeps_per_sample=1, jit_chains=False, trace=args.trace, precompile=False)
log(f"Anneal finished in {time.time()-start_ann:.2f}s", args.trace)
else:
prog, free_blocks, _K = make_prog(beta_cur, penalty_cur)
# Adaptive loop
best_vec=None; best_score=None
rounds = 1 if not enforce else args.max_rounds
for ri in range(rounds):
if enforce and ri > 0:
beta_cur *= args.beta_mult
penalty_cur *= args.penalty_mult
warmup_cur = int(max(warmup_cur * args.warmup_mult, warmup_cur + 1))
steps_cur += args.steps_inc
log(f"[adaptive] round {ri+1}: β={beta_cur:.3g}, penalty={penalty_cur:.3g}, warmup={warmup_cur}, steps={steps_cur}", args.trace)
prog, free_blocks, _K = make_prog(beta_cur, penalty_cur)
n_tries = max(1, args.restarts+1)
cand_vec=None; cand_score=None
for ti in range(n_tries):
log(f"Sampling: round={ri+1}/{rounds}, try={ti+1}/{n_tries} | chains={args.chains}, warmup={warmup_cur}, samples={args.samples}, steps={steps_cur} ...", args.trace)
start_samp = time.time()
samples = run_thrml_sampling(key, prog, free_blocks, nodes, n_chains=args.chains,
n_warmup=warmup_cur, n_samples=args.samples,
steps_per_sample=steps_cur,
jit_chains=args.jit_chains, trace=args.trace,
init_states=init_states)
log(f"Sampling finished in {time.time()-start_samp:.2f}s", args.trace)
samples_array = np.array(samples)
vec, sc = pick_best_sample(samples_array, score_of)
if (cand_vec is None) or (sc < cand_score):
cand_vec, cand_score = vec, sc
init_states = None # randomize next try
if (best_vec is None) or (cand_score < best_score):
best_vec, best_score = cand_vec, cand_score
log(f"[adaptive] best score this round: {cand_score}", args.trace)
if best_score <= target:
log(f"[adaptive] target reached: score={best_score}", args.trace)
break
final_labels = best_vec
if args.task == "matching":
violated = int(best_score)
edges_all, faces_all, edge_faces = enumerate_faces(rotation)
matching = set()
for (u,vx) in edges_all:
lf, rf = edge_faces[(u,vx)]
if int(final_labels[lf]) == int(final_labels[rf]):
matching.add(tuple(sorted((u,vx))))
deg = {vv:0 for vv in rotation}
for (u,vx) in matching:
deg[u]+=1; deg[vx]+=1
is_perfect = all(d == 1 for d in deg.values())
bitstring = dual_bfs_bitstring(final_labels, rotation, outer_face)
title = (f"THRML matching (term): β≈{beta_cur:.2f}, penalty≈{penalty_cur:.2f} | "
f"violations={violated} | perfect={is_perfect} | bitstring: {bitstring}")
if not args.skip_plot:
draw_term_matching(rotation, faces_list, pos, final_labels, matching, title, args.png)
log(f"Graph saved: {args.png}", args.trace)
# violations trace of last batch
try:
mean_viol = []
for sidx in range(samples_array.shape[1]):
cur = samples_array[:, sidx, :]
vs = []
for ch in range(cur.shape[0]):
cnt = 0
for (a,b,c) in V_triples.values():
if int(cur[ch,a])==int(cur[ch,b])==int(cur[ch,c]): cnt += 1
vs.append(cnt)
mean_viol.append(np.mean(vs))
fig = plt.figure(figsize=(6,3.2)); ax = plt.gca()
ax.plot(mean_viol, lw=1.2); ax.set_xlabel("sample index"); ax.set_ylabel("mean NAE violations")
ax.set_title("Sampling trace (matching)")
plt.tight_layout(); plt.savefig(args.energy, dpi=160); plt.close(fig)
log(f"Energy trace saved: {args.energy}", args.trace)
except Exception:
pass
print("---- Results (term/matching) ----")
print(f"Faces: {len(faces_list)}, Vertices: {len(rotation)}")
print(f"NAE violations (final): {int(violated)}, Perfect matching: {is_perfect}")
print(f"Matching edges: {len(matching)}")
print(f"Root face: {outer_face}")
if not args.skip_plot:
print(f"[OK] Saved graph to '{args.png}' and energy trace to '{args.energy}'.")
else:
K = int(args.K) if args.K else 4
pairs = dual_edges_from_rotation(rotation)
conflicts = int(best_score)
proper = (conflicts <= target)
title = (f"THRML face-{K} coloring (term): β≈{beta_cur:.2f}, penalty≈{penalty_cur:.2f} | "
f"conflicts={conflicts} | proper={proper}")
if not args.skip_plot:
draw_term_coloring(rotation, faces_list, pos, final_labels, title, args.png, mono=args.mono)
log(f"Graph saved: {args.png}", args.trace)
try:
mean_conf = []
for sidx in range(samples_array.shape[1]):
cur = samples_array[:, sidx, :]
cs = []
for ch in range(cur.shape[0]):
cs.append(sum(1 for (a,b) in pairs if int(cur[ch,a])==int(cur[ch,b])))
mean_conf.append(np.mean(cs))
fig = plt.figure(figsize=(6,3.2)); ax = plt.gca()
ax.plot(mean_conf, lw=1.2); ax.set_xlabel("sample index"); ax.set_ylabel("mean conflicts")
ax.set_title("Sampling trace (term face coloring)")
plt.tight_layout(); plt.savefig(args.energy, dpi=160); plt.close(fig)
log(f"Energy trace saved: {args.energy}", args.trace)
except Exception:
pass
print("---- Results (term/faceK) ----")
print(f"Faces: {len(faces_list)}, Vertices: {len(rotation)}")
print(f"Conflicts (final): {int(conflicts)}, Proper coloring: {proper}")
print(f"Root face: {outer_face}")
if not args.skip_plot:
print(f"[OK] Saved graph to '{args.png}' and energy trace to '{args.energy}'.")
log("Done.", args.trace)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment