Created
September 3, 2025 22:25
-
-
Save rnagasam/1d34b747aa5ae0593552aa3661876838 to your computer and use it in GitHub Desktop.
Small sketch; checking refinement types in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import ast | |
| import inspect | |
| import z3 | |
| from collections import namedtuple | |
| from dataclasses import dataclass | |
| from typing import Union, Callable, Tuple | |
| # ------------------------------------------------------------------------ | |
| # Types -- primitives and refinement types | |
| # ------------------------------------------------------------------------ | |
| Type = Union['TPrim', 'TSub'] | |
| @dataclass(frozen=True) | |
| class TPrim: | |
| name: str | |
| @dataclass(frozen=True) | |
| class TSub: | |
| base: Type | |
| pred: Callable[z3.ExprRef, z3.BoolRef] | |
| TInt = TPrim('Int') | |
| TBool = TPrim('Bool') | |
| TReal = TPrim('Real') | |
| Nat = TSub(TInt, lambda n: n >= 0) | |
| Posint = TSub(TInt, lambda n: n > 0) | |
| Posnat = TSub(Nat, lambda n: n > 0) | |
| Even = TSub(TInt, lambda n: n % 2 == 0) | |
| Odd = TSub(TInt, lambda n: n % 2 == 1) | |
| def is_square(n): | |
| m = z3.Int('m') | |
| return z3.Exists([m], n == m*m) | |
| Square = TSub(Nat, is_square) | |
| # ------------------------------------------------------------------------ | |
| # Example functions, annotated with types. We will build a small | |
| # checker that can statically verify that the function implementations | |
| # have the annotated type. | |
| # ------------------------------------------------------------------------ | |
| def succ(n: Nat) -> Posint: | |
| return n+1 | |
| def plus2(n: Even) -> Even: | |
| return n+2 | |
| def plus_odd(o1: Odd, o2: Odd) -> Even: | |
| return o1 + o2 | |
| def my_abs(o: TInt) -> Nat: | |
| return (0-o) if o < 0 else o | |
| def sqr(n: TInt) -> Square: | |
| return n*n | |
| # ------------------------------------------------------------------------ | |
| # Checker | |
| # ------------------------------------------------------------------------ | |
| prim_types_interp = { | |
| TInt: z3.Int, | |
| TBool: z3.Bool, | |
| TReal: z3.Real | |
| } | |
| def regular_arguments(fdef: ast.FunctionDef) -> list[ast.arg]: | |
| return fdef.args.args | |
| def get_annotation(arg: ast.arg) -> ast.Name: | |
| return arg.annotation | |
| def fresh_var(): | |
| c = 0 | |
| def gen(prefix=None): | |
| nonlocal c | |
| c += 1 | |
| if prefix: | |
| return prefix + str(c) | |
| return 'x' + str(c) | |
| return gen | |
| gensym = fresh_var() | |
| def gen_type_constraint(ty: Type) -> [z3.ExprRef,z3.BoolRef]: | |
| if isinstance(ty, TPrim): | |
| return (prim_types_interp[ty](gensym()), z3.BoolVal(True)) | |
| elif isinstance(ty, TSub): | |
| exp, cnst = gen_type_constraint(ty.base) | |
| return (exp, z3.simplify(z3.And(ty.pred(exp), cnst))) | |
| else: | |
| raise Exception(f'Invalid input: {ty}') | |
| def get_function_type(f: ast.FunctionDef) -> Tuple[list[str,Type],Type]: | |
| args = regular_arguments(f) | |
| names = map(lambda e: e.arg, args) | |
| inp = zip(names, map(lambda e: get_annotation(e).id, args)) | |
| ret = f.returns.id | |
| return (list(inp), ret) | |
| def get_function_def(f: str) -> ast.FunctionDef: | |
| func = eval(f) | |
| return ast.parse(inspect.getsource(func)).body[0] | |
| class ExprToZ3(ast.NodeVisitor): | |
| def __init__(self, context: dict[str,z3.ExprRef] = {}): | |
| self.context = context | |
| self.expr = None | |
| def to_z3(self, node): | |
| self.visit(node) | |
| return self.expr | |
| def visit_Name(self, node): | |
| if node.id in self.context: | |
| self.expr = self.context[node.id] | |
| else: | |
| raise Exception(f'Variable {node.id} not bound') | |
| self.generic_visit(node) | |
| def visit_BoolOp(self, node): | |
| self.generic_visit(node) | |
| assert len(node.values) == 2 | |
| v1 = self.to_z3(node.values[0]) | |
| v2 = self.to_z3(node.values[1]) | |
| if isinstance(node.op, ast.And): | |
| self.expr = z3.And(v1, v2) | |
| elif isinstance(node.op, ast.Or): | |
| self.expr = z3.Or(v1, v2) | |
| else: | |
| raise Exception(f'Unsupported operation: {node.op}') | |
| def visit_Compare(self, node): | |
| self.generic_visit(node) | |
| assert len(node.ops) == 1 | |
| assert len(node.comparators) == 1 | |
| v1 = self.to_z3(node.left) | |
| v2 = self.to_z3(node.comparators[0]) | |
| op = node.ops[0] | |
| if isinstance(op, ast.Lt): | |
| self.expr = (v1 < v2) | |
| elif isinstance(op, ast.LtE): | |
| self.expr = (v1 <= v2) | |
| elif isinstance(op, ast.Gt): | |
| self.expr = (v1 > v2) | |
| elif isinstance(op, ast.GtE): | |
| self.expr = (v1 >= v2) | |
| elif isinstance(op, ast.GtE): | |
| self.expr = (v1 >= v2) | |
| elif isinstance(op, ast.Eq): | |
| self.expr = (v1 == v2) | |
| elif isinstance(op, ast.NotEq): | |
| self.expr = (v1 != v2) | |
| else: | |
| raise Exception(f'Unsupported operation: {node.op}') | |
| def visit_BinOp(self, node): | |
| self.generic_visit(node) | |
| v1 = self.to_z3(node.left) | |
| v2 = self.to_z3(node.right) | |
| if isinstance(node.op, ast.Add): | |
| self.expr = (v1 + v2) | |
| elif isinstance(node.op, ast.Sub): | |
| self.expr = (v1 - v2) | |
| elif isinstance(node.op, ast.Mult): | |
| self.expr = (v1 * v2) | |
| elif isinstance(node.op, ast.Div): | |
| self.expr = (v1 / v2) | |
| else: | |
| raise Exception(f'Unsupported operation: {node.op}') | |
| def visit_Constant(self, node): | |
| v = node.value | |
| if isinstance(v, bool): | |
| self.expr = z3.BoolVal(v) | |
| elif isinstance(v, int): | |
| self.expr = z3.IntVal(v) | |
| elif isinstance(v, float): | |
| self.expr = z3.RealVal(v) | |
| else: | |
| raise Exception("Unsupported constant") | |
| self.generic_visit(node) | |
| def visit_IfExp(self, node): | |
| self.generic_visit(node) | |
| test = self.to_z3(node.test) | |
| body = self.to_z3(node.body) | |
| orelse = self.to_z3(node.orelse) | |
| self.expr = z3.If(test, body, orelse) | |
| class Constraints(ast.NodeVisitor): | |
| def __init__(self): | |
| self.context: dict[str,z3.ExprRef] = {} | |
| self.tcc: z3.BoolRef = z3.BoolVal(True) | |
| self.return_annot: Type = None | |
| def visit_Return(self, node): | |
| ret = ExprToZ3(self.context).to_z3(node.value) | |
| if self.return_annot is None: | |
| raise Exception('Unexpected') | |
| v, cnst = gen_type_constraint(self.return_annot) | |
| self.tcc = z3.Implies(z3.simplify(self.tcc), z3.substitute(cnst, (v, ret))) | |
| print(f'TCC: {self.tcc}') | |
| self.generic_visit(node) | |
| def visit_FunctionDef(self, node): | |
| func_name = node.name | |
| args = regular_arguments(node) | |
| annots = zip(args, map(lambda e: get_annotation(e).id, args)) | |
| for arg, ty in annots: | |
| name = arg.arg | |
| z3_var, cnsts = gen_type_constraint(eval(ty)) | |
| self.context[name] = z3_var | |
| self.tcc = z3.And(cnsts, self.tcc) | |
| self.return_annot = eval(node.returns.id) | |
| self.generic_visit(node) | |
| def print_ast_repr(t): | |
| print(ast.dump(t, indent=2)) | |
| def prove(f): | |
| s = z3.Solver() | |
| s.add(z3.Not(f)) | |
| return s.check() | |
| def check_fun(f): | |
| c = Constraints() | |
| c.visit(ast.parse(inspect.getsource(f))) | |
| if prove(c.tcc) == z3.unsat: | |
| print('ok') | |
| else: | |
| print('Failed to verify') | |
| # ------------------------------------------------------------------------ | |
| # Testing | |
| # ------------------------------------------------------------------------ | |
| check_fun(succ) # ok | |
| check_fun(plus2) # ok | |
| check_fun(plus_odd) # ok | |
| check_fun(my_abs) # ok | |
| check_fun(sqr) # ok | |
| def incorrect(n: TInt) -> Posnat: | |
| return n-1 | |
| check_fun(incorrect) # prints 'Failed to verify' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment