Last active
April 21, 2025 22:34
-
-
Save CallumJHays/925890540e7a7515bbc2880e5438583b to your computer and use it in GitHub Desktop.
SymbolicMatrixFunction class for sympy that addresses https://github.com/sympy/sympy/issues/23221
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, Collection, Dict | |
| from sympy import Expr, MatrixExpr, Basic, MatrixSymbol, sin, Matrix, Tuple, Function, solve, ImmutableDenseMatrix, Symbol, symbols, StrPrinter, Eq, randMatrix, cos, MatMul, Derivative, Integral, ShapeError | |
| from sympy.core import cacheit | |
| from sympy.core.sympify import _sympify | |
| from sympy.core.symbol import Str | |
| import unittest | |
| __all__ = ["SymbolicMatrixFunction", "solve_with_sym_matrices"] | |
| class SymbolicMatrixFunction(MatrixSymbol): | |
| def __new__(cls, name: "Str | str", n: int, m: int, function_of: Collection[Expr]): | |
| n, m = _sympify(n), _sympify(m) | |
| cls._check_dim(m) | |
| cls._check_dim(n) | |
| if isinstance(name, str): | |
| name = Str(name) | |
| obj = Basic.__new__(cls, name, n, m, Tuple(*function_of)) | |
| return obj | |
| def __init__(self, _name: str, _n: int, _m: int, function_of: Collection[Expr]): | |
| assert any(function_of), "SymbolicMatrixFunction must be a function of at least 1 expression" | |
| self.function_of: Tuple = Tuple(*function_of) # store as immutable tuple so we stay hashable | |
| def _sympystr(self, printer: StrPrinter) -> str: | |
| return f"{self.name}({', '.join(str(x) for x in self.function_of)})" | |
| @property | |
| def free_symbols(self): | |
| return set(self.function_of) | |
| def diff(self, *args, **kwargs): | |
| return Derivative(self, *args, **kwargs) | |
| def _entry(self, i, j) -> Function: | |
| return Function(f"{self.name}[{i}, {j}]")(*self.function_of) # type: ignore | |
| def integrate(self, *wrt: Symbol, **kwargs): | |
| return Integral(self, *wrt, **kwargs) | |
| def __matmul__(self, other): | |
| if other == 1: # TODO: Identity mat as well? | |
| return self | |
| return MatMul(self, other) | |
| def __rmatmul__(self, other): | |
| if other == 1: # TODO: Identity mat as well? | |
| return self | |
| return MatMul(other, self) | |
| def __mul__(self, other): | |
| if other == 1: # TODO: Identity mat as well? | |
| return self | |
| return MatMul(self, other) | |
| def __rmul__(self, other): | |
| if other == 1: # TODO: Identity mat as well? | |
| return self | |
| return MatMul(other, self) | |
| # @cacheit | |
| # def has(self, *patterns): | |
| # # reqd for integral impl | |
| # return super().has(*patterns) or all(p in self.function_of for p in patterns) | |
| def solve_with_sym_matrices(equations: Collection[Eq], *solve_for: Expr, explicit: bool = True): | |
| """ | |
| solve() does not yet handle MatrixSymbol's properly. This function (kinda) does. | |
| Works by substituting each MatrixExpr with its ._as_explicit() version if `explicit`, | |
| otherwise a non-commutative symbol will be used in its place. | |
| PS: | |
| Setting explicit = False may speed up solving at the cost of not being able to solve problems wherein matrix internals are referenced in intermediary steps. | |
| This is just my current understanding as an engineering bachelor, so it may not be correct. | |
| """ | |
| # populated in subbed_eqns() | |
| subs_map: Dict[MatrixSymbol, "Symbol | ImmutableDenseMatrix"] = {} | |
| for eqn in equations: | |
| # add matrices to subs_map | |
| for mat in eqn.atoms(MatrixSymbol if explicit else MatrixExpr): | |
| if mat not in subs_map: | |
| subs_map[mat] = mat.as_explicit() if explicit \ | |
| else Symbol(f"x{len(subs_map)}", commutative=False) | |
| sln = solve([eqn.subs(subs_map) for eqn in equations], *subs_map.values()) | |
| res = tuple(sym.subs(subs_map).subs(sln).doit() for sym in solve_for) | |
| return res # TODO: support multiple solutions properly? | |
| class TestSymbolicMatrixFunction(unittest.TestCase): | |
| def _assert_str_repr(self, x: Any, expected: str): | |
| self.assertEqual(str(x), expected) | |
| self.assertEqual(repr(x), expected) | |
| def test_init(self): | |
| t = Symbol("t") | |
| SymbolicMatrixFunction("A", 3, 3, {t}) | |
| def test_repr(self): | |
| t = Symbol("t") | |
| A = SymbolicMatrixFunction("A", 3, 3, {t}) | |
| self._assert_str_repr(A, "A(t)") | |
| def test_diff(self): | |
| t = Symbol("t") | |
| A = SymbolicMatrixFunction("A", 3, 3, {t}) | |
| dA = A.diff(t) | |
| self._assert_str_repr(dA, "Derivative(A(t), t)") | |
| def test_integral(self): | |
| t = Symbol("t") | |
| A = SymbolicMatrixFunction("A", 3, 3, {t}) | |
| iA = A.integrate(t) | |
| self._assert_str_repr(iA, "Integral(A(t), t)") | |
| def test_add(self): | |
| t = Symbol("t") | |
| A = SymbolicMatrixFunction("A", 3, 3, {t}) | |
| B = SymbolicMatrixFunction("B", 3, 3, {t}) | |
| C = A + B | |
| self._assert_str_repr(C, "A(t) + B(t)") | |
| def test_sub(self): | |
| t = Symbol("t") | |
| A = SymbolicMatrixFunction("A", 3, 3, {t}) | |
| B = SymbolicMatrixFunction("B", 3, 3, {t}) | |
| C = A - B | |
| self._assert_str_repr(C, "A(t) - B(t)") | |
| # def test_addsub_wrongshape_fails(self): | |
| # t = Symbol("t") | |
| # A = SymbolicMatrixFunction("A", 3, 3, {t}) | |
| # B = SymbolicMatrixFunction("B", 3, 2, {t}) | |
| # with self.assertRaises(ShapeError): | |
| # A + B | |
| # with self.assertRaises(ShapeError): | |
| # A - B | |
| def test_solve(self): | |
| t = Symbol("t") | |
| A = SymbolicMatrixFunction("A", 3, 3, {t}) | |
| B = randMatrix(3) | |
| slnB, = solve_with_sym_matrices([Eq(A, B)], A) | |
| assert slnB == B | |
| def test_elements_are_functions(self): | |
| t = Symbol("t") | |
| A = SymbolicMatrixFunction("A", 3, 3, {t}) | |
| x = A[1, 2] | |
| assert isinstance(x, Function) | |
| assert x.free_symbols == {t} | |
| self._assert_str_repr(x, "A[1, 2](t)") | |
| def test_diff_multivariate(self): | |
| x, y, z = symbols("x,y,z") # type: ignore | |
| A = SymbolicMatrixFunction("A", 3, 3, [x, y, z]) | |
| dA = A.diff(x, y) | |
| self._assert_str_repr(dA, "Derivative(A(x, y, z), x, y)") | |
| def test_integral_multivariate(self): | |
| x, y, z = symbols("x,y,z") # type: ignore | |
| A = SymbolicMatrixFunction("A", 3, 3, [x, y, z]) | |
| dA = A.integrate(x, y) | |
| self._assert_str_repr(dA, "Integral(A(x, y, z), x, y)") | |
| def test_solve_2d_rotation_diff_undef_theta(self): | |
| t = Symbol("t") | |
| theta: Function = Function("theta")(t) # type: ignore | |
| R = SymbolicMatrixFunction("R", 2, 2, {theta}) | |
| R_def_eqn: Eq = Eq(R, Matrix([ | |
| [cos(theta), -sin(theta)], # type: ignore | |
| [sin(theta), cos(theta)] | |
| ])) | |
| sln_Rdiff_wrt_theta, sln_Rdiff_wrt_t = solve_with_sym_matrices( | |
| [R_def_eqn], R.diff(theta), R.diff(t) | |
| ) | |
| sln_Rdiff_wrt_theta_expected = Matrix([ | |
| [-sin(theta), -cos(theta)], # type: ignore | |
| [cos(theta), -sin(theta)] # type: ignore | |
| ]) | |
| assert sln_Rdiff_wrt_theta == sln_Rdiff_wrt_theta_expected | |
| assert sln_Rdiff_wrt_t == theta.diff() * sln_Rdiff_wrt_theta_expected # chain rule | |
| def test_solve_2d_rotation_diff_def_theta(self): | |
| t = Symbol("t") | |
| theta: Expr = t ** 2 | |
| R = SymbolicMatrixFunction("R", 2, 2, {theta}) | |
| R_def_eqn: Eq = Eq(R, Matrix([ | |
| [cos(theta), -sin(theta)], # type: ignore | |
| [sin(theta), cos(theta)] | |
| ])) | |
| sln_Rdiff, = solve_with_sym_matrices([R_def_eqn], R.diff(t)) | |
| sln_Rdiff_wrt_theta_expected = theta.diff() * Matrix([ # chain rule | |
| [-sin(theta), -cos(theta)], # type: ignore | |
| [cos(theta), -sin(theta)] # type: ignore | |
| ]) | |
| assert sln_Rdiff == sln_Rdiff_wrt_theta_expected | |
| if __name__ == "__main__": | |
| unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment