Skip to content

Instantly share code, notes, and snippets.

@evanatyourservice
Created June 5, 2025 18:00
Show Gist options
  • Select an option

  • Save evanatyourservice/6eccc5d2f0caedc0387dcc6ea7c9574f to your computer and use it in GitHub Desktop.

Select an option

Save evanatyourservice/6eccc5d2f0caedc0387dcc6ea7c9574f to your computer and use it in GitHub Desktop.
alexander stotsky newton schulz inverses
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_spd_matrix
"""
Factorized newton schulz iters for inverse of SPD matrix, from paper
by Alexander Stotsky https://arxiv.org/pdf/2208.04068
From efficiency equation EI = n^(1/np) where EI is the efficiency, n is the
order of the algorithm, and np is the number of matmuls in the algorithm, n=11
gives best covergence rate per matmul with EI = 11^(1/6) ≈ 1.4913.
"""
def ns_step_n2(G, A, I):
F = I - G @ A
return (I + F) @ G
def ns_step_n3(G, A, I):
F = I - G @ A
F2 = F @ F
return (I + F + F2) @ G
def ns_step_n4(G, A, I):
F = I - G @ A
F2 = F @ F
return (I + F2) @ (I + F) @ G
def ns_step_n5(G, A, I):
F = I - G @ A
F2 = F @ F
P4 = (I + F2) @ (I + F)
return (I + F @ P4) @ G
def ns_step_n6(G, A, I):
F = I - G @ A
F2 = F @ F
F3 = F2 @ F
return (I + F3) @ (I + F + F2) @ G
def ns_step_n7(G, A, I):
F = I - G @ A
F2 = F @ F
F3 = F2 @ F
P6 = (I + F3) @ (I + F + F2)
return (I + F @ P6) @ G
def ns_step_n8(G, A, I):
F = I - G @ A
F2 = F @ F
F4 = F2 @ F2
return (I + F4) @ (I + F2) @ (I + F) @ G
def ns_step_n9(G, A, I):
F = I - G @ A
F2 = F @ F
F3 = F2 @ F
F6 = F3 @ F3
return (I + F3 + F6) @ (I + F + F2) @ G
def ns_step_n10(G, A, I):
F = I - G @ A
F2 = F @ F
F4 = F2 @ F2
return (I + (F2 + F4) @ (I + F2)) @ (I + F) @ G
def ns_step_n11(G, A, I):
F = I - G @ A
F2 = F @ F
F4 = F2 @ F2
term_c = (F2 + F4) @ (I + F2)
return (I + (I + term_c) @ (F + F2)) @ G
def ns_step_n12(G, A, I):
F = I - G @ A
F2 = F @ F
F3 = F2 @ F
F4 = F2 @ F2
F8 = F4 @ F4
return (I + F4 + F8) @ (I + F + F2 + F3) @ G
def ns_step_n13(G, A, I):
F = I - G @ A
F2 = F @ F
F3 = F2 @ F
F4 = F2 @ F2
F8 = F4 @ F4
P12 = (I + F4 + F8) @ (I + F + F2 + F3)
return (I + F @ P12) @ G
def ns_step_n14(G, A, I):
F = I - G @ A
F2 = F @ F
F3 = F2 @ F
F4 = F2 @ F2
F5 = F4 @ F
F6 = F3 @ F3
F7 = F6 @ F
return (I + F7) @ (I + F + F2 + F3 + F4 + F5 + F6) @ G
def ns_step_n15(G, A, I):
F = I - G @ A
F2 = F @ F
F3 = F2 @ F
F4 = F2 @ F2
F5 = F4 @ F
F10 = F5 @ F5
return (I + F5 + F10) @ (I + F + F2 + F3 + F4) @ G
def ns_step_n16(G, A, I):
F = I - G @ A
F2 = F @ F
F4 = F2 @ F2
F8 = F4 @ F4
return (I + F8) @ (I + F4) @ (I + F2) @ (I + F) @ G
def ns_step_n17(G, A, I):
F = I - G @ A
F2 = F @ F
F4 = F2 @ F2
F8 = F4 @ F4
P16 = (I + F8) @ (I + F4) @ (I + F2) @ (I + F)
return (I + F @ P16) @ G
if __name__ == "__main__":
dim = 64
A = make_spd_matrix(dim, random_state=1)
A_inv = np.linalg.inv(A)
I = np.eye(dim)
norm_A_1 = np.linalg.norm(A, 1)
norm_A_inf = np.linalg.norm(A, np.inf)
G0 = A / (norm_A_1 * norm_A_inf)
max_iterations = 20
orders = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
step_functions = {
2: ns_step_n2, 3: ns_step_n3, 4: ns_step_n4, 5: ns_step_n5,
6: ns_step_n6, 7: ns_step_n7, 8: ns_step_n8, 9: ns_step_n9,
10: ns_step_n10, 11: ns_step_n11, 12: ns_step_n12, 13: ns_step_n13,
14: ns_step_n14, 15: ns_step_n15, 16: ns_step_n16, 17: ns_step_n17,
}
results = {n: {'G': G0.copy(), 'errors': []} for n in orders}
norm_A_inv = np.linalg.norm(A_inv, 'fro')
print(f"Testing Newton-Schulz on {dim}x{dim} SPD matrix...")
for k in range(max_iterations):
for n in orders:
G_current = results[n]['G']
G_next = step_functions[n](G_current, A, I)
results[n]['G'] = G_next
error = np.linalg.norm(A_inv - G_next, 'fro') / norm_A_inv
results[n]['errors'].append(error)
print("\nFinal relative errors:")
for n in orders:
print(f"Order {n}: {results[n]['errors'][-1]:.3e}")
plt.figure(figsize=(10, 6))
iterations = np.arange(1, max_iterations + 1)
for n in orders:
plt.plot(iterations, results[n]['errors'], marker='o', linestyle='-', label=f'Order {n}')
plt.yscale('log')
plt.xlabel('Iteration Number')
plt.ylabel('Relative Error')
plt.title(f'Newton-Schulz Convergence ({dim}x{dim} matrix)')
plt.legend()
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(iterations)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment