Skip to content

Instantly share code, notes, and snippets.

@ziap
Last active July 4, 2025 14:13
Show Gist options
  • Select an option

  • Save ziap/b32eed4d7d453337402994719125d2aa to your computer and use it in GitHub Desktop.

Select an option

Save ziap/b32eed4d7d453337402994719125d2aa to your computer and use it in GitHub Desktop.
robust, analytic eigensolver for symmetric 3x3 matrix, faster than np.linalg.eigh, based on <https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf> with some modifications
"""
eigh6.py — Fast symmetric 3×3 eigen‐decomposition with Numba
This module provides a highly optimized, stand-alone (no Numpy) implementation
of the eigenvalue decomposition for real symmetric 3×3 matrices. It returns the
three real eigenvalues in ascending order along with an orthonormal basis of
eigenvectors. The implementation is designed for use with Numba’s @njit
decorator and features:
- Analytic solution for eigenvalues via closed‐form formulas (trigonometric
method on the characteristic cubic).
- Direct construction of eigenvectors using cross‐product "nullspace" trick
and 2×2 subproblem for intermediate eigenvalues.
- One‐step Rayleigh quotient refinement to boost accuracy.
- Special‐case handling for nearly diagonal or nearly zero matrices.
- No external dependencies beyond Python stdlib and Numba.
Usage
-----
>>> from eigh6 import eigh6
>>> A = (2.0, -1.0, 0.0,
... 2.0, 3.0, 1.0)
>>> eigenvalues, eigenvectors = eigh6(A)
>>> print("Eigenvalues:", eigenvalues)
>>> for v in eigenvectors:
... print("Eigenvector:", v)
Further reading
---------------
- Eberly D. 2014. *A Robust Eigensolver for 3 × 3 Symmetric Matrices*.
Geometric Tools, Redmond, WA, USA (December 2014), 11-18.
https://www.geometrictools.com/Documentation/RobustEigenSymmetric3x3.pdf
- Oliver K. Smith. 1961. *Eigenvalues of a symmetric 3 × 3 matrix*.
Commun. ACM 4, 4 (April 1961), 168. https://doi.org/10.1145/355578.366316
"""
import math
import sys
from numba import njit
type Vector = tuple[float, float, float]
"""
A length-3 tuple representing a 3D vector.
"""
type SymMat3 = tuple[float, float, float, float, float, float]
"""
The six unique entries of a symmetric 3×3 matrix in row‐major order:
⎡a00 a01 a02⎤
⎢ ⎥
⎢a01 a11 a12⎥ → (a00, a01, a02, a11, a12, a22)
⎢ ⎥
⎣a02 a12 a22⎦
"""
_eps = sys.float_info.epsilon
@njit
def _dot(v1: Vector, v2: Vector) -> float:
x1, y1, z1 = v1
x2, y2, z2 = v2
return x1*x2 + y1*y2 + z1*z2
@njit
def _cross(v1: Vector, v2: Vector) -> Vector:
x1, y1, z1 = v1
x2, y2, z2 = v2
return (
y1*z2 - z1*y2,
z1*x2 - x1*z2,
x1*y2 - y1*x2,
)
@njit
def _solve_ldlt_3x3(A: SymMat3, b: Vector) -> Vector:
"""
Solve a 3x3 symmetric system with LDL^T factorization used in the Rayleigh
quotient refinement step. Although not numerically stable for indefnite
matrix, it is better than the adjugate method or Cramer's rule.
"""
a00, a01, a02, a11, a12, a22 = A
# LDL^T factorization
d1 = a00
if abs(d1) < _eps:
return b
l10 = a01/d1
l20 = a02/d1
d2 = a11 - l10*a01
if abs(d2) < _eps:
return b
l21 = (a12 - l10*a02)/d2
d3 = a22 - (l20*a02 + l21*(a12 - l10*a02))
if abs(d3) < _eps:
return b
# Forward substitution L y = b
y0 = b[0]
y1 = b[1] - l10*y0
y2 = b[2] - l20*y0 - l21*y1
# Diagonal substitution D z = y
z0 = y0/d1
z1 = y1/d2
z2 = y2/d3
# Backward substitution L^T x = z
x2 = z2
x1 = z1 - l21*x2
x0 = z0 - l10*x1 - l20*x2
return (x0, x1, x2)
@njit
def _refine_eigenpair(A: SymMat3, lam: float, v: Vector) -> tuple[float, Vector]:
"""
Perform a single Rayleigh quotient iteration to improve the accuracy of the
eigenpair without resorting to iterative algorithms.
"""
a00, a01, a02, a11, a12, a22 = A
# Imrpove eigenvector with shifted inverse iteration
w = _solve_ldlt_3x3((a00 - lam, a01, a02, a11 - lam, a12, a22 - lam), v)
d = math.sqrt(_dot(w, w))
v = (w[0]/d, w[1]/d, w[2]/d)
# Set the eigenvalue to the Rayleigh Quotient
AV = (
a00*v[0] + a01*v[1] + a02*v[2],
a01*v[0] + a11*v[1] + a12*v[2],
a02*v[0] + a12*v[1] + a22*v[2],
)
lam = _dot(v, AV)
return lam, v
@njit
def _first_evec(A: SymMat3, lam: float) -> tuple[float, Vector]:
"""
Compute the first eigenvector using the cross product nullspace trick.
Require rank(A - lam*I) = 2, otherwise cause undefined behavior.
"""
a00, a01, a02, a11, a12, a22 = A
# Rows of M = A - lam*I
r0 = (a00 - lam, a01, a02)
r1 = (a01, a11 - lam, a12)
r2 = (a02, a12, a22 - lam)
# Compute all cross product pairs
v0 = _cross(r1, r2)
v1 = _cross(r2, r0)
v2 = _cross(r0, r1)
d0 = _dot(v0, v0)
d1 = _dot(v1, v1)
d2 = _dot(v2, v2)
# Select the cross product with the largest (squared) norm, which is the
# cross product of the two most linearly independent row
if d1 > d0:
v0, d0 = v1, d1
if d2 > d0:
v0, d0 = v2, d2
return _refine_eigenpair(A, lam, v0)
@njit
def _second_evec(A: SymMat3, v0: Vector, lam1: float) -> tuple[float, Vector]:
"""
Find the second eigenvector by projecting onto the subspace orthogonal to the
first eigenvector and solve a 2x2 eigenvector problem.
"""
a00, a01, a02, a11, a12, a22 = A
x, y, z = v0
# Pick a simple basis for the subspace orthogonal to v0
if abs(x) > abs(y):
inv = 1.0/math.sqrt(x*x + z*z)
U = (-z*inv, 0.0, x*inv)
else:
inv = 1.0/math.sqrt(y*y + z*z)
U = (0.0, z*inv, -y*inv)
V = _cross(v0, U)
# Project A onto U,V
AU = (
a00*U[0] + a01*U[1] + a02*U[2],
a01*U[0] + a11*U[1] + a12*U[2],
a02*U[0] + a12*U[1] + a22*U[2],
)
AV = (
a00*V[0] + a01*V[1] + a02*V[2],
a01*V[0] + a11*V[1] + a12*V[2],
a02*V[0] + a12*V[1] + a22*V[2],
)
# Create the MX = 0 system, where M = (A - I*lambda)
m00 = _dot(U, AU) - lam1
m01 = _dot(U, AV)
m11 = _dot(V, AV) - lam1
# The solution is orthogonal to any non-zero row of M, so if we pick the row
# (x, y) then the eigenvector will be (-y, x)
# Because both rows have m01, comparing abs(m00) and abs(m11) is enough to
# determine which row has the larger norm without explicitly computing it
p0 = abs(m00)
p1 = abs(m11)
if p1 > p0:
x, y = m01, m11
p = p1
else:
x, y = m00, m01
p = p0
if max(p, abs(m01)) < _eps:
# M is singular, which happens with (numerically) repeated eigenvalues
# In this case, any vector in the subspace will do
v0 = U
else:
# Construct the eigenvector (-y, x) and lift it back to 3D
v0 = (x*V[0] - y*U[0], x*V[1] - y*U[1], x*V[2] - y*U[2])
return _refine_eigenpair(A, lam1, v0)
@njit
def eigh6(A: SymMat3) -> tuple[Vector, tuple[Vector, Vector, Vector]]:
"""
Compute the eigenvalues and eigenvectors of the symmetric matrix A.
Parameters
----------
A: SymMat3
The 6 upper-triangular entries of the input 3x3 symmetric matrix
Returns
-------
A tuple (λ, (v0, v1, v2)) where
λ: Vector = (λ_min, λ_mid, λ_max)
The sorted eigenvalues
v0: Vector, v1: Vector, v2: Vector
The corresponding orthonormal eigenvectors
"""
max_abs = max(map(abs, A))
eye = (
(1, 0, 0),
(0, 1, 0),
(0, 0, 1),
)
if max_abs < _eps:
return (0, 0, 0), eye
else:
# Scale down for better numerical stability
a00, a01, a02, a11, a12, a22 = A
a00 /= max_abs; a01 /= max_abs; a02 /= max_abs
a11 /= max_abs; a12 /= max_abs; a22 /= max_abs
A = a00, a01, a02, a11, a12, a22
norm_off = a01*a01 + a02*a02 + a12*a12
if norm_off < _eps:
# Numerically diagonal, effectively handled the triple eigenvalues case
l0 = a00; l1 = a11; l2 = a22
v0, v1, v2 = eye
# Sort the eigenvalues in ascending order
if l1 < l0:
l1, l0 = l0, l1
v1, v0 = v0, v1
if l2 < l1:
l2, l1 = l1, l2
v2, v1 = v1, v2
if l1 < l0:
l1, l0 = l0, l1
v1, v0 = v0, v1
else:
tr = a00 + a11 + a22
q = tr/3.0
b00 = a00 - q
b11 = a11 - q
b22 = a22 - q
norm_diag = b00*b00 + b11*b11 + b22*b22
p = math.sqrt((norm_diag + 2.0*norm_off)/6.0)
# Compute det(1/p B) via minors of B
c00 = b11*b22 - a12*a12
c01 = a01*b22 - a12*a02
c02 = a01*a12 - b11*a02
detB = (b00*c00 - a01*c01 + a02*c02)/(p*p*p)
halfdet = max(min(0.5*detB, 1), -1)
phi = math.acos(halfdet)/3.0
beta2 = 2.0*math.cos(phi)
beta0 = 2.0*math.cos(phi + 2.0943951023931954923084289)
# Recover eigenvalues of A from eigenvalues of 1/p B
l0 = q + p*beta0
l2 = q + p*beta2
# Pick build order to avoid double first eigenvalue
if halfdet >= 0.0:
# l2 > l1 >= l0
l2, v2 = _first_evec(A, l2)
l0, v0 = _second_evec(A, v2, l0)
else:
# l2 >= l1 > l0
l0, v0 = _first_evec(A, l0)
l2, v2 = _second_evec(A, v0, l2)
# Compute the middle eigenpair from the other two
# Don't refine because it may break orthogonality
l1 = tr - l0 - l2
v1 = _cross(v0, v2)
# Scale eigenvalues back up
return (l0*max_abs, l1*max_abs, l2*max_abs), (v0, v1, v2)
from collections.abc import Callable
from time import perf_counter
from typing import Literal
import numpy as np
from numba import njit
from eigh6 import eigh6
type NPVec = np.ndarray[tuple[int], np.dtype[np.float64]]
type NPMat = np.ndarray[tuple[int, int], np.dtype[np.float64]]
@njit
def eigh6_np(A: NPMat) -> tuple[NPVec, NPMat]:
a00, a01, a02 = A[0, 0], A[0, 1], A[0, 2]
a11, a12, a22 = A[1, 1], A[1, 2], A[2, 2]
D, Q = eigh6((a00, a01, a02, a11, a12, a22))
return np.array(D), np.array(Q).T
@njit(cache=True)
def eigh_lapack(A: NPMat) -> tuple[NPVec, NPMat]:
return np.linalg.eigh(A)
@njit(cache=True)
def generate_tests(
n_trials: int,
n_rows: int,
scale: float,
psd: bool,
rng: np.random.Generator
) -> np.ndarray[tuple[int, int, int], np.dtype[np.float64]]:
D = 3
Gs = np.empty((n_trials, D, D), np.float64)
signs = np.array([-1, -1, 0, 1, 1, 1], np.int64)
for t in range(n_trials):
A = rng.standard_normal((n_rows, D))
s = scale / (n_rows - 1)
S, V = np.linalg.eigh((A.T @ A) * s)
k = rng.integers(0, D - 1) + 2
L = rng.permutation(S)[:k]
if not psd:
L *= signs[rng.integers(0, signs.shape[0], size=(k,))]
if k < D:
Lp = L[rng.integers(0, L.shape[0], size=(D,))]
else:
Lp = rng.permutation(L)
Gs[t] = (V * Lp) @ V.T
return Gs
def fuzz_test(
n_trials: int,
n_rows: int,
scale: float,
eigh_fn: Callable[[NPMat], tuple[NPVec, NPMat]],
psd: bool = False,
right_evec: Literal["inverse", "transpose"] = "transpose",
tol: float = 1e-10,
seed: int | None = None
):
print(f"Generating {n_trials} testcases...", end=" ", flush=True)
rng = np.random.default_rng(seed)
Gs = generate_tests(n_trials, n_rows, scale, psd, rng)
print("Done")
print("Example eigenvalues:")
for i in range(10):
D, _ = eigh_fn(Gs[i])
print("Custom:", D)
D_ = np.linalg.eigvals(Gs[i])
D_.sort()
print("LAPACK:", D_)
start = perf_counter()
DQs = list(map(eigh_fn, Gs))
print(f"Decomposition speed: {n_trials / (perf_counter() - start):.2f} ops/s")
max_err = 0
failures = 0
for idx, ((D, Q), G) in enumerate(zip(DQs, Gs)):
match right_evec:
case "inverse":
QT = np.linalg.inv(Q)
case "transpose":
QT = Q.T
G_hat = (Q * D) @ QT
err = np.max(np.abs(G - G_hat))
max_err = max(max_err, err)
if err > tol:
print(f"[Case {idx + 1}]: reconstruction error {err:.4e} exceeded tol={tol:.4e}")
print("Eigenvalues:", D)
failures += 1
if failures == 0:
print(f"All {n_trials} reconstructions passed within tol={max_err:.4e}")
else:
print(f"{n_trials - failures}/{n_trials} cases passed within tol={tol:.4e}")
print(f"Maximum error: {max_err:.4e}")
if __name__ == "__main__":
fuzz_test(500000, 10, 1, eigh6_np, right_evec="transpose", tol=1e-14, seed=42)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment