Last active
January 12, 2026 19:05
-
-
Save Jacob-Stevens-Haas/754b3c2ff388c34257bb89b3e114ba2e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Utilities to convert a `dysts` dynamical system object's rhs to SymPy. | |
| This module inspects the source of an object's RHS method (by default | |
| named ``rhs``), parses the function using ``ast``, and converts the | |
| returned expression(s) into SymPy expressions. | |
| The conversion is intentionally conservative and aims to handle common | |
| patterns used in simple rhs implementations, e.g. returning a tuple/list | |
| of arithmetic expressions, using indexing into a state vector (``x[0]``), | |
| and calls to common ``numpy``/``math`` functions (``np.sin``, ``math.exp``, ...). | |
| Limitations: | |
| - It does not execute arbitrary code from the inspected function. | |
| - Complex control flow, loops, or non-trivial Python constructs may not | |
| be fully supported. | |
| Example | |
| ------- | |
| from dysts.flows import Lorenz | |
| from inspect_to_sympy import object_to_sympy_rhs | |
| lor = Lorenz() | |
| symbols, exprs, lambda_rhs = object_to_sympy_rhs(lor) | |
| # `symbols` is a list of SymPy symbols for the state vector | |
| # `exprs` is a list of SymPy expressions for the RHS | |
| # `lambda_rhs` is a SymPy Lambda mapping state symbols -> rhs expressions | |
| """ | |
| from __future__ import annotations | |
| import ast | |
| import inspect | |
| import textwrap | |
| from typing import Any | |
| from typing import Callable | |
| from typing import Dict | |
| from typing import List | |
| from typing import Tuple | |
| import numpy as np | |
| import sympy as sp | |
| from dysts.base import BaseDyn | |
| def _is_name(node: ast.AST, name: str) -> bool: | |
| return isinstance(node, ast.Name) and node.id == name | |
| class _ASTToSympy(ast.NodeVisitor): | |
| def __init__( | |
| self, | |
| state_name: str, | |
| state_symbols: List[sp.Symbol], | |
| locals_map: Dict[str, Any], | |
| ): | |
| self.state_name = state_name | |
| self.state_symbols = state_symbols | |
| self.locals = dict(locals_map) | |
| def generic_visit(self, node): | |
| raise NotImplementedError(f"AST node not supported: {node!r}") | |
| def visit_Constant(self, node: ast.Constant): | |
| return sp.sympify(node.value) | |
| def visit_Num(self, node: ast.Num): | |
| return sp.sympify(node.n) | |
| def visit_Name(self, node: ast.Name): | |
| if node.id in self.locals: | |
| return self.locals[node.id] | |
| return sp.Symbol(node.id) | |
| def visit_Tuple(self, node: ast.Tuple): | |
| elems = [] | |
| for elt in node.elts: | |
| val = self.visit(elt) | |
| if isinstance(val, (list, tuple)): | |
| elems.extend(list(val)) | |
| else: | |
| elems.append(val) | |
| return tuple(elems) | |
| def visit_List(self, node: ast.List): | |
| elems = [] | |
| for elt in node.elts: | |
| val = self.visit(elt) | |
| if isinstance(val, (list, tuple)): | |
| elems.extend(list(val)) | |
| else: | |
| elems.append(val) | |
| return elems | |
| def visit_Starred(self, node: ast.Starred): | |
| # Handle starred expressions like `*x` in list/tuple literals. | |
| # If the starred value is the state vector name, expand to state symbols. | |
| if isinstance(node.value, ast.Name) and node.value.id == self.state_name: | |
| return tuple(self.state_symbols) | |
| # Otherwise, evaluate the value and if it is a sequence, return its items | |
| val = self.visit(node.value) | |
| if isinstance(val, (list, tuple)): | |
| return tuple(val) | |
| raise NotImplementedError( | |
| "Unsupported starred expression; cannot expand non-iterable" | |
| ) | |
| def visit_BinOp(self, node: ast.BinOp): | |
| left = self.visit(node.left) | |
| right = self.visit(node.right) | |
| if isinstance(node.op, ast.Add): | |
| return left + right | |
| if isinstance(node.op, ast.Sub): | |
| return left - right | |
| if isinstance(node.op, ast.Mult): | |
| return left * right | |
| if isinstance(node.op, ast.Div): | |
| return left / right | |
| if isinstance(node.op, ast.Pow): | |
| return left**right | |
| if isinstance(node.op, ast.Mod): | |
| return left % right | |
| raise NotImplementedError(f"Binary op not supported: {node.op!r}") | |
| def visit_UnaryOp(self, node: ast.UnaryOp): | |
| operand = self.visit(node.operand) | |
| if isinstance(node.op, ast.USub): | |
| return -operand | |
| if isinstance(node.op, ast.UAdd): | |
| return +operand | |
| raise NotImplementedError(f"Unary op not supported: {node.op!r}") | |
| def visit_Call(self, node: ast.Call): | |
| # Determine function name | |
| func = node.func | |
| func_name = None | |
| mod_name = None | |
| if isinstance(func, ast.Name): | |
| func_name = func.id | |
| elif isinstance(func, ast.Attribute): | |
| # e.g. np.sin or math.exp | |
| if isinstance(func.value, ast.Name): | |
| mod_name = func.value.id | |
| func_name = func.attr | |
| else: | |
| raise NotImplementedError( | |
| f"Call to unsupported func node: {ast.dump(func)}" | |
| ) | |
| # Map common numpy/math functions to sympy | |
| func_map = { | |
| "sin": sp.sin, | |
| "cos": sp.cos, | |
| "tan": sp.tan, | |
| "exp": sp.exp, | |
| "log": sp.log, | |
| "sqrt": sp.sqrt, | |
| "abs": sp.Abs, | |
| "atan": sp.atan, | |
| "asin": sp.asin, | |
| "acos": sp.acos, | |
| } | |
| args = [self.visit(a) for a in node.args] | |
| # Special-case array constructors: return underlying list/tuple | |
| if func_name in ("array", "asarray") and mod_name in ("np", "numpy"): | |
| # expect a single positional arg that's a list/tuple | |
| if len(args) == 1: | |
| return args[0] | |
| if func_name in func_map: | |
| return func_map[func_name](*args) | |
| # Unknown function: create a Sympy Function | |
| symf = sp.Function(func_name) | |
| return symf(*args) | |
| def visit_Subscript(self, node: ast.Subscript): | |
| # Support patterns like x[0] where x is the state vector name | |
| value = node.value | |
| # handle simple constant index | |
| if _is_name(value, self.state_name): | |
| # Python >=3.9: slice is directly the node.slice | |
| idx_node = node.slice | |
| if isinstance(idx_node, ast.Constant): | |
| idx = idx_node.value | |
| else: | |
| raise NotImplementedError( | |
| "Only constant indices into state vector supported" | |
| ) | |
| return self.state_symbols[idx] | |
| # If it's something else, try to evaluate generically | |
| base = self.visit(value) | |
| # slice may be constant | |
| if isinstance(node.slice, ast.Constant): | |
| key = node.slice.value | |
| return base[key] | |
| raise NotImplementedError("Unsupported subscript pattern") | |
| def _numeric_consistency_check( | |
| dysts_flow: BaseDyn, | |
| rhsfunc: Callable, | |
| arg_names: List[str], | |
| state_names: List[str], | |
| vector_mode: bool, | |
| sys_dim: int, | |
| lambda_rhs: sp.Lambda, | |
| ) -> None: | |
| """Compare the original dysts rhs function to the SymPy-derived lambda. | |
| Raises a RuntimeError if they disagree. | |
| """ | |
| # default to nonnegative support (e.g. Lotka volterra) | |
| random_state = np.random.standard_exponential(size=sys_dim) | |
| # Construct call arguments for the original function (bound method). | |
| call_args = [] | |
| for name in arg_names: | |
| if name == "self": | |
| continue | |
| if name in state_names and not vector_mode: | |
| idx = state_names.index(name) | |
| call_args.append(random_state[idx]) | |
| elif name in state_names and vector_mode: | |
| call_args.append(np.asarray(random_state, dtype=float)) | |
| elif name == "t": | |
| call_args.append(float(np.random.standard_normal(size=()))) | |
| else: | |
| call_args.append(dysts_flow.params[name]) | |
| dysts_val = rhsfunc(*call_args) | |
| orig_arr = np.asarray(dysts_val, dtype=float).ravel() | |
| sym_val = lambda_rhs(*tuple(random_state)) | |
| sym_arr = np.asarray(sym_val, dtype=float).ravel() | |
| if orig_arr.shape != sym_arr.shape: | |
| raise RuntimeError( | |
| f"_rhs shape {orig_arr.shape} != sympy shape {sym_arr.shape}" | |
| ) | |
| if not np.allclose(orig_arr, sym_arr, rtol=1e-6, atol=1e-9): | |
| raise RuntimeError("Numeric mismatch between original and sympy conversion.") | |
| def dynsys_to_sympy( | |
| obj: Any, func_name: str = "_rhs" | |
| ) -> Tuple[List[sp.Symbol], List[sp.Expr], sp.Lambda]: | |
| """Inspect ``obj`` for a method named ``func_name`` and return a SymPy | |
| representation of its RHS. | |
| Returns: | |
| a tuple ``(state_symbols, exprs, lambda_rhs)`` where ``state_symbols`` | |
| is a list of SymPy symbols for the state vector, ``exprs`` is a list of | |
| SymPy expressions for the RHS components, and ``lambda_rhs`` is a SymPy | |
| Lambda mapping the state symbols to the RHS vector. | |
| Example: | |
| >>> from dysts.flows import Lorenz | |
| >>> from inspect_to_sympy import dynsys_to_sympy | |
| >>> lor = Lorenz() | |
| >>> symbols, exprs, lambda_rhs = dynsys_to_sympy(lor) | |
| >>> print(lor._rhs(1, 2, 3, t=0.0, **lor.params)) | |
| (10, 23, -6.0009999999999994) | |
| >>> print(tuple(lambda_rhs(1, 2, 3))) | |
| (10, 23, -6.00100000000000) | |
| """ | |
| if not hasattr(obj, func_name): | |
| raise AttributeError(f"Object has no attribute {func_name!r}") | |
| func = getattr(obj, func_name) | |
| src = inspect.getsource(func) | |
| src = textwrap.dedent(src) | |
| parsed = ast.parse(src) | |
| # Find first FunctionDef | |
| fndef = None | |
| for node in parsed.body: | |
| if isinstance(node, ast.FunctionDef): | |
| fndef = node | |
| break | |
| if fndef is None: | |
| raise RuntimeError("No function definition found in source") | |
| # Determine state argument names. Common dysts signature: | |
| # (self, *states, t, *parameters). Prefer obj.dimension when available. | |
| arg_names = [a.arg for a in fndef.args.args] | |
| if len(arg_names) == 0: | |
| raise RuntimeError("Function has no arguments") | |
| start_idx = 0 | |
| if arg_names[0] == "self": | |
| start_idx = 1 | |
| vector_mode = False | |
| state_args: List[str] | |
| t_idx = None | |
| if "t" in arg_names: | |
| t_idx = arg_names.index("t") | |
| if hasattr(obj, "dimension") and isinstance(getattr(obj, "dimension"), int): | |
| n_state = int(getattr(obj, "dimension")) | |
| if t_idx is not None: | |
| potential = arg_names[start_idx:t_idx] | |
| if len(potential) >= n_state: | |
| state_args = potential[:n_state] | |
| else: | |
| state_args = [arg_names[start_idx]] | |
| vector_mode = True | |
| else: | |
| potential = arg_names[start_idx:] | |
| if len(potential) >= n_state: | |
| state_args = potential[:n_state] | |
| else: | |
| state_args = [arg_names[start_idx]] | |
| vector_mode = True | |
| else: | |
| if t_idx is not None: | |
| state_args = arg_names[start_idx:t_idx] | |
| if len(state_args) == 0: | |
| state_args = [arg_names[start_idx]] | |
| vector_mode = True | |
| elif len(state_args) == 1: | |
| # single name could be vector or scalar; assume vector-mode | |
| vector_mode = True | |
| else: | |
| state_args = [arg_names[start_idx]] | |
| vector_mode = True | |
| # If vector_mode, inspect AST for subscript/index usage or tuple unpacking | |
| if vector_mode: | |
| state_name = state_args[0] | |
| max_index = -1 | |
| unpack_size = None | |
| for node in ast.walk(fndef): | |
| if ( | |
| isinstance(node, ast.Subscript) | |
| and isinstance(node.value, ast.Name) | |
| and node.value.id == state_name | |
| ): | |
| sl = node.slice | |
| if isinstance(sl, ast.Constant) and isinstance(sl.value, int): | |
| if sl.value > max_index: | |
| max_index = sl.value | |
| if isinstance(node, ast.Assign): | |
| if isinstance(node.value, ast.Name) and node.value.id == state_name: | |
| targets = node.targets | |
| if len(targets) == 1 and isinstance( | |
| targets[0], (ast.Tuple, ast.List) | |
| ): | |
| unpack_size = len(targets[0].elts) | |
| if unpack_size is not None: | |
| n_state = unpack_size | |
| elif max_index >= 0: | |
| n_state = max_index + 1 | |
| else: | |
| n_state = int(getattr(obj, "dimension", 3)) | |
| state_symbols = [sp.Symbol(f"x{i}") for i in range(n_state)] | |
| primary_state_name = state_name | |
| else: | |
| # individual state args -> use their arg names as symbol names | |
| state_symbols = [sp.Symbol(n) for n in state_args] | |
| primary_state_name = state_args[0] if len(state_args) > 0 else "x" | |
| # Build locals mapping from known state arg names and parameters | |
| locals_map: Dict[str, Any] = {} | |
| for i, name in enumerate(state_args): | |
| if i < len(state_symbols): | |
| locals_map[name] = state_symbols[i] | |
| # map parameters (if present) to numeric values or symbols | |
| if hasattr(obj, "parameters") and isinstance(getattr(obj, "parameters"), dict): | |
| params = getattr(obj, "parameters") | |
| if t_idx is not None: | |
| param_arg_names = arg_names[t_idx + 1 :] | |
| else: | |
| param_arg_names = [] | |
| for pname in param_arg_names: | |
| if pname in params: | |
| locals_map[pname] = sp.sympify(params[pname]) | |
| else: | |
| locals_map[pname] = sp.Symbol(pname) | |
| converter = _ASTToSympy(primary_state_name, state_symbols, locals_map) | |
| return_expr = None | |
| # Walk through function body statements, handle Assign and Return | |
| for stmt in fndef.body: | |
| if isinstance(stmt, ast.Assign): | |
| # only simple single-target assignments supported | |
| if len(stmt.targets) != 1: | |
| raise ValueError("Only single-target assignments supported") | |
| target = stmt.targets[0] | |
| if isinstance(target, ast.Name): | |
| value_expr = converter.visit(stmt.value) | |
| locals_map[target.id] = value_expr | |
| elif ( | |
| isinstance(target, (ast.Tuple, ast.List)) | |
| and isinstance(stmt.value, ast.Name) | |
| and stmt.value.id == state_name | |
| ): | |
| # unpacking like a,b,c = x -> map names to state symbols | |
| for i, elt in enumerate(target.elts): | |
| if isinstance(elt, ast.Name): | |
| locals_map[elt.id] = state_symbols[i] | |
| elif isinstance(stmt, ast.Return): | |
| return_expr = stmt.value | |
| if return_expr is None: | |
| # maybe last statement is an Expr with list construction; | |
| # try to find a Return node deep | |
| for node in ast.walk(fndef): | |
| if isinstance(node, ast.Return): | |
| return_expr = node.value | |
| break | |
| if return_expr is None: | |
| raise RuntimeError("No return expression found in function body") | |
| # Refresh converter with updated locals | |
| converter = _ASTToSympy(primary_state_name, state_symbols, locals_map) | |
| rhs_val = converter.visit(return_expr) | |
| if isinstance(rhs_val, (list, tuple)): | |
| exprs = list(rhs_val) | |
| else: | |
| # single expression: treat as 1-dim RHS | |
| exprs = [rhs_val] | |
| lambda_rhs = sp.Lambda(tuple(state_symbols), sp.Matrix(exprs)) | |
| # Run numeric consistency guard (raises on mismatch) | |
| _numeric_consistency_check( | |
| obj, | |
| func, | |
| arg_names, | |
| state_args, | |
| vector_mode, | |
| len(state_symbols), | |
| lambda_rhs, | |
| ) | |
| return state_symbols, exprs, lambda_rhs | |
| __all__ = ["dynsys_to_sympy"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment