-
-
Save LeeeeT/0c7790e7592b2e40c351af195757dbd9 to your computer and use it in GitHub Desktop.
HOAS CoC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from collections.abc import Callable | |
| from dataclasses import dataclass | |
| from typing import Any | |
| type Term[T] = Var[T] | Typ[T] | Lam[T] | Fun[T] | App[T] | |
| @dataclass(frozen=True) | |
| class Var[T]: | |
| var: T | |
| @dataclass(frozen=True) | |
| class Typ[T]: | |
| pass | |
| @dataclass(frozen=True) | |
| class Lam[T]: | |
| dom: Term[T] | |
| bod: Callable[[Term[T]], Term[T]] | |
| @dataclass(frozen=True) | |
| class Fun[T]: | |
| dom: Term[T] | |
| cod: Callable[[Term[T]], Term[T]] | |
| @dataclass(frozen=True) | |
| class App[T]: | |
| fun: Term[T] | |
| arg: Term[T] | |
| def show(term: Term[int], depth: int = 0) -> str: | |
| match term: | |
| case Var(var): | |
| return f"{var}" | |
| case Typ(): | |
| return "*" | |
| case Lam(dom, bod): | |
| return f"λ({depth} : {show(dom, depth)}). {show(bod(Var(depth)), depth + 1)}" | |
| case Fun(dom, cod): | |
| return f"Π({depth} : {show(dom, depth)}). {show(cod(Var(depth)), depth + 1)}" | |
| case App(fun, arg): | |
| fun = f"({show(fun, depth)})" if isinstance(fun, Lam) else show(fun, depth) | |
| arg = f"({show(arg, depth)})" if not isinstance(arg, Var) else show(arg, depth) | |
| return f"{fun} {arg}" | |
| def nf[T](term: Term[T]) -> Term[T]: | |
| match whnf(term): | |
| case Lam(dom, bod): | |
| return Lam(nf(dom), lambda arg: nf(bod(arg))) | |
| case Fun(dom, cod): | |
| return Fun(nf(dom), lambda arg: nf(cod(arg))) | |
| case App(fun, arg): | |
| return App(nf(fun), nf(arg)) | |
| case _: | |
| return term | |
| def whnf[T](term: Term[T]) -> Term[T]: | |
| match term: | |
| case App(fun, arg): | |
| match whnf(fun): | |
| case Lam(_, bod): | |
| return whnf(bod(arg)) | |
| case _: | |
| return term | |
| case _: | |
| return term | |
| def equal(first: Term[object], second: Term[object]) -> bool: | |
| match whnf(first), whnf(second): | |
| case Var(v1), Var(v2): | |
| return v1 is v2 | |
| case Typ(), Typ(): | |
| return True | |
| case Lam(dom1, bod1), Lam(dom2, bod2): | |
| v = Var(object()) | |
| return equal(dom1, dom2) and equal(bod1(v), bod2(v)) | |
| case Fun(dom1, cod1), Fun(dom2, cod2): | |
| v = Var(object()) | |
| return equal(dom1, dom2) and equal(cod1(v), cod2(v)) | |
| case App(fun1, arg1), App(fun2, arg2): | |
| return equal(fun1, fun2) and equal(arg1, arg2) | |
| case _: | |
| return False | |
| def infer(term: Term[Any], ctx: dict[Term[Any], Term[Any]] = {}) -> Term[Any]: | |
| match term: | |
| case Var(var): | |
| return ctx[Var(var)] | |
| case Typ(): | |
| return Typ() | |
| case Lam(dom, bod): | |
| return Fun(dom, lambda arg: infer(bod(arg), ctx | {arg: dom})) | |
| case Fun(dom, _): | |
| return Typ() | |
| case App(fun, arg): | |
| match whnf(infer(fun, ctx)): | |
| case Fun(dom, cod) if equal(infer(arg, ctx), dom): | |
| return cod(arg) | |
| case _: | |
| raise TypeError | |
| case _: | |
| raise TypeError | |
| Unit: Term[Any] = Fun(Typ(), lambda R: Fun(R, lambda unit: R)) | |
| unit: Term[Any] = Lam(Typ(), lambda R: Lam(R, lambda unit: unit)) | |
| Bool: Term[Any] = Fun(Typ(), lambda R: Fun(R, lambda true: Fun(R, lambda false: R))) | |
| true: Term[Any] = Lam(Typ(), lambda R: Lam(R, lambda true: Lam(R, lambda false: true))) | |
| false: Term[Any] = Lam(Typ(), lambda R: Lam(R, lambda true: Lam(R, lambda false: false))) | |
| negate: Term[Any] = Lam(Bool, lambda b: App(App(App(b, Bool), false), true)) | |
| main = App(negate, true) | |
| print(show(nf(main))) | |
| print(show(nf(infer(main)))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment