Skip to content

Instantly share code, notes, and snippets.

@CoryKornowicz
Created December 2, 2025 10:43
Show Gist options
  • Select an option

  • Save CoryKornowicz/83a3da4d86e9a9d68fb903c5a9259577 to your computer and use it in GitHub Desktop.

Select an option

Save CoryKornowicz/83a3da4d86e9a9d68fb903c5a9259577 to your computer and use it in GitHub Desktop.
MLX - Intelligent Matrix Exponentiation
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.data.datasets import load_mnist
@mx.compile
def matrix_exp(M):
"""
Compute matrix exponential using scaling and squaring with Padé approximation.
This implements a simplified version suitable for small matrices.
For a matrix M: exp(M) ≈ Taylor series or Padé approximation
Args:
M: Input matrix of shape (..., n, n)
Returns:
Matrix exponential of shape (..., n, n)
"""
# Use Taylor series approximation: exp(M) = I + M + M^2/2! + M^3/3! + ...
# We'll use enough terms for good accuracy
# Identity matrix
if M.ndim == 3:
batch_size, n, _ = M.shape
I = mx.eye(n)
I = mx.broadcast_to(I[None, :, :], (batch_size, n, n))
else:
I = mx.eye(M.shape[0])
# Initialize result with I + M
result = I + M
M_power = M
factorial = 1.0
# Add terms up to order 20 for good accuracy
# k = 2
factorial *= 2
M_power = M_power @ M
result = result + M_power / factorial
# k = 3
factorial *= 3
M_power = M_power @ M
result = result + M_power / factorial
return result
class MLayer(nn.Module):
"""
M-layer as described:
- Input: (B, C, H, W) or (B, in_dim)
- Output: logits (B, h), plus exp(M) for regularization
"""
def __init__(self, in_shape=(1, 28, 28), d=35, n=30, h=10, sigma=0.05):
super().__init__()
self.in_shape = in_shape
in_dim = 1
for s in in_shape:
in_dim *= s
self.d = d # latent dimension
self.n = n # matrix size
self.h = h # output dimension (# classes)
# Initialize parameters with proper scaling
scale = sigma
# \tilde{U}: input -> latent d
self.U = nn.Linear(in_dim, d, bias=False)
# Initialize U's weight
self.U.weight = mx.random.normal(shape=self.U.weight.shape) * scale
# \tilde{T}: (d, n, n) - stored as trainable parameters
self.T = mx.random.normal(shape=(d, n, n)) * scale
# \tilde{B}: (n, n)
self.B = mx.random.normal(shape=(n, n)) * scale
# \tilde{S}: (h, n, n)
self.S = mx.random.normal(shape=(h, n, n)) * scale
# \tilde{V}: (h,)
self.V = mx.random.normal(shape=(h,)) * scale
def __call__(self, x):
"""
x: (B, C, H, W) or (B, in_dim)
returns:
logits: (B, h)
expM: (B, n, n) # for activity regularization
"""
batch_size = x.shape[0]
# Flatten if needed
if x.ndim > 2:
x = mx.reshape(x, (batch_size, -1)) # (B, in_dim)
# Latent features Φ = U X -> (B, d)
phi = self.U(x) # (B, d)
# Build M: B_{jk} + Σ_a φ_a T_{a,j,k}
# phi: (B, d), T: (d, n, n) -> M: (B, n, n)
M = self.B + mx.einsum("bd,djk->bjk", phi, self.T)
# Nonlinearity: matrix exponential for each sample
# expM = matrix_exp(M) # (B, n, n)
# Launch the kernel - one thread per batch element
expM = matrix_exp(M)
# Output logits: V_m + Σ_{j,k} S_{m,j,k} * expM_{b,j,k}
# S: (h, n, n), expM: (B, n, n) -> logits: (B, h)
logits = self.V + mx.einsum("mjk,bjk->bm", self.S, expM)
return logits, expM
class MNet(nn.Module):
def __init__(self, d=35, n=30, h=10):
super().__init__()
self.m_layer = MLayer(in_shape=(1, 28, 28), d=d, n=n, h=h)
def __call__(self, x):
# Return logits and expM so the training loop can apply regularization
return self.m_layer(x)
def train_mnet(
epochs=10,
d=35,
n=30,
h=10,
lambda_reg=1e-4,
lr=1e-3,
momentum=0.9,
batch_size=64,
):
print(f"Using MLX with stream: {mx.default_device()}")
print("Performance optimizations: compiled loss function, async evaluation")
# Load datasets once
train_dataset = load_mnist(train=True)
test_dataset = load_mnist(train=False)
model = MNet(d=d, n=n, h=h)
optimizer = optim.SGD(learning_rate=lr, momentum=momentum)
def loss_fn(model, images, labels):
logits, expM = model(images)
loss_ce = nn.losses.cross_entropy(logits, labels, reduction="mean")
reg_term = mx.mean(mx.sum(expM ** 2, axis=(1, 2)))
loss = loss_ce + lambda_reg * reg_term
return loss, (logits, expM)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
for epoch in range(1, epochs + 1):
# --------- REBUILD TRAIN STREAM EACH EPOCH ----------
train_loader = (
train_dataset
.shuffle()
.to_stream()
.key_transform("image", lambda x: x.astype("float32").reshape(-1) / 255.0)
.batch(batch_size)
.prefetch(8, 4)
)
running_loss = 0.0
correct = 0
total = 0
for batch in train_loader:
images = mx.array(batch["image"])
labels = mx.array(batch["label"])
(loss, (logits, expM)), grads = loss_and_grad_fn(model, images, labels)
optimizer.update(model, grads)
# make sure updates & loss/logits are materialized
mx.eval(model.parameters(), optimizer.state, loss, logits)
if mx.isnan(loss).item():
print("Warning: NaN loss detected at batch!")
break
running_loss += loss.item() * labels.shape[0]
predicted = mx.argmax(logits, axis=1)
total += labels.shape[0]
correct += mx.sum(predicted == labels).item()
train_loss = running_loss / total if total > 0 else 0.0
train_acc = 100.0 * correct / total if total > 0 else 0.0
# --------- REBUILD TEST STREAM EACH EPOCH ----------
test_loader = (
test_dataset
.to_stream()
.key_transform("image", lambda x: x.astype("float32").reshape(-1) / 255.0)
.batch(batch_size)
.prefetch(8, 4)
)
test_loss = 0.0
correct = 0
total = 0
for batch in test_loader:
images = mx.array(batch["image"])
labels = mx.array(batch["label"])
logits, expM = model(images)
loss_ce = nn.losses.cross_entropy(logits, labels, reduction="mean")
reg_term = mx.mean(mx.sum(expM ** 2, axis=(1, 2)))
loss = loss_ce + lambda_reg * reg_term
test_loss += loss.item() * labels.shape[0]
predicted = mx.argmax(logits, axis=1)
total += labels.shape[0]
correct += mx.sum(predicted == labels).item()
test_loss = test_loss / total if total > 0 else 0.0
test_acc = 100.0 * correct / total if total > 0 else 0.0
print(
f"Epoch [{epoch}/{epochs}] "
f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% "
f"| Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%"
)
return model
if __name__ == "__main__":
# Example run; you can tweak epochs etc.
train_mnet(epochs=20)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment