Skip to content

Instantly share code, notes, and snippets.

@yberreby
Last active October 21, 2025 22:41
Show Gist options
  • Select an option

  • Save yberreby/4d117a2ff4571af60628c9f88c2a3988 to your computer and use it in GitHub Desktop.

Select an option

Save yberreby/4d117a2ff4571af60628c9f88c2a3988 to your computer and use it in GitHub Desktop.
A MLX-based foray into muP/µP (maximal update parameterization; cf. Tensor Programs V)
#!/usr/bin/env -S uv run
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "matplotlib",
# "polars>=1.34.0",
# "tqdm>=4.67.1",
# ]
# ///
#
# Install mlx separately - can use `mlx[cuda]` on Linux for GPU support.
# Just use `mlx` on macOS for GPU acceleration.
"""μP (Maximal Update Parametrization) learning rate transfer experiment in MLX."""
import glob
import json
import math
import argparse
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
import matplotlib.pyplot as plt
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.optimizers import MultiOptimizer
from mlx.utils import tree_flatten
import numpy as np
import numpy.typing as npt
import polars as pl
from tqdm import tqdm
##
# Data
##
def make_batch(n: int, d: int = 1) -> tuple[mx.array, mx.array]:
"""Generate sine regression data: y = sin(2πx) + noise"""
X = mx.random.uniform(-1, 1, (n, d))
y = mx.sin(2 * np.pi * X) + 0.1 * mx.random.normal((n, 1))
return X, y
##
# Configuration
##
@dataclass
class MuPConfig:
"""μP scaling configuration"""
width: int
depth: int
d_in: int
@property
def init_in(self) -> float:
"""Input layer init std: 1/√d_in"""
return 1.0 / math.sqrt(self.d_in)
@property
def init_hidden(self) -> float:
"""Hidden layer init std: 1/√width"""
return 1.0 / math.sqrt(self.width)
@property
def init_out(self) -> float:
"""Output layer init std: 1/width"""
return 1.0 / self.width
@property
def residual_scale(self) -> float:
"""Residual scaling: 1/√depth"""
return 1.0 / math.sqrt(self.depth)
@property
def lr_scale_in(self) -> float:
"""Input layer LR scaling: no width dependence"""
return 1.0
@property
def lr_scale_hidden(self) -> float:
"""Hidden layer LR scaling: 1/width"""
return 1.0 / self.width
@property
def lr_scale_out(self) -> float:
"""Output layer LR scaling: 1/width"""
return 1.0 / self.width
@property
def lr_scale_bias(self) -> float:
"""Bias LR scaling: no width dependence"""
return 1.0
@dataclass
class ExperimentConfig:
"""Experiment configuration"""
width: int
depth: int
d_in: int
batch_size: int
n_steps: int
n_seeds: int
##
# Parameter Filtering (for MultiOptimizer)
##
def _path_to_str(path: Any) -> str:
return (
"/".join(str(p) for p in path) if isinstance(path, (tuple, list)) else str(path)
)
def is_bias(path: Any, _: Any) -> bool:
return "bias" in _path_to_str(path)
def is_in_layer_weight(path: Any, _: Any) -> bool:
path_str = _path_to_str(path)
return "in_layer" in path_str and "weight" in path_str
def is_out_layer_weight(path: Any, _: Any) -> bool:
path_str = _path_to_str(path)
return "out_layer" in path_str and "weight" in path_str
##
# Model
##
class ResNet(nn.Module):
"""Simple residual MLP with optional μP initialization"""
def __init__(
self,
d_in: int,
width: int,
depth: int,
init_std_in: Optional[float],
init_std_hidden: Optional[float],
init_std_out: Optional[float],
use_layernorm: bool,
):
super().__init__()
self.in_layer = nn.Linear(d_in, width)
self.hidden_layers = [nn.Linear(width, width) for _ in range(depth)]
self.out_layer = nn.Linear(width, 1)
# Use Identity or LayerNorm based on config
self.norms = [
nn.LayerNorm(width) if use_layernorm else nn.Identity()
for _ in range(depth)
]
if init_std_in:
self._init_layer(self.in_layer, init_std_in)
if init_std_hidden:
for layer in self.hidden_layers:
self._init_layer(layer, init_std_hidden)
if init_std_out:
self._init_layer(self.out_layer, init_std_out)
def _init_layer(self, layer: nn.Linear, std: float) -> None:
layer.weight = mx.random.normal(layer.weight.shape, scale=std)
layer.bias = mx.zeros_like(layer.bias)
def __call__(self, x: mx.array, residual_scale: float) -> mx.array:
h = nn.relu(self.in_layer(x))
for layer, norm in zip(self.hidden_layers, self.norms):
h = h + residual_scale * nn.relu(layer(norm(h)))
return self.out_layer(h)
##
# Optimizer
##
def verify_parameter_filters(model: ResNet) -> None:
"""Sanity check: print which parameters match which filters"""
params = tree_flatten(model.parameters())
print("Parameter filter sanity check:")
for path, param in params:
path_str = _path_to_str(path)
categories = []
if is_bias(path, param):
categories.append("bias")
if is_in_layer_weight(path, param):
categories.append("in_layer")
if is_out_layer_weight(path, param):
categories.append("out_layer")
if not categories:
categories.append("hidden_layer")
print(f" {path_str:30s} → {', '.join(categories)}")
def create_optimizer(
base_lr: float, mup_config: Optional[MuPConfig] = None
) -> optim.Optimizer:
opt = optim.Adam
if mup_config is None:
return opt(learning_rate=base_lr)
# μP: per-layer learning rates
lr_in = base_lr * mup_config.lr_scale_in
lr_hidden = base_lr * mup_config.lr_scale_hidden
lr_out = base_lr * mup_config.lr_scale_out
lr_bias = base_lr * mup_config.lr_scale_bias
return MultiOptimizer(
optimizers=[
opt(learning_rate=lr_bias),
opt(learning_rate=lr_in),
opt(learning_rate=lr_out),
opt(learning_rate=lr_hidden),
],
filters=[is_bias, is_in_layer_weight, is_out_layer_weight],
)
##
# Training
##
def train_one_config(
base_lr: float,
batch_size: int,
model: ResNet,
residual_scale: float,
n_steps: int,
mup_config: Optional[MuPConfig],
seed: Optional[int],
) -> tuple[mx.array, mx.array]:
if seed is not None:
mx.random.seed(seed)
d_in = model.in_layer.weight.shape[1]
def loss_fn(model: ResNet, X: mx.array, y: mx.array) -> mx.array:
return mx.mean((model(X, residual_scale) - y) ** 2)
loss_and_grad = nn.value_and_grad(model, loss_fn)
optimizer = create_optimizer(base_lr, mup_config)
# Measure init loss
X_init, y_init = make_batch(batch_size, d_in)
init_loss = loss_fn(model, X_init, y_init)
# For early exit
divergence_threshold = 2.0 * init_loss.item()
# Training loop (infinite data regime)
for _ in tqdm(range(n_steps), desc=f"LR={base_lr:.0e}", leave=False):
X, y = make_batch(batch_size, d_in)
loss, grads = loss_and_grad(model, X, y)
optimizer.update(model, grads)
if loss.item() > divergence_threshold:
return init_loss, loss
X_final, y_final = make_batch(batch_size, d_in)
final_loss = loss_fn(model, X_final, y_final)
return init_loss, final_loss
def sweep_learning_rates(
cfg: ExperimentConfig,
lr_grid: npt.NDArray[np.floating[Any]],
mup_config: Optional[MuPConfig],
) -> tuple[npt.NDArray[np.floating[Any]], npt.NDArray[np.floating[Any]]]:
"""Returns (init_losses, final_losses) with shape (n_seeds, n_lrs)"""
results_init = np.zeros((cfg.n_seeds, len(lr_grid)))
results_final = np.zeros((cfg.n_seeds, len(lr_grid)))
if mup_config is None:
init_in = init_hidden = init_out = None
residual_scale = 1.0
else:
init_in = mup_config.init_in
init_hidden = mup_config.init_hidden
init_out = mup_config.init_out
residual_scale = mup_config.residual_scale
for lr_idx, lr in tqdm(
list(enumerate(lr_grid)), desc=f" LRs (w={cfg.width})", leave=False
):
for seed in tqdm(
range(cfg.n_seeds), desc=f" Seeds (lr={lr:.0e})", leave=False
):
model = ResNet(
cfg.d_in,
cfg.width,
cfg.depth,
init_in,
init_hidden,
init_out,
use_layernorm=True,
)
init_loss, final_loss = train_one_config(
float(lr),
cfg.batch_size,
model,
residual_scale,
cfg.n_steps,
mup_config,
seed,
)
results_init[seed, lr_idx] = init_loss.item()
results_final[seed, lr_idx] = final_loss.item()
# Early exit: if average loss diverged, all higher LRs will also diverge
avg_init = np.mean(results_init[:, lr_idx])
avg_final = np.mean(results_final[:, lr_idx])
if avg_final > 2.0 * avg_init:
# Mark remaining LRs as diverged (copy current diverged values)
for remaining_lr_idx in range(lr_idx + 1, len(lr_grid)):
results_init[:, remaining_lr_idx] = results_init[:, lr_idx]
results_final[:, remaining_lr_idx] = results_final[:, lr_idx]
break
return results_init, results_final
##
# Persistence
##
def save_results(
results_default: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]],
results_mup: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]],
widths: list[int],
lr_grid: npt.NDArray[np.floating[Any]],
base_lr_grid: npt.NDArray[np.floating[Any]],
metadata: dict[str, Any],
prefix: str,
) -> None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f"{prefix}_{timestamp}.parquet"
records = []
for width in widths:
init_def, final_def = results_default[width]
init_mup, final_mup = results_mup[width]
n_seeds, n_lrs = init_def.shape
for seed in range(n_seeds):
for lr_idx in range(n_lrs):
records.append(
{
"width": width,
"seed": seed,
"lr_idx": lr_idx,
"lr": lr_grid[lr_idx],
"base_lr": base_lr_grid[lr_idx],
"init_loss_default": init_def[seed, lr_idx],
"final_loss_default": final_def[seed, lr_idx],
"init_loss_mup": init_mup[seed, lr_idx],
"final_loss_mup": final_mup[seed, lr_idx],
}
)
df = pl.DataFrame(records)
# Add sentinel row for backward compatibility with filtering
metadata_row = {
"width": -1,
"seed": -1,
"lr_idx": -1,
"lr": 0.0,
"base_lr": 0.0,
"init_loss_default": 0.0,
"final_loss_default": 0.0,
"init_loss_mup": 0.0,
"final_loss_mup": 0.0,
}
df_with_meta = pl.concat([df, pl.DataFrame([metadata_row])])
# Store metadata as JSON column on all rows
metadata_json = json.dumps(metadata)
df_with_meta = df_with_meta.with_columns(pl.lit(metadata_json).alias("_metadata"))
df_with_meta.write_parquet(save_path)
print(f"Saved results to {save_path}")
def load_results(
load_path: str,
) -> tuple[
dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]],
dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]],
list[int],
npt.NDArray[np.floating[Any]],
npt.NDArray[np.floating[Any]],
dict[str, Any],
]:
df = pl.read_parquet(load_path)
# Extract metadata from JSON column
metadata = json.loads(df["_metadata"][0])
# Filter out metadata row
df = df.filter(pl.col("width") != -1)
widths = sorted(df["width"].unique().to_list())
n_lrs = df["lr_idx"].max() + 1
# Simplified queries
lr_grid = df.sort("lr")["lr"].unique().to_numpy()
base_lr_grid = df.sort("base_lr")["base_lr"].unique().to_numpy()
results_default = {}
results_mup = {}
loss_cols = [
"init_loss_default",
"final_loss_default",
"init_loss_mup",
"final_loss_mup",
]
for width in widths:
width_df = df.filter(pl.col("width") == width).sort(["seed", "lr_idx"])
losses = [
width_df[col].to_numpy().reshape(metadata["n_seeds"], n_lrs)
for col in loss_cols
]
results_default[width] = (losses[0], losses[1])
results_mup[width] = (losses[2], losses[3])
return results_default, results_mup, widths, lr_grid, base_lr_grid, metadata
##
# Visualization
##
def _plot_method(
ax: plt.Axes,
results: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]],
widths: list[int],
lr_grid: npt.NDArray[np.floating[Any]],
n_seeds: int,
colors: npt.NDArray[Any],
) -> npt.NDArray[Any]:
"""Plot results for one method (standard or μP)"""
style_kwargs = {"linewidth": 2}
fill_kwargs = {"alpha": 0.2}
for i, width in enumerate(widths):
init, final = results[width]
mean = np.mean(final, axis=0)
se = np.std(final, axis=0) / np.sqrt(n_seeds)
color = colors[i]
ax.plot(lr_grid, mean, label=f"w={width}", color=color, **style_kwargs) # type: ignore
ax.fill_between(lr_grid, mean - se, mean + se, color=color, **fill_kwargs) # type: ignore
# Return first init for reference line
return results[widths[0]][0]
def plot_comparison(
results_default: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]],
results_mup: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]],
widths: list[int],
lr_grid: npt.NDArray[np.floating[Any]],
n_seeds: int,
output_path: str,
) -> None:
# Styling constants
FIGSIZE = (14, 5)
COLORMAP_START = 0.2
COLORMAP_END = 0.9
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=FIGSIZE)
colors = plt.cm.viridis(np.linspace(COLORMAP_START, COLORMAP_END, len(widths))) # type: ignore
# Compute shared axis limits
all_losses = [
results[w][1] for results in [results_default, results_mup] for w in widths
]
ylim_min = min(np.min(losses) for losses in all_losses)
# Max is the highest mean init loss across all μP widths
mup_init_losses = [results_mup[w][0] for w in widths]
ylim_max = max(np.mean(losses) for losses in mup_init_losses)
# Plot both methods
init_def = _plot_method(ax1, results_default, widths, lr_grid, n_seeds, colors)
init_mup = _plot_method(ax2, results_mup, widths, lr_grid, n_seeds, colors)
# Add init loss reference lines
ref_kwargs = {
"linestyle": "--",
"color": "r",
"alpha": 0.5,
"linewidth": 2,
"label": "Init loss",
}
for ax, init_losses in [(ax1, init_def), (ax2, init_mup)]:
mean_init = np.mean(init_losses)
ax.plot([lr_grid[0], lr_grid[-1]], [mean_init, mean_init], **ref_kwargs)
# Configure axes
for ax in [ax1, ax2]:
ax.set_ylabel("Final Loss")
ax.set_xlabel("Learning Rate")
ax.set_xscale("log")
ax.set_xlim(lr_grid[0], lr_grid[-1])
ax.set_ylim(ylim_min, ylim_max)
ax.legend()
ax.grid(alpha=0.3)
ax1.set_title("Standard Parametrization")
ax2.set_title("μP")
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f"\nSaved plot to {output_path}")
##
# Main
##
def run_experiment(args: argparse.Namespace) -> None:
# Generate width range (powers of 2), largest first (slow stuff first)
widths = [
args.min_width * (2**i)
for i in range(int(np.log2(args.max_width / args.min_width)) + 1)
]
widths.reverse()
# Generate LR grid with batch size scaling (Goyal et al. linear scaling rule)
# base_lr is the learning rate for reference batch size 1
# actual lr scales linearly with batch size
reference_batch_size = 1
base_lr_grid = np.logspace(np.log10(args.min_lr), np.log10(args.max_lr), args.n_lrs)
lr_scale = args.batch_size / reference_batch_size
lr_grid = base_lr_grid * lr_scale
print("Running μP experiment:")
print(f" Widths: {widths}")
print(f" Depth: {args.depth}")
print(f" Batch size: {args.batch_size} (LR scaled by {lr_scale:.3f}x)")
print(f" Input dim: {args.d_in}")
print(f" LR range: [{lr_grid[0]:.0e}, {lr_grid[-1]:.0e}] ({args.n_lrs} points)")
print(f" Steps: {args.n_steps}, Seeds: {args.n_seeds}")
# Print parameter counts
print("\n Parameter counts:")
models = []
for width in widths:
model = ResNet(args.d_in, width, args.depth, None, None, None, True)
models.append(model)
n_params = sum(v.size for _, v in tree_flatten(model.parameters())) # type: ignore
print(f" w={width}: {n_params:,} params")
# Verify parameter filters on first model
print()
verify_parameter_filters(models[0])
print()
# Run experiments
results_default: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]] = {}
results_mup: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]] = {}
for width in tqdm(widths, desc="Widths"):
cfg = ExperimentConfig(
width,
args.depth,
args.d_in,
args.batch_size,
args.n_steps,
args.n_seeds,
)
results_default[width] = sweep_learning_rates(cfg, lr_grid, None)
results_mup[width] = sweep_learning_rates(
cfg, lr_grid, MuPConfig(width, args.depth, args.d_in)
)
# Save and plot
metadata = {
"depth": args.depth,
"batch_size": args.batch_size,
"n_steps": args.n_steps,
"n_seeds": args.n_seeds,
"d_in": args.d_in,
}
save_results(
results_default,
results_mup,
widths,
lr_grid,
base_lr_grid,
metadata,
args.results_prefix,
)
plot_comparison(
results_default, results_mup, widths, lr_grid, args.n_seeds, args.plot_file
)
def main() -> None:
parser = argparse.ArgumentParser(description="μP LR transfer experiment in MLX")
parser.add_argument("--depth", type=int, default=2)
parser.add_argument("--n-steps", type=int, default=400)
parser.add_argument("--n-seeds", type=int, default=4)
parser.add_argument("--min-lr", type=float, default=5e-11)
parser.add_argument("--max-lr", type=float, default=2e-3)
parser.add_argument("--n-lrs", type=int, default=32)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--d-in", type=int, default=1)
parser.add_argument("--min-width", type=int, default=128)
parser.add_argument("--max-width", type=int, default=1024)
parser.add_argument("--replot", action="store_true")
parser.add_argument("--results-prefix", type=str, default="mlx_mup_results")
parser.add_argument("--plot-file", type=str, default="mlx_mup.png")
args = parser.parse_args()
if args.replot:
pattern = f"{args.results_prefix}_*.parquet"
files = sorted(glob.glob(pattern))
if not files:
print(f"No results files found matching {pattern}")
return
load_path = files[-1]
print(f"Loading results from {load_path}...")
results_default, results_mup, widths, lr_grid, _, metadata = load_results(
load_path
)
print(f"Loaded results for widths: {widths}")
print(
f" Config: depth={metadata['depth']}, batch={metadata['batch_size']}, "
f"steps={metadata['n_steps']}, seeds={metadata['n_seeds']}"
)
plot_comparison(
results_default,
results_mup,
widths,
lr_grid,
metadata["n_seeds"],
args.plot_file,
)
else:
run_experiment(args)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment