Created
December 2, 2025 10:43
-
-
Save CoryKornowicz/83a3da4d86e9a9d68fb903c5a9259577 to your computer and use it in GitHub Desktop.
MLX - Intelligent Matrix Exponentiation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import mlx.optimizers as optim | |
| from mlx.data.datasets import load_mnist | |
| @mx.compile | |
| def matrix_exp(M): | |
| """ | |
| Compute matrix exponential using scaling and squaring with Padé approximation. | |
| This implements a simplified version suitable for small matrices. | |
| For a matrix M: exp(M) ≈ Taylor series or Padé approximation | |
| Args: | |
| M: Input matrix of shape (..., n, n) | |
| Returns: | |
| Matrix exponential of shape (..., n, n) | |
| """ | |
| # Use Taylor series approximation: exp(M) = I + M + M^2/2! + M^3/3! + ... | |
| # We'll use enough terms for good accuracy | |
| # Identity matrix | |
| if M.ndim == 3: | |
| batch_size, n, _ = M.shape | |
| I = mx.eye(n) | |
| I = mx.broadcast_to(I[None, :, :], (batch_size, n, n)) | |
| else: | |
| I = mx.eye(M.shape[0]) | |
| # Initialize result with I + M | |
| result = I + M | |
| M_power = M | |
| factorial = 1.0 | |
| # Add terms up to order 20 for good accuracy | |
| # k = 2 | |
| factorial *= 2 | |
| M_power = M_power @ M | |
| result = result + M_power / factorial | |
| # k = 3 | |
| factorial *= 3 | |
| M_power = M_power @ M | |
| result = result + M_power / factorial | |
| return result | |
| class MLayer(nn.Module): | |
| """ | |
| M-layer as described: | |
| - Input: (B, C, H, W) or (B, in_dim) | |
| - Output: logits (B, h), plus exp(M) for regularization | |
| """ | |
| def __init__(self, in_shape=(1, 28, 28), d=35, n=30, h=10, sigma=0.05): | |
| super().__init__() | |
| self.in_shape = in_shape | |
| in_dim = 1 | |
| for s in in_shape: | |
| in_dim *= s | |
| self.d = d # latent dimension | |
| self.n = n # matrix size | |
| self.h = h # output dimension (# classes) | |
| # Initialize parameters with proper scaling | |
| scale = sigma | |
| # \tilde{U}: input -> latent d | |
| self.U = nn.Linear(in_dim, d, bias=False) | |
| # Initialize U's weight | |
| self.U.weight = mx.random.normal(shape=self.U.weight.shape) * scale | |
| # \tilde{T}: (d, n, n) - stored as trainable parameters | |
| self.T = mx.random.normal(shape=(d, n, n)) * scale | |
| # \tilde{B}: (n, n) | |
| self.B = mx.random.normal(shape=(n, n)) * scale | |
| # \tilde{S}: (h, n, n) | |
| self.S = mx.random.normal(shape=(h, n, n)) * scale | |
| # \tilde{V}: (h,) | |
| self.V = mx.random.normal(shape=(h,)) * scale | |
| def __call__(self, x): | |
| """ | |
| x: (B, C, H, W) or (B, in_dim) | |
| returns: | |
| logits: (B, h) | |
| expM: (B, n, n) # for activity regularization | |
| """ | |
| batch_size = x.shape[0] | |
| # Flatten if needed | |
| if x.ndim > 2: | |
| x = mx.reshape(x, (batch_size, -1)) # (B, in_dim) | |
| # Latent features Φ = U X -> (B, d) | |
| phi = self.U(x) # (B, d) | |
| # Build M: B_{jk} + Σ_a φ_a T_{a,j,k} | |
| # phi: (B, d), T: (d, n, n) -> M: (B, n, n) | |
| M = self.B + mx.einsum("bd,djk->bjk", phi, self.T) | |
| # Nonlinearity: matrix exponential for each sample | |
| # expM = matrix_exp(M) # (B, n, n) | |
| # Launch the kernel - one thread per batch element | |
| expM = matrix_exp(M) | |
| # Output logits: V_m + Σ_{j,k} S_{m,j,k} * expM_{b,j,k} | |
| # S: (h, n, n), expM: (B, n, n) -> logits: (B, h) | |
| logits = self.V + mx.einsum("mjk,bjk->bm", self.S, expM) | |
| return logits, expM | |
| class MNet(nn.Module): | |
| def __init__(self, d=35, n=30, h=10): | |
| super().__init__() | |
| self.m_layer = MLayer(in_shape=(1, 28, 28), d=d, n=n, h=h) | |
| def __call__(self, x): | |
| # Return logits and expM so the training loop can apply regularization | |
| return self.m_layer(x) | |
| def train_mnet( | |
| epochs=10, | |
| d=35, | |
| n=30, | |
| h=10, | |
| lambda_reg=1e-4, | |
| lr=1e-3, | |
| momentum=0.9, | |
| batch_size=64, | |
| ): | |
| print(f"Using MLX with stream: {mx.default_device()}") | |
| print("Performance optimizations: compiled loss function, async evaluation") | |
| # Load datasets once | |
| train_dataset = load_mnist(train=True) | |
| test_dataset = load_mnist(train=False) | |
| model = MNet(d=d, n=n, h=h) | |
| optimizer = optim.SGD(learning_rate=lr, momentum=momentum) | |
| def loss_fn(model, images, labels): | |
| logits, expM = model(images) | |
| loss_ce = nn.losses.cross_entropy(logits, labels, reduction="mean") | |
| reg_term = mx.mean(mx.sum(expM ** 2, axis=(1, 2))) | |
| loss = loss_ce + lambda_reg * reg_term | |
| return loss, (logits, expM) | |
| loss_and_grad_fn = nn.value_and_grad(model, loss_fn) | |
| for epoch in range(1, epochs + 1): | |
| # --------- REBUILD TRAIN STREAM EACH EPOCH ---------- | |
| train_loader = ( | |
| train_dataset | |
| .shuffle() | |
| .to_stream() | |
| .key_transform("image", lambda x: x.astype("float32").reshape(-1) / 255.0) | |
| .batch(batch_size) | |
| .prefetch(8, 4) | |
| ) | |
| running_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for batch in train_loader: | |
| images = mx.array(batch["image"]) | |
| labels = mx.array(batch["label"]) | |
| (loss, (logits, expM)), grads = loss_and_grad_fn(model, images, labels) | |
| optimizer.update(model, grads) | |
| # make sure updates & loss/logits are materialized | |
| mx.eval(model.parameters(), optimizer.state, loss, logits) | |
| if mx.isnan(loss).item(): | |
| print("Warning: NaN loss detected at batch!") | |
| break | |
| running_loss += loss.item() * labels.shape[0] | |
| predicted = mx.argmax(logits, axis=1) | |
| total += labels.shape[0] | |
| correct += mx.sum(predicted == labels).item() | |
| train_loss = running_loss / total if total > 0 else 0.0 | |
| train_acc = 100.0 * correct / total if total > 0 else 0.0 | |
| # --------- REBUILD TEST STREAM EACH EPOCH ---------- | |
| test_loader = ( | |
| test_dataset | |
| .to_stream() | |
| .key_transform("image", lambda x: x.astype("float32").reshape(-1) / 255.0) | |
| .batch(batch_size) | |
| .prefetch(8, 4) | |
| ) | |
| test_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| for batch in test_loader: | |
| images = mx.array(batch["image"]) | |
| labels = mx.array(batch["label"]) | |
| logits, expM = model(images) | |
| loss_ce = nn.losses.cross_entropy(logits, labels, reduction="mean") | |
| reg_term = mx.mean(mx.sum(expM ** 2, axis=(1, 2))) | |
| loss = loss_ce + lambda_reg * reg_term | |
| test_loss += loss.item() * labels.shape[0] | |
| predicted = mx.argmax(logits, axis=1) | |
| total += labels.shape[0] | |
| correct += mx.sum(predicted == labels).item() | |
| test_loss = test_loss / total if total > 0 else 0.0 | |
| test_acc = 100.0 * correct / total if total > 0 else 0.0 | |
| print( | |
| f"Epoch [{epoch}/{epochs}] " | |
| f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% " | |
| f"| Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%" | |
| ) | |
| return model | |
| if __name__ == "__main__": | |
| # Example run; you can tweak epochs etc. | |
| train_mnet(epochs=20) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment