Last active
July 4, 2025 14:13
-
-
Save ziap/b32eed4d7d453337402994719125d2aa to your computer and use it in GitHub Desktop.
robust, analytic eigensolver for symmetric 3x3 matrix, faster than np.linalg.eigh, based on <https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf> with some modifications
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| eigh6.py — Fast symmetric 3×3 eigen‐decomposition with Numba | |
| This module provides a highly optimized, stand-alone (no Numpy) implementation | |
| of the eigenvalue decomposition for real symmetric 3×3 matrices. It returns the | |
| three real eigenvalues in ascending order along with an orthonormal basis of | |
| eigenvectors. The implementation is designed for use with Numba’s @njit | |
| decorator and features: | |
| - Analytic solution for eigenvalues via closed‐form formulas (trigonometric | |
| method on the characteristic cubic). | |
| - Direct construction of eigenvectors using cross‐product "nullspace" trick | |
| and 2×2 subproblem for intermediate eigenvalues. | |
| - One‐step Rayleigh quotient refinement to boost accuracy. | |
| - Special‐case handling for nearly diagonal or nearly zero matrices. | |
| - No external dependencies beyond Python stdlib and Numba. | |
| Usage | |
| ----- | |
| >>> from eigh6 import eigh6 | |
| >>> A = (2.0, -1.0, 0.0, | |
| ... 2.0, 3.0, 1.0) | |
| >>> eigenvalues, eigenvectors = eigh6(A) | |
| >>> print("Eigenvalues:", eigenvalues) | |
| >>> for v in eigenvectors: | |
| ... print("Eigenvector:", v) | |
| Further reading | |
| --------------- | |
| - Eberly D. 2014. *A Robust Eigensolver for 3 × 3 Symmetric Matrices*. | |
| Geometric Tools, Redmond, WA, USA (December 2014), 11-18. | |
| https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf | |
| - Oliver K. Smith. 1961. *Eigenvalues of a symmetric 3 × 3 matrix*. | |
| Commun. ACM 4, 4 (April 1961), 168. https://doi.org/10.1145/355578.366316 | |
| """ | |
| import math | |
| import sys | |
| from numba import njit | |
| type Vector = tuple[float, float, float] | |
| """ | |
| A length-3 tuple representing a 3D vector. | |
| """ | |
| type SymMat3 = tuple[float, float, float, float, float, float] | |
| """ | |
| The six unique entries of a symmetric 3×3 matrix in row‐major order: | |
| ⎡a00 a01 a02⎤ | |
| ⎢ ⎥ | |
| ⎢a01 a11 a12⎥ → (a00, a01, a02, a11, a12, a22) | |
| ⎢ ⎥ | |
| ⎣a02 a12 a22⎦ | |
| """ | |
| _eps = sys.float_info.epsilon | |
| @njit | |
| def _dot(v1: Vector, v2: Vector) -> float: | |
| x1, y1, z1 = v1 | |
| x2, y2, z2 = v2 | |
| return x1*x2 + y1*y2 + z1*z2 | |
| @njit | |
| def _cross(v1: Vector, v2: Vector) -> Vector: | |
| x1, y1, z1 = v1 | |
| x2, y2, z2 = v2 | |
| return ( | |
| y1*z2 - z1*y2, | |
| z1*x2 - x1*z2, | |
| x1*y2 - y1*x2, | |
| ) | |
| @njit | |
| def _solve_ldlt_3x3(A: SymMat3, b: Vector) -> Vector: | |
| """ | |
| Solve a 3x3 symmetric system with LDL^T factorization used in the Rayleigh | |
| quotient refinement step. Although not numerically stable for indefnite | |
| matrix, it is better than the adjugate method or Cramer's rule. | |
| """ | |
| a00, a01, a02, a11, a12, a22 = A | |
| # LDL^T factorization | |
| d1 = a00 | |
| if abs(d1) < _eps: | |
| return b | |
| l10 = a01/d1 | |
| l20 = a02/d1 | |
| d2 = a11 - l10*a01 | |
| if abs(d2) < _eps: | |
| return b | |
| l21 = (a12 - l10*a02)/d2 | |
| d3 = a22 - (l20*a02 + l21*(a12 - l10*a02)) | |
| if abs(d3) < _eps: | |
| return b | |
| # Forward substitution L y = b | |
| y0 = b[0] | |
| y1 = b[1] - l10*y0 | |
| y2 = b[2] - l20*y0 - l21*y1 | |
| # Diagonal substitution D z = y | |
| z0 = y0/d1 | |
| z1 = y1/d2 | |
| z2 = y2/d3 | |
| # Backward substitution L^T x = z | |
| x2 = z2 | |
| x1 = z1 - l21*x2 | |
| x0 = z0 - l10*x1 - l20*x2 | |
| return (x0, x1, x2) | |
| @njit | |
| def _refine_eigenpair(A: SymMat3, lam: float, v: Vector) -> tuple[float, Vector]: | |
| """ | |
| Perform a single Rayleigh quotient iteration to improve the accuracy of the | |
| eigenpair without resorting to iterative algorithms. | |
| """ | |
| a00, a01, a02, a11, a12, a22 = A | |
| # Imrpove eigenvector with shifted inverse iteration | |
| w = _solve_ldlt_3x3((a00 - lam, a01, a02, a11 - lam, a12, a22 - lam), v) | |
| d = math.sqrt(_dot(w, w)) | |
| v = (w[0]/d, w[1]/d, w[2]/d) | |
| # Set the eigenvalue to the Rayleigh Quotient | |
| AV = ( | |
| a00*v[0] + a01*v[1] + a02*v[2], | |
| a01*v[0] + a11*v[1] + a12*v[2], | |
| a02*v[0] + a12*v[1] + a22*v[2], | |
| ) | |
| lam = _dot(v, AV) | |
| return lam, v | |
| @njit | |
| def _first_evec(A: SymMat3, lam: float) -> tuple[float, Vector]: | |
| """ | |
| Compute the first eigenvector using the cross product nullspace trick. | |
| Require rank(A - lam*I) = 2, otherwise cause undefined behavior. | |
| """ | |
| a00, a01, a02, a11, a12, a22 = A | |
| # Rows of M = A - lam*I | |
| r0 = (a00 - lam, a01, a02) | |
| r1 = (a01, a11 - lam, a12) | |
| r2 = (a02, a12, a22 - lam) | |
| # Compute all cross product pairs | |
| v0 = _cross(r1, r2) | |
| v1 = _cross(r2, r0) | |
| v2 = _cross(r0, r1) | |
| d0 = _dot(v0, v0) | |
| d1 = _dot(v1, v1) | |
| d2 = _dot(v2, v2) | |
| # Select the cross product with the largest (squared) norm, which is the | |
| # cross product of the two most linearly independent row | |
| if d1 > d0: | |
| v0, d0 = v1, d1 | |
| if d2 > d0: | |
| v0, d0 = v2, d2 | |
| return _refine_eigenpair(A, lam, v0) | |
| @njit | |
| def _second_evec(A: SymMat3, v0: Vector, lam1: float) -> tuple[float, Vector]: | |
| """ | |
| Find the second eigenvector by projecting onto the subspace orthogonal to the | |
| first eigenvector and solve a 2x2 eigenvector problem. | |
| """ | |
| a00, a01, a02, a11, a12, a22 = A | |
| x, y, z = v0 | |
| # Pick a simple basis for the subspace orthogonal to v0 | |
| if abs(x) > abs(y): | |
| inv = 1.0/math.sqrt(x*x + z*z) | |
| U = (-z*inv, 0.0, x*inv) | |
| else: | |
| inv = 1.0/math.sqrt(y*y + z*z) | |
| U = (0.0, z*inv, -y*inv) | |
| V = _cross(v0, U) | |
| # Project A onto U,V | |
| AU = ( | |
| a00*U[0] + a01*U[1] + a02*U[2], | |
| a01*U[0] + a11*U[1] + a12*U[2], | |
| a02*U[0] + a12*U[1] + a22*U[2], | |
| ) | |
| AV = ( | |
| a00*V[0] + a01*V[1] + a02*V[2], | |
| a01*V[0] + a11*V[1] + a12*V[2], | |
| a02*V[0] + a12*V[1] + a22*V[2], | |
| ) | |
| # Create the MX = 0 system, where M = (A - I*lambda) | |
| m00 = _dot(U, AU) - lam1 | |
| m01 = _dot(U, AV) | |
| m11 = _dot(V, AV) - lam1 | |
| # The solution is orthogonal to any non-zero row of M, so if we pick the row | |
| # (x, y) then the eigenvector will be (-y, x) | |
| # Because both rows have m01, comparing abs(m00) and abs(m11) is enough to | |
| # determine which row has the larger norm without explicitly computing it | |
| p0 = abs(m00) | |
| p1 = abs(m11) | |
| if p1 > p0: | |
| x, y = m01, m11 | |
| p = p1 | |
| else: | |
| x, y = m00, m01 | |
| p = p0 | |
| if max(p, abs(m01)) < _eps: | |
| # M is singular, which happens with (numerically) repeated eigenvalues | |
| # In this case, any vector in the subspace will do | |
| v0 = U | |
| else: | |
| # Construct the eigenvector (-y, x) and lift it back to 3D | |
| v0 = (x*V[0] - y*U[0], x*V[1] - y*U[1], x*V[2] - y*U[2]) | |
| return _refine_eigenpair(A, lam1, v0) | |
| @njit | |
| def eigh6(A: SymMat3) -> tuple[Vector, tuple[Vector, Vector, Vector]]: | |
| """ | |
| Compute the eigenvalues and eigenvectors of the symmetric matrix A. | |
| Parameters | |
| ---------- | |
| A: SymMat3 | |
| The 6 upper-triangular entries of the input 3x3 symmetric matrix | |
| Returns | |
| ------- | |
| A tuple (λ, (v0, v1, v2)) where | |
| λ: Vector = (λ_min, λ_mid, λ_max) | |
| The sorted eigenvalues | |
| v0: Vector, v1: Vector, v2: Vector | |
| The corresponding orthonormal eigenvectors | |
| """ | |
| max_abs = max(map(abs, A)) | |
| eye = ( | |
| (1, 0, 0), | |
| (0, 1, 0), | |
| (0, 0, 1), | |
| ) | |
| if max_abs < _eps: | |
| return (0, 0, 0), eye | |
| else: | |
| # Scale down for better numerical stability | |
| a00, a01, a02, a11, a12, a22 = A | |
| a00 /= max_abs; a01 /= max_abs; a02 /= max_abs | |
| a11 /= max_abs; a12 /= max_abs; a22 /= max_abs | |
| A = a00, a01, a02, a11, a12, a22 | |
| norm_off = a01*a01 + a02*a02 + a12*a12 | |
| if norm_off < _eps: | |
| # Numerically diagonal, effectively handled the triple eigenvalues case | |
| l0 = a00; l1 = a11; l2 = a22 | |
| v0, v1, v2 = eye | |
| # Sort the eigenvalues in ascending order | |
| if l1 < l0: | |
| l1, l0 = l0, l1 | |
| v1, v0 = v0, v1 | |
| if l2 < l1: | |
| l2, l1 = l1, l2 | |
| v2, v1 = v1, v2 | |
| if l1 < l0: | |
| l1, l0 = l0, l1 | |
| v1, v0 = v0, v1 | |
| else: | |
| tr = a00 + a11 + a22 | |
| q = tr/3.0 | |
| b00 = a00 - q | |
| b11 = a11 - q | |
| b22 = a22 - q | |
| norm_diag = b00*b00 + b11*b11 + b22*b22 | |
| p = math.sqrt((norm_diag + 2.0*norm_off)/6.0) | |
| # Compute det(1/p B) via minors of B | |
| c00 = b11*b22 - a12*a12 | |
| c01 = a01*b22 - a12*a02 | |
| c02 = a01*a12 - b11*a02 | |
| detB = (b00*c00 - a01*c01 + a02*c02)/(p*p*p) | |
| halfdet = max(min(0.5*detB, 1), -1) | |
| phi = math.acos(halfdet)/3.0 | |
| beta2 = 2.0*math.cos(phi) | |
| beta0 = 2.0*math.cos(phi + 2.0943951023931954923084289) | |
| # Recover eigenvalues of A from eigenvalues of 1/p B | |
| l0 = q + p*beta0 | |
| l2 = q + p*beta2 | |
| # Pick build order to avoid double first eigenvalue | |
| if halfdet >= 0.0: | |
| # l2 > l1 >= l0 | |
| l2, v2 = _first_evec(A, l2) | |
| l0, v0 = _second_evec(A, v2, l0) | |
| else: | |
| # l2 >= l1 > l0 | |
| l0, v0 = _first_evec(A, l0) | |
| l2, v2 = _second_evec(A, v0, l2) | |
| # Compute the middle eigenpair from the other two | |
| # Don't refine because it may break orthogonality | |
| l1 = tr - l0 - l2 | |
| v1 = _cross(v0, v2) | |
| # Scale eigenvalues back up | |
| return (l0*max_abs, l1*max_abs, l2*max_abs), (v0, v1, v2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from collections.abc import Callable | |
| from time import perf_counter | |
| from typing import Literal | |
| import numpy as np | |
| from numba import njit | |
| from eigh6 import eigh6 | |
| type NPVec = np.ndarray[tuple[int], np.dtype[np.float64]] | |
| type NPMat = np.ndarray[tuple[int, int], np.dtype[np.float64]] | |
| @njit | |
| def eigh6_np(A: NPMat) -> tuple[NPVec, NPMat]: | |
| a00, a01, a02 = A[0, 0], A[0, 1], A[0, 2] | |
| a11, a12, a22 = A[1, 1], A[1, 2], A[2, 2] | |
| D, Q = eigh6((a00, a01, a02, a11, a12, a22)) | |
| return np.array(D), np.array(Q).T | |
| @njit(cache=True) | |
| def eigh_lapack(A: NPMat) -> tuple[NPVec, NPMat]: | |
| return np.linalg.eigh(A) | |
| @njit(cache=True) | |
| def generate_tests( | |
| n_trials: int, | |
| n_rows: int, | |
| scale: float, | |
| psd: bool, | |
| rng: np.random.Generator | |
| ) -> np.ndarray[tuple[int, int, int], np.dtype[np.float64]]: | |
| D = 3 | |
| Gs = np.empty((n_trials, D, D), np.float64) | |
| signs = np.array([-1, -1, 0, 1, 1, 1], np.int64) | |
| for t in range(n_trials): | |
| A = rng.standard_normal((n_rows, D)) | |
| s = scale / (n_rows - 1) | |
| S, V = np.linalg.eigh((A.T @ A) * s) | |
| k = rng.integers(0, D - 1) + 2 | |
| L = rng.permutation(S)[:k] | |
| if not psd: | |
| L *= signs[rng.integers(0, signs.shape[0], size=(k,))] | |
| if k < D: | |
| Lp = L[rng.integers(0, L.shape[0], size=(D,))] | |
| else: | |
| Lp = rng.permutation(L) | |
| Gs[t] = (V * Lp) @ V.T | |
| return Gs | |
| def fuzz_test( | |
| n_trials: int, | |
| n_rows: int, | |
| scale: float, | |
| eigh_fn: Callable[[NPMat], tuple[NPVec, NPMat]], | |
| psd: bool = False, | |
| right_evec: Literal["inverse", "transpose"] = "transpose", | |
| tol: float = 1e-10, | |
| seed: int | None = None | |
| ): | |
| print(f"Generating {n_trials} testcases...", end=" ", flush=True) | |
| rng = np.random.default_rng(seed) | |
| Gs = generate_tests(n_trials, n_rows, scale, psd, rng) | |
| print("Done") | |
| print("Example eigenvalues:") | |
| for i in range(10): | |
| D, _ = eigh_fn(Gs[i]) | |
| print("Custom:", D) | |
| D_ = np.linalg.eigvals(Gs[i]) | |
| D_.sort() | |
| print("LAPACK:", D_) | |
| start = perf_counter() | |
| DQs = list(map(eigh_fn, Gs)) | |
| print(f"Decomposition speed: {n_trials / (perf_counter() - start):.2f} ops/s") | |
| max_err = 0 | |
| failures = 0 | |
| for idx, ((D, Q), G) in enumerate(zip(DQs, Gs)): | |
| match right_evec: | |
| case "inverse": | |
| QT = np.linalg.inv(Q) | |
| case "transpose": | |
| QT = Q.T | |
| G_hat = (Q * D) @ QT | |
| err = np.max(np.abs(G - G_hat)) | |
| max_err = max(max_err, err) | |
| if err > tol: | |
| print(f"[Case {idx + 1}]: reconstruction error {err:.4e} exceeded tol={tol:.4e}") | |
| print("Eigenvalues:", D) | |
| failures += 1 | |
| if failures == 0: | |
| print(f"All {n_trials} reconstructions passed within tol={max_err:.4e}") | |
| else: | |
| print(f"{n_trials - failures}/{n_trials} cases passed within tol={tol:.4e}") | |
| print(f"Maximum error: {max_err:.4e}") | |
| if __name__ == "__main__": | |
| fuzz_test(500000, 10, 1, eigh6_np, right_evec="transpose", tol=1e-14, seed=42) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment