Created
June 5, 2025 18:00
-
-
Save evanatyourservice/6eccc5d2f0caedc0387dcc6ea7c9574f to your computer and use it in GitHub Desktop.
alexander stotsky newton schulz inverses
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.datasets import make_spd_matrix | |
| """ | |
| Factorized newton schulz iters for inverse of SPD matrix, from paper | |
| by Alexander Stotsky https://arxiv.org/pdf/2208.04068 | |
| From efficiency equation EI = n^(1/np) where EI is the efficiency, n is the | |
| order of the algorithm, and np is the number of matmuls in the algorithm, n=11 | |
| gives best covergence rate per matmul with EI = 11^(1/6) ≈ 1.4913. | |
| """ | |
| def ns_step_n2(G, A, I): | |
| F = I - G @ A | |
| return (I + F) @ G | |
| def ns_step_n3(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| return (I + F + F2) @ G | |
| def ns_step_n4(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| return (I + F2) @ (I + F) @ G | |
| def ns_step_n5(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| P4 = (I + F2) @ (I + F) | |
| return (I + F @ P4) @ G | |
| def ns_step_n6(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F3 = F2 @ F | |
| return (I + F3) @ (I + F + F2) @ G | |
| def ns_step_n7(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F3 = F2 @ F | |
| P6 = (I + F3) @ (I + F + F2) | |
| return (I + F @ P6) @ G | |
| def ns_step_n8(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F4 = F2 @ F2 | |
| return (I + F4) @ (I + F2) @ (I + F) @ G | |
| def ns_step_n9(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F3 = F2 @ F | |
| F6 = F3 @ F3 | |
| return (I + F3 + F6) @ (I + F + F2) @ G | |
| def ns_step_n10(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F4 = F2 @ F2 | |
| return (I + (F2 + F4) @ (I + F2)) @ (I + F) @ G | |
| def ns_step_n11(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F4 = F2 @ F2 | |
| term_c = (F2 + F4) @ (I + F2) | |
| return (I + (I + term_c) @ (F + F2)) @ G | |
| def ns_step_n12(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F3 = F2 @ F | |
| F4 = F2 @ F2 | |
| F8 = F4 @ F4 | |
| return (I + F4 + F8) @ (I + F + F2 + F3) @ G | |
| def ns_step_n13(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F3 = F2 @ F | |
| F4 = F2 @ F2 | |
| F8 = F4 @ F4 | |
| P12 = (I + F4 + F8) @ (I + F + F2 + F3) | |
| return (I + F @ P12) @ G | |
| def ns_step_n14(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F3 = F2 @ F | |
| F4 = F2 @ F2 | |
| F5 = F4 @ F | |
| F6 = F3 @ F3 | |
| F7 = F6 @ F | |
| return (I + F7) @ (I + F + F2 + F3 + F4 + F5 + F6) @ G | |
| def ns_step_n15(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F3 = F2 @ F | |
| F4 = F2 @ F2 | |
| F5 = F4 @ F | |
| F10 = F5 @ F5 | |
| return (I + F5 + F10) @ (I + F + F2 + F3 + F4) @ G | |
| def ns_step_n16(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F4 = F2 @ F2 | |
| F8 = F4 @ F4 | |
| return (I + F8) @ (I + F4) @ (I + F2) @ (I + F) @ G | |
| def ns_step_n17(G, A, I): | |
| F = I - G @ A | |
| F2 = F @ F | |
| F4 = F2 @ F2 | |
| F8 = F4 @ F4 | |
| P16 = (I + F8) @ (I + F4) @ (I + F2) @ (I + F) | |
| return (I + F @ P16) @ G | |
| if __name__ == "__main__": | |
| dim = 64 | |
| A = make_spd_matrix(dim, random_state=1) | |
| A_inv = np.linalg.inv(A) | |
| I = np.eye(dim) | |
| norm_A_1 = np.linalg.norm(A, 1) | |
| norm_A_inf = np.linalg.norm(A, np.inf) | |
| G0 = A / (norm_A_1 * norm_A_inf) | |
| max_iterations = 20 | |
| orders = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] | |
| step_functions = { | |
| 2: ns_step_n2, 3: ns_step_n3, 4: ns_step_n4, 5: ns_step_n5, | |
| 6: ns_step_n6, 7: ns_step_n7, 8: ns_step_n8, 9: ns_step_n9, | |
| 10: ns_step_n10, 11: ns_step_n11, 12: ns_step_n12, 13: ns_step_n13, | |
| 14: ns_step_n14, 15: ns_step_n15, 16: ns_step_n16, 17: ns_step_n17, | |
| } | |
| results = {n: {'G': G0.copy(), 'errors': []} for n in orders} | |
| norm_A_inv = np.linalg.norm(A_inv, 'fro') | |
| print(f"Testing Newton-Schulz on {dim}x{dim} SPD matrix...") | |
| for k in range(max_iterations): | |
| for n in orders: | |
| G_current = results[n]['G'] | |
| G_next = step_functions[n](G_current, A, I) | |
| results[n]['G'] = G_next | |
| error = np.linalg.norm(A_inv - G_next, 'fro') / norm_A_inv | |
| results[n]['errors'].append(error) | |
| print("\nFinal relative errors:") | |
| for n in orders: | |
| print(f"Order {n}: {results[n]['errors'][-1]:.3e}") | |
| plt.figure(figsize=(10, 6)) | |
| iterations = np.arange(1, max_iterations + 1) | |
| for n in orders: | |
| plt.plot(iterations, results[n]['errors'], marker='o', linestyle='-', label=f'Order {n}') | |
| plt.yscale('log') | |
| plt.xlabel('Iteration Number') | |
| plt.ylabel('Relative Error') | |
| plt.title(f'Newton-Schulz Convergence ({dim}x{dim} matrix)') | |
| plt.legend() | |
| plt.grid(True, which='both', linestyle='--', linewidth=0.5) | |
| plt.xticks(iterations) | |
| plt.tight_layout() | |
| plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment