Created
January 14, 2025 11:37
-
-
Save Guest0x0/e0c8c8fd0974bec98aa081a2016d8570 to your computer and use it in GitHub Desktop.
STLC type inference with type graph optimization via an explicit substitution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| type type_name = string | |
| type var = string | |
| type term = | |
| | Atom of { type_name : type_name } | |
| | Var of var | |
| | Lam of var * term | |
| | App of term * term | |
| type tvar_id = int | |
| type typ = | |
| | TVar of tvar_id | |
| | TAtom of string | |
| | TFunc of typ * typ | |
| type tvar_status = | |
| | Unsolved | |
| | Solved of typ | |
| module Subst = Hashtbl | |
| type subst = (tvar_id, tvar_status) Subst.t | |
| (* since we are maintaining a lazy substitution, | |
| before matching on a type, we must "force" it, | |
| and actually perform substitution for solved type variables. | |
| We are performing path compression here | |
| (if [A] is substituted to [B] and [B] is substituted to [C], | |
| when fetching [A], we let [A] point to [C] directly), | |
| so a new substitution is also returned. | |
| *) | |
| let rec get_type subst typ = | |
| match typ with | |
| | TVar tvar_id -> | |
| (match Subst.find subst tvar_id with | |
| | Unsolved -> typ | |
| | Solved typ' -> | |
| let typ' = get_type subst typ' in | |
| (* path compression *) | |
| Subst.replace subst tvar_id (Solved typ'); | |
| typ') | |
| | _ -> typ | |
| let rec check_occurence subst tvar typ = | |
| match get_type subst typ with | |
| | TVar tvar' -> | |
| if tvar = tvar' | |
| then failwith "occurence check" | |
| else () | |
| | TAtom _ -> () | |
| | TFunc (t1, t2) -> | |
| check_occurence subst tvar t1; | |
| check_occurence subst tvar t2 | |
| let rec unify subst ty1 ty2 = | |
| let ty1 = get_type subst ty1 in | |
| let ty2 = get_type subst ty2 in | |
| match ty1, ty2 with | |
| | TVar tv1, TVar tv2 when tv1 = tv2 -> | |
| () | |
| | TVar tvar, typ | typ, TVar tvar -> | |
| check_occurence subst tvar typ; | |
| Subst.replace subst tvar (Solved typ) | |
| | TAtom a1, TAtom a2 when a1 = a2 -> | |
| () | |
| | TFunc (t11, t12), TFunc (t21, t22) -> | |
| unify subst t11 t21; | |
| unify subst t12 t22 | |
| | _ -> failwith "type mismatch" | |
| let tvar_id = ref 0 | |
| let fresh_tvar subst = | |
| let id = !tvar_id in | |
| incr tvar_id; | |
| Subst.add subst id Unsolved; | |
| id | |
| module Env = Map.Make(struct | |
| type t = var | |
| let compare = String.compare | |
| end) | |
| let rec infer subst env expr = | |
| match expr with | |
| | Atom { type_name } -> TAtom type_name | |
| | Var var -> | |
| (match Env.find_opt var env with | |
| | Some typ -> typ | |
| | None -> failwith "undefined variable") | |
| | Lam (x, body) -> | |
| let t1 = TVar (fresh_tvar subst) in | |
| let t2 = infer subst (Env.add x t1 env) body in | |
| TFunc (t1, t2) | |
| | App (f, a) -> | |
| let f_typ = infer subst env f in | |
| let a_typ = infer subst env a in | |
| let ty_result = TVar (fresh_tvar subst) in | |
| unify subst f_typ (TFunc (a_typ, ty_result)); | |
| ty_result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment