Skip to content

Instantly share code, notes, and snippets.

@Adam-Vandervorst
Created May 16, 2023 13:30
Show Gist options
  • Select an option

  • Save Adam-Vandervorst/a2a18eca4fa5c8561303db70628efb75 to your computer and use it in GitHub Desktop.

Select an option

Save Adam-Vandervorst/a2a18eca4fa5c8561303db70628efb75 to your computer and use it in GitHub Desktop.
Term Unification in Python
from dataclasses import dataclass
from typing import Optional
class Term:
pass
@dataclass
class Var(Term):
s: str
@dataclass
class Sym(Term):
s: str
@dataclass
class Expr(Term):
ts: list[Term]
Knowledge = dict[str, Term]
# Modify or introduce a binding `v <- t`
def mod_bind(d: Knowledge, v: str, t: Term) -> Knowledge:
d_ = d.copy()
d_[v] = t
return d_
# Recursively looks up the value of variable `v` (if any)
def lookup(d: Knowledge, v: str) -> Optional[Term]:
r = d.get(v)
return r and walk(d, r)
# If `t` is a var or contains a var look it up recursively.
def walk(d: Knowledge, t: Term) -> Optional[Term]:
if isinstance(t, Expr):
return Expr([walk(d, x) or x for x in t.ts])
elif isinstance(t, Var):
return lookup(d, t.s)
else:
return t
# Unify t1 with t2 using/improving the passed knowledge.
def unify(t1: Term, t2: Term, knowledge: Optional[Knowledge] = None) -> Optional[Knowledge]:
if knowledge is None:
knowledge = {}
if isinstance(t1, Var) and isinstance(t2, Var):
if t1.s == t2.s:
return knowledge # does global variable equality
else:
return bind(t1.s, t2, knowledge) # left-biased binding
elif isinstance(t1, Var):
return bind(t1.s, t2, knowledge)
elif isinstance(t2, Var):
return bind(t2.s, t1, knowledge)
elif isinstance(t1, Expr) and isinstance(t2, Expr) and len(t1.ts) == len(t2.ts):
k = knowledge
for l, r in zip(t1.ts, t2.ts):
k = unify(l, r, k)
return k
elif t1 == t2:
return knowledge
else:
return None
# Trying to unify `$x` with `Expr($x, b, c)` should fail as this is equivalent to infinite regress.
def bind(v: str, t: Term, knowledge: Knowledge) -> Optional[Knowledge]:
if occurs_check(v, knowledge, t):
return None
else:
r = lookup(knowledge, v)
if r is None:
return mod_bind(knowledge, v, t)
else:
nk = unify(t, r, knowledge)
return nk and mod_bind(nk, v, t)
# Recursively checks if `v` occurs in `t` using the supplied knowledge.
def occurs_check(v: str, knowledge: Knowledge, t: Term) -> bool:
def reachlist(l: list[str]) -> list[str]:
return l + [e for r in l for e in reachable(r)]
def reachable(v: str) -> list[str]:
r = lookup(knowledge, v)
return reachlist([] if r is None else tvars(r))
# v may be aliased to some other variable
kv = lookup(knowledge, v)
aliased = kv.s if kv is not None and isinstance(kv, Var) else v
return aliased in reachlist(tvars(t))
def tvars(t: Term) -> list[str]:
if isinstance(t, Expr):
return [e for r in t.ts for e in tvars(r)]
elif isinstance(t, Var):
return [t.s]
else:
return []
from unify import Var, Expr, Sym, unify
import unittest
class UnifyTest(unittest.TestCase):
def test_basic(self):
self.assertEqual(
unify(Expr([Var("x"), Var("y")]),
Expr([Var("x'"), Sym("a")])),
{"x": Var("x'"), "y": Sym("a")}
)
self.assertEqual(
unify(Expr([Var("x"), Sym("a"), Var("x")]),
Expr([Expr([Sym("a"), Sym("b")]), Var("y"), Expr([Var("y"), Sym("b")])])),
{"x": Expr([Var("y"), Sym("b")]), "y": Sym("a")}
)
self.assertEqual(
unify(Expr([Sym("A"), Expr([Sym("B"), Var("v")]), Expr([Sym("C"), Var("u"), Var("v")])]),
Expr([Sym("A"), Expr([Sym("B"), Var("w")]), Expr([Sym("C"), Var("w"), Expr([Sym("f"), Var("x"), Var("y")])])])),
{"v": Expr([Sym("f"), Var("x"), Var("y")]), "u": Var("w")}
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment