Created
May 16, 2023 13:30
-
-
Save Adam-Vandervorst/a2a18eca4fa5c8561303db70628efb75 to your computer and use it in GitHub Desktop.
Term Unification in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from typing import Optional | |
| class Term: | |
| pass | |
| @dataclass | |
| class Var(Term): | |
| s: str | |
| @dataclass | |
| class Sym(Term): | |
| s: str | |
| @dataclass | |
| class Expr(Term): | |
| ts: list[Term] | |
| Knowledge = dict[str, Term] | |
| # Modify or introduce a binding `v <- t` | |
| def mod_bind(d: Knowledge, v: str, t: Term) -> Knowledge: | |
| d_ = d.copy() | |
| d_[v] = t | |
| return d_ | |
| # Recursively looks up the value of variable `v` (if any) | |
| def lookup(d: Knowledge, v: str) -> Optional[Term]: | |
| r = d.get(v) | |
| return r and walk(d, r) | |
| # If `t` is a var or contains a var look it up recursively. | |
| def walk(d: Knowledge, t: Term) -> Optional[Term]: | |
| if isinstance(t, Expr): | |
| return Expr([walk(d, x) or x for x in t.ts]) | |
| elif isinstance(t, Var): | |
| return lookup(d, t.s) | |
| else: | |
| return t | |
| # Unify t1 with t2 using/improving the passed knowledge. | |
| def unify(t1: Term, t2: Term, knowledge: Optional[Knowledge] = None) -> Optional[Knowledge]: | |
| if knowledge is None: | |
| knowledge = {} | |
| if isinstance(t1, Var) and isinstance(t2, Var): | |
| if t1.s == t2.s: | |
| return knowledge # does global variable equality | |
| else: | |
| return bind(t1.s, t2, knowledge) # left-biased binding | |
| elif isinstance(t1, Var): | |
| return bind(t1.s, t2, knowledge) | |
| elif isinstance(t2, Var): | |
| return bind(t2.s, t1, knowledge) | |
| elif isinstance(t1, Expr) and isinstance(t2, Expr) and len(t1.ts) == len(t2.ts): | |
| k = knowledge | |
| for l, r in zip(t1.ts, t2.ts): | |
| k = unify(l, r, k) | |
| return k | |
| elif t1 == t2: | |
| return knowledge | |
| else: | |
| return None | |
| # Trying to unify `$x` with `Expr($x, b, c)` should fail as this is equivalent to infinite regress. | |
| def bind(v: str, t: Term, knowledge: Knowledge) -> Optional[Knowledge]: | |
| if occurs_check(v, knowledge, t): | |
| return None | |
| else: | |
| r = lookup(knowledge, v) | |
| if r is None: | |
| return mod_bind(knowledge, v, t) | |
| else: | |
| nk = unify(t, r, knowledge) | |
| return nk and mod_bind(nk, v, t) | |
| # Recursively checks if `v` occurs in `t` using the supplied knowledge. | |
| def occurs_check(v: str, knowledge: Knowledge, t: Term) -> bool: | |
| def reachlist(l: list[str]) -> list[str]: | |
| return l + [e for r in l for e in reachable(r)] | |
| def reachable(v: str) -> list[str]: | |
| r = lookup(knowledge, v) | |
| return reachlist([] if r is None else tvars(r)) | |
| # v may be aliased to some other variable | |
| kv = lookup(knowledge, v) | |
| aliased = kv.s if kv is not None and isinstance(kv, Var) else v | |
| return aliased in reachlist(tvars(t)) | |
| def tvars(t: Term) -> list[str]: | |
| if isinstance(t, Expr): | |
| return [e for r in t.ts for e in tvars(r)] | |
| elif isinstance(t, Var): | |
| return [t.s] | |
| else: | |
| return [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from unify import Var, Expr, Sym, unify | |
| import unittest | |
| class UnifyTest(unittest.TestCase): | |
| def test_basic(self): | |
| self.assertEqual( | |
| unify(Expr([Var("x"), Var("y")]), | |
| Expr([Var("x'"), Sym("a")])), | |
| {"x": Var("x'"), "y": Sym("a")} | |
| ) | |
| self.assertEqual( | |
| unify(Expr([Var("x"), Sym("a"), Var("x")]), | |
| Expr([Expr([Sym("a"), Sym("b")]), Var("y"), Expr([Var("y"), Sym("b")])])), | |
| {"x": Expr([Var("y"), Sym("b")]), "y": Sym("a")} | |
| ) | |
| self.assertEqual( | |
| unify(Expr([Sym("A"), Expr([Sym("B"), Var("v")]), Expr([Sym("C"), Var("u"), Var("v")])]), | |
| Expr([Sym("A"), Expr([Sym("B"), Var("w")]), Expr([Sym("C"), Var("w"), Expr([Sym("f"), Var("x"), Var("y")])])])), | |
| {"v": Expr([Sym("f"), Var("x"), Var("y")]), "u": Var("w")} | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment