Skip to content

Instantly share code, notes, and snippets.

@rnagasam
Created September 3, 2025 22:25
Show Gist options
  • Select an option

  • Save rnagasam/1d34b747aa5ae0593552aa3661876838 to your computer and use it in GitHub Desktop.

Select an option

Save rnagasam/1d34b747aa5ae0593552aa3661876838 to your computer and use it in GitHub Desktop.
Small sketch; checking refinement types in Python
import ast
import inspect
import z3
from collections import namedtuple
from dataclasses import dataclass
from typing import Union, Callable, Tuple
# ------------------------------------------------------------------------
# Types -- primitives and refinement types
# ------------------------------------------------------------------------
Type = Union['TPrim', 'TSub']
@dataclass(frozen=True)
class TPrim:
name: str
@dataclass(frozen=True)
class TSub:
base: Type
pred: Callable[z3.ExprRef, z3.BoolRef]
TInt = TPrim('Int')
TBool = TPrim('Bool')
TReal = TPrim('Real')
Nat = TSub(TInt, lambda n: n >= 0)
Posint = TSub(TInt, lambda n: n > 0)
Posnat = TSub(Nat, lambda n: n > 0)
Even = TSub(TInt, lambda n: n % 2 == 0)
Odd = TSub(TInt, lambda n: n % 2 == 1)
def is_square(n):
m = z3.Int('m')
return z3.Exists([m], n == m*m)
Square = TSub(Nat, is_square)
# ------------------------------------------------------------------------
# Example functions, annotated with types. We will build a small
# checker that can statically verify that the function implementations
# have the annotated type.
# ------------------------------------------------------------------------
def succ(n: Nat) -> Posint:
return n+1
def plus2(n: Even) -> Even:
return n+2
def plus_odd(o1: Odd, o2: Odd) -> Even:
return o1 + o2
def my_abs(o: TInt) -> Nat:
return (0-o) if o < 0 else o
def sqr(n: TInt) -> Square:
return n*n
# ------------------------------------------------------------------------
# Checker
# ------------------------------------------------------------------------
prim_types_interp = {
TInt: z3.Int,
TBool: z3.Bool,
TReal: z3.Real
}
def regular_arguments(fdef: ast.FunctionDef) -> list[ast.arg]:
return fdef.args.args
def get_annotation(arg: ast.arg) -> ast.Name:
return arg.annotation
def fresh_var():
c = 0
def gen(prefix=None):
nonlocal c
c += 1
if prefix:
return prefix + str(c)
return 'x' + str(c)
return gen
gensym = fresh_var()
def gen_type_constraint(ty: Type) -> [z3.ExprRef,z3.BoolRef]:
if isinstance(ty, TPrim):
return (prim_types_interp[ty](gensym()), z3.BoolVal(True))
elif isinstance(ty, TSub):
exp, cnst = gen_type_constraint(ty.base)
return (exp, z3.simplify(z3.And(ty.pred(exp), cnst)))
else:
raise Exception(f'Invalid input: {ty}')
def get_function_type(f: ast.FunctionDef) -> Tuple[list[str,Type],Type]:
args = regular_arguments(f)
names = map(lambda e: e.arg, args)
inp = zip(names, map(lambda e: get_annotation(e).id, args))
ret = f.returns.id
return (list(inp), ret)
def get_function_def(f: str) -> ast.FunctionDef:
func = eval(f)
return ast.parse(inspect.getsource(func)).body[0]
class ExprToZ3(ast.NodeVisitor):
def __init__(self, context: dict[str,z3.ExprRef] = {}):
self.context = context
self.expr = None
def to_z3(self, node):
self.visit(node)
return self.expr
def visit_Name(self, node):
if node.id in self.context:
self.expr = self.context[node.id]
else:
raise Exception(f'Variable {node.id} not bound')
self.generic_visit(node)
def visit_BoolOp(self, node):
self.generic_visit(node)
assert len(node.values) == 2
v1 = self.to_z3(node.values[0])
v2 = self.to_z3(node.values[1])
if isinstance(node.op, ast.And):
self.expr = z3.And(v1, v2)
elif isinstance(node.op, ast.Or):
self.expr = z3.Or(v1, v2)
else:
raise Exception(f'Unsupported operation: {node.op}')
def visit_Compare(self, node):
self.generic_visit(node)
assert len(node.ops) == 1
assert len(node.comparators) == 1
v1 = self.to_z3(node.left)
v2 = self.to_z3(node.comparators[0])
op = node.ops[0]
if isinstance(op, ast.Lt):
self.expr = (v1 < v2)
elif isinstance(op, ast.LtE):
self.expr = (v1 <= v2)
elif isinstance(op, ast.Gt):
self.expr = (v1 > v2)
elif isinstance(op, ast.GtE):
self.expr = (v1 >= v2)
elif isinstance(op, ast.GtE):
self.expr = (v1 >= v2)
elif isinstance(op, ast.Eq):
self.expr = (v1 == v2)
elif isinstance(op, ast.NotEq):
self.expr = (v1 != v2)
else:
raise Exception(f'Unsupported operation: {node.op}')
def visit_BinOp(self, node):
self.generic_visit(node)
v1 = self.to_z3(node.left)
v2 = self.to_z3(node.right)
if isinstance(node.op, ast.Add):
self.expr = (v1 + v2)
elif isinstance(node.op, ast.Sub):
self.expr = (v1 - v2)
elif isinstance(node.op, ast.Mult):
self.expr = (v1 * v2)
elif isinstance(node.op, ast.Div):
self.expr = (v1 / v2)
else:
raise Exception(f'Unsupported operation: {node.op}')
def visit_Constant(self, node):
v = node.value
if isinstance(v, bool):
self.expr = z3.BoolVal(v)
elif isinstance(v, int):
self.expr = z3.IntVal(v)
elif isinstance(v, float):
self.expr = z3.RealVal(v)
else:
raise Exception("Unsupported constant")
self.generic_visit(node)
def visit_IfExp(self, node):
self.generic_visit(node)
test = self.to_z3(node.test)
body = self.to_z3(node.body)
orelse = self.to_z3(node.orelse)
self.expr = z3.If(test, body, orelse)
class Constraints(ast.NodeVisitor):
def __init__(self):
self.context: dict[str,z3.ExprRef] = {}
self.tcc: z3.BoolRef = z3.BoolVal(True)
self.return_annot: Type = None
def visit_Return(self, node):
ret = ExprToZ3(self.context).to_z3(node.value)
if self.return_annot is None:
raise Exception('Unexpected')
v, cnst = gen_type_constraint(self.return_annot)
self.tcc = z3.Implies(z3.simplify(self.tcc), z3.substitute(cnst, (v, ret)))
print(f'TCC: {self.tcc}')
self.generic_visit(node)
def visit_FunctionDef(self, node):
func_name = node.name
args = regular_arguments(node)
annots = zip(args, map(lambda e: get_annotation(e).id, args))
for arg, ty in annots:
name = arg.arg
z3_var, cnsts = gen_type_constraint(eval(ty))
self.context[name] = z3_var
self.tcc = z3.And(cnsts, self.tcc)
self.return_annot = eval(node.returns.id)
self.generic_visit(node)
def print_ast_repr(t):
print(ast.dump(t, indent=2))
def prove(f):
s = z3.Solver()
s.add(z3.Not(f))
return s.check()
def check_fun(f):
c = Constraints()
c.visit(ast.parse(inspect.getsource(f)))
if prove(c.tcc) == z3.unsat:
print('ok')
else:
print('Failed to verify')
# ------------------------------------------------------------------------
# Testing
# ------------------------------------------------------------------------
check_fun(succ) # ok
check_fun(plus2) # ok
check_fun(plus_odd) # ok
check_fun(my_abs) # ok
check_fun(sqr) # ok
def incorrect(n: TInt) -> Posnat:
return n-1
check_fun(incorrect) # prints 'Failed to verify'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment