Last active
October 21, 2025 22:41
-
-
Save yberreby/4d117a2ff4571af60628c9f88c2a3988 to your computer and use it in GitHub Desktop.
A MLX-based foray into muP/µP (maximal update parameterization; cf. Tensor Programs V)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env -S uv run | |
| # /// script | |
| # requires-python = ">=3.11" | |
| # dependencies = [ | |
| # "matplotlib", | |
| # "polars>=1.34.0", | |
| # "tqdm>=4.67.1", | |
| # ] | |
| # /// | |
| # | |
| # Install mlx separately - can use `mlx[cuda]` on Linux for GPU support. | |
| # Just use `mlx` on macOS for GPU acceleration. | |
| """μP (Maximal Update Parametrization) learning rate transfer experiment in MLX.""" | |
| import glob | |
| import json | |
| import math | |
| import argparse | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from typing import Any, Optional | |
| import matplotlib.pyplot as plt | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import mlx.optimizers as optim | |
| from mlx.optimizers import MultiOptimizer | |
| from mlx.utils import tree_flatten | |
| import numpy as np | |
| import numpy.typing as npt | |
| import polars as pl | |
| from tqdm import tqdm | |
| ## | |
| # Data | |
| ## | |
| def make_batch(n: int, d: int = 1) -> tuple[mx.array, mx.array]: | |
| """Generate sine regression data: y = sin(2πx) + noise""" | |
| X = mx.random.uniform(-1, 1, (n, d)) | |
| y = mx.sin(2 * np.pi * X) + 0.1 * mx.random.normal((n, 1)) | |
| return X, y | |
| ## | |
| # Configuration | |
| ## | |
| @dataclass | |
| class MuPConfig: | |
| """μP scaling configuration""" | |
| width: int | |
| depth: int | |
| d_in: int | |
| @property | |
| def init_in(self) -> float: | |
| """Input layer init std: 1/√d_in""" | |
| return 1.0 / math.sqrt(self.d_in) | |
| @property | |
| def init_hidden(self) -> float: | |
| """Hidden layer init std: 1/√width""" | |
| return 1.0 / math.sqrt(self.width) | |
| @property | |
| def init_out(self) -> float: | |
| """Output layer init std: 1/width""" | |
| return 1.0 / self.width | |
| @property | |
| def residual_scale(self) -> float: | |
| """Residual scaling: 1/√depth""" | |
| return 1.0 / math.sqrt(self.depth) | |
| @property | |
| def lr_scale_in(self) -> float: | |
| """Input layer LR scaling: no width dependence""" | |
| return 1.0 | |
| @property | |
| def lr_scale_hidden(self) -> float: | |
| """Hidden layer LR scaling: 1/width""" | |
| return 1.0 / self.width | |
| @property | |
| def lr_scale_out(self) -> float: | |
| """Output layer LR scaling: 1/width""" | |
| return 1.0 / self.width | |
| @property | |
| def lr_scale_bias(self) -> float: | |
| """Bias LR scaling: no width dependence""" | |
| return 1.0 | |
| @dataclass | |
| class ExperimentConfig: | |
| """Experiment configuration""" | |
| width: int | |
| depth: int | |
| d_in: int | |
| batch_size: int | |
| n_steps: int | |
| n_seeds: int | |
| ## | |
| # Parameter Filtering (for MultiOptimizer) | |
| ## | |
| def _path_to_str(path: Any) -> str: | |
| return ( | |
| "/".join(str(p) for p in path) if isinstance(path, (tuple, list)) else str(path) | |
| ) | |
| def is_bias(path: Any, _: Any) -> bool: | |
| return "bias" in _path_to_str(path) | |
| def is_in_layer_weight(path: Any, _: Any) -> bool: | |
| path_str = _path_to_str(path) | |
| return "in_layer" in path_str and "weight" in path_str | |
| def is_out_layer_weight(path: Any, _: Any) -> bool: | |
| path_str = _path_to_str(path) | |
| return "out_layer" in path_str and "weight" in path_str | |
| ## | |
| # Model | |
| ## | |
| class ResNet(nn.Module): | |
| """Simple residual MLP with optional μP initialization""" | |
| def __init__( | |
| self, | |
| d_in: int, | |
| width: int, | |
| depth: int, | |
| init_std_in: Optional[float], | |
| init_std_hidden: Optional[float], | |
| init_std_out: Optional[float], | |
| use_layernorm: bool, | |
| ): | |
| super().__init__() | |
| self.in_layer = nn.Linear(d_in, width) | |
| self.hidden_layers = [nn.Linear(width, width) for _ in range(depth)] | |
| self.out_layer = nn.Linear(width, 1) | |
| # Use Identity or LayerNorm based on config | |
| self.norms = [ | |
| nn.LayerNorm(width) if use_layernorm else nn.Identity() | |
| for _ in range(depth) | |
| ] | |
| if init_std_in: | |
| self._init_layer(self.in_layer, init_std_in) | |
| if init_std_hidden: | |
| for layer in self.hidden_layers: | |
| self._init_layer(layer, init_std_hidden) | |
| if init_std_out: | |
| self._init_layer(self.out_layer, init_std_out) | |
| def _init_layer(self, layer: nn.Linear, std: float) -> None: | |
| layer.weight = mx.random.normal(layer.weight.shape, scale=std) | |
| layer.bias = mx.zeros_like(layer.bias) | |
| def __call__(self, x: mx.array, residual_scale: float) -> mx.array: | |
| h = nn.relu(self.in_layer(x)) | |
| for layer, norm in zip(self.hidden_layers, self.norms): | |
| h = h + residual_scale * nn.relu(layer(norm(h))) | |
| return self.out_layer(h) | |
| ## | |
| # Optimizer | |
| ## | |
| def verify_parameter_filters(model: ResNet) -> None: | |
| """Sanity check: print which parameters match which filters""" | |
| params = tree_flatten(model.parameters()) | |
| print("Parameter filter sanity check:") | |
| for path, param in params: | |
| path_str = _path_to_str(path) | |
| categories = [] | |
| if is_bias(path, param): | |
| categories.append("bias") | |
| if is_in_layer_weight(path, param): | |
| categories.append("in_layer") | |
| if is_out_layer_weight(path, param): | |
| categories.append("out_layer") | |
| if not categories: | |
| categories.append("hidden_layer") | |
| print(f" {path_str:30s} → {', '.join(categories)}") | |
| def create_optimizer( | |
| base_lr: float, mup_config: Optional[MuPConfig] = None | |
| ) -> optim.Optimizer: | |
| opt = optim.Adam | |
| if mup_config is None: | |
| return opt(learning_rate=base_lr) | |
| # μP: per-layer learning rates | |
| lr_in = base_lr * mup_config.lr_scale_in | |
| lr_hidden = base_lr * mup_config.lr_scale_hidden | |
| lr_out = base_lr * mup_config.lr_scale_out | |
| lr_bias = base_lr * mup_config.lr_scale_bias | |
| return MultiOptimizer( | |
| optimizers=[ | |
| opt(learning_rate=lr_bias), | |
| opt(learning_rate=lr_in), | |
| opt(learning_rate=lr_out), | |
| opt(learning_rate=lr_hidden), | |
| ], | |
| filters=[is_bias, is_in_layer_weight, is_out_layer_weight], | |
| ) | |
| ## | |
| # Training | |
| ## | |
| def train_one_config( | |
| base_lr: float, | |
| batch_size: int, | |
| model: ResNet, | |
| residual_scale: float, | |
| n_steps: int, | |
| mup_config: Optional[MuPConfig], | |
| seed: Optional[int], | |
| ) -> tuple[mx.array, mx.array]: | |
| if seed is not None: | |
| mx.random.seed(seed) | |
| d_in = model.in_layer.weight.shape[1] | |
| def loss_fn(model: ResNet, X: mx.array, y: mx.array) -> mx.array: | |
| return mx.mean((model(X, residual_scale) - y) ** 2) | |
| loss_and_grad = nn.value_and_grad(model, loss_fn) | |
| optimizer = create_optimizer(base_lr, mup_config) | |
| # Measure init loss | |
| X_init, y_init = make_batch(batch_size, d_in) | |
| init_loss = loss_fn(model, X_init, y_init) | |
| # For early exit | |
| divergence_threshold = 2.0 * init_loss.item() | |
| # Training loop (infinite data regime) | |
| for _ in tqdm(range(n_steps), desc=f"LR={base_lr:.0e}", leave=False): | |
| X, y = make_batch(batch_size, d_in) | |
| loss, grads = loss_and_grad(model, X, y) | |
| optimizer.update(model, grads) | |
| if loss.item() > divergence_threshold: | |
| return init_loss, loss | |
| X_final, y_final = make_batch(batch_size, d_in) | |
| final_loss = loss_fn(model, X_final, y_final) | |
| return init_loss, final_loss | |
| def sweep_learning_rates( | |
| cfg: ExperimentConfig, | |
| lr_grid: npt.NDArray[np.floating[Any]], | |
| mup_config: Optional[MuPConfig], | |
| ) -> tuple[npt.NDArray[np.floating[Any]], npt.NDArray[np.floating[Any]]]: | |
| """Returns (init_losses, final_losses) with shape (n_seeds, n_lrs)""" | |
| results_init = np.zeros((cfg.n_seeds, len(lr_grid))) | |
| results_final = np.zeros((cfg.n_seeds, len(lr_grid))) | |
| if mup_config is None: | |
| init_in = init_hidden = init_out = None | |
| residual_scale = 1.0 | |
| else: | |
| init_in = mup_config.init_in | |
| init_hidden = mup_config.init_hidden | |
| init_out = mup_config.init_out | |
| residual_scale = mup_config.residual_scale | |
| for lr_idx, lr in tqdm( | |
| list(enumerate(lr_grid)), desc=f" LRs (w={cfg.width})", leave=False | |
| ): | |
| for seed in tqdm( | |
| range(cfg.n_seeds), desc=f" Seeds (lr={lr:.0e})", leave=False | |
| ): | |
| model = ResNet( | |
| cfg.d_in, | |
| cfg.width, | |
| cfg.depth, | |
| init_in, | |
| init_hidden, | |
| init_out, | |
| use_layernorm=True, | |
| ) | |
| init_loss, final_loss = train_one_config( | |
| float(lr), | |
| cfg.batch_size, | |
| model, | |
| residual_scale, | |
| cfg.n_steps, | |
| mup_config, | |
| seed, | |
| ) | |
| results_init[seed, lr_idx] = init_loss.item() | |
| results_final[seed, lr_idx] = final_loss.item() | |
| # Early exit: if average loss diverged, all higher LRs will also diverge | |
| avg_init = np.mean(results_init[:, lr_idx]) | |
| avg_final = np.mean(results_final[:, lr_idx]) | |
| if avg_final > 2.0 * avg_init: | |
| # Mark remaining LRs as diverged (copy current diverged values) | |
| for remaining_lr_idx in range(lr_idx + 1, len(lr_grid)): | |
| results_init[:, remaining_lr_idx] = results_init[:, lr_idx] | |
| results_final[:, remaining_lr_idx] = results_final[:, lr_idx] | |
| break | |
| return results_init, results_final | |
| ## | |
| # Persistence | |
| ## | |
| def save_results( | |
| results_default: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]], | |
| results_mup: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]], | |
| widths: list[int], | |
| lr_grid: npt.NDArray[np.floating[Any]], | |
| base_lr_grid: npt.NDArray[np.floating[Any]], | |
| metadata: dict[str, Any], | |
| prefix: str, | |
| ) -> None: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path = f"{prefix}_{timestamp}.parquet" | |
| records = [] | |
| for width in widths: | |
| init_def, final_def = results_default[width] | |
| init_mup, final_mup = results_mup[width] | |
| n_seeds, n_lrs = init_def.shape | |
| for seed in range(n_seeds): | |
| for lr_idx in range(n_lrs): | |
| records.append( | |
| { | |
| "width": width, | |
| "seed": seed, | |
| "lr_idx": lr_idx, | |
| "lr": lr_grid[lr_idx], | |
| "base_lr": base_lr_grid[lr_idx], | |
| "init_loss_default": init_def[seed, lr_idx], | |
| "final_loss_default": final_def[seed, lr_idx], | |
| "init_loss_mup": init_mup[seed, lr_idx], | |
| "final_loss_mup": final_mup[seed, lr_idx], | |
| } | |
| ) | |
| df = pl.DataFrame(records) | |
| # Add sentinel row for backward compatibility with filtering | |
| metadata_row = { | |
| "width": -1, | |
| "seed": -1, | |
| "lr_idx": -1, | |
| "lr": 0.0, | |
| "base_lr": 0.0, | |
| "init_loss_default": 0.0, | |
| "final_loss_default": 0.0, | |
| "init_loss_mup": 0.0, | |
| "final_loss_mup": 0.0, | |
| } | |
| df_with_meta = pl.concat([df, pl.DataFrame([metadata_row])]) | |
| # Store metadata as JSON column on all rows | |
| metadata_json = json.dumps(metadata) | |
| df_with_meta = df_with_meta.with_columns(pl.lit(metadata_json).alias("_metadata")) | |
| df_with_meta.write_parquet(save_path) | |
| print(f"Saved results to {save_path}") | |
| def load_results( | |
| load_path: str, | |
| ) -> tuple[ | |
| dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]], | |
| dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]], | |
| list[int], | |
| npt.NDArray[np.floating[Any]], | |
| npt.NDArray[np.floating[Any]], | |
| dict[str, Any], | |
| ]: | |
| df = pl.read_parquet(load_path) | |
| # Extract metadata from JSON column | |
| metadata = json.loads(df["_metadata"][0]) | |
| # Filter out metadata row | |
| df = df.filter(pl.col("width") != -1) | |
| widths = sorted(df["width"].unique().to_list()) | |
| n_lrs = df["lr_idx"].max() + 1 | |
| # Simplified queries | |
| lr_grid = df.sort("lr")["lr"].unique().to_numpy() | |
| base_lr_grid = df.sort("base_lr")["base_lr"].unique().to_numpy() | |
| results_default = {} | |
| results_mup = {} | |
| loss_cols = [ | |
| "init_loss_default", | |
| "final_loss_default", | |
| "init_loss_mup", | |
| "final_loss_mup", | |
| ] | |
| for width in widths: | |
| width_df = df.filter(pl.col("width") == width).sort(["seed", "lr_idx"]) | |
| losses = [ | |
| width_df[col].to_numpy().reshape(metadata["n_seeds"], n_lrs) | |
| for col in loss_cols | |
| ] | |
| results_default[width] = (losses[0], losses[1]) | |
| results_mup[width] = (losses[2], losses[3]) | |
| return results_default, results_mup, widths, lr_grid, base_lr_grid, metadata | |
| ## | |
| # Visualization | |
| ## | |
| def _plot_method( | |
| ax: plt.Axes, | |
| results: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]], | |
| widths: list[int], | |
| lr_grid: npt.NDArray[np.floating[Any]], | |
| n_seeds: int, | |
| colors: npt.NDArray[Any], | |
| ) -> npt.NDArray[Any]: | |
| """Plot results for one method (standard or μP)""" | |
| style_kwargs = {"linewidth": 2} | |
| fill_kwargs = {"alpha": 0.2} | |
| for i, width in enumerate(widths): | |
| init, final = results[width] | |
| mean = np.mean(final, axis=0) | |
| se = np.std(final, axis=0) / np.sqrt(n_seeds) | |
| color = colors[i] | |
| ax.plot(lr_grid, mean, label=f"w={width}", color=color, **style_kwargs) # type: ignore | |
| ax.fill_between(lr_grid, mean - se, mean + se, color=color, **fill_kwargs) # type: ignore | |
| # Return first init for reference line | |
| return results[widths[0]][0] | |
| def plot_comparison( | |
| results_default: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]], | |
| results_mup: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]], | |
| widths: list[int], | |
| lr_grid: npt.NDArray[np.floating[Any]], | |
| n_seeds: int, | |
| output_path: str, | |
| ) -> None: | |
| # Styling constants | |
| FIGSIZE = (14, 5) | |
| COLORMAP_START = 0.2 | |
| COLORMAP_END = 0.9 | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=FIGSIZE) | |
| colors = plt.cm.viridis(np.linspace(COLORMAP_START, COLORMAP_END, len(widths))) # type: ignore | |
| # Compute shared axis limits | |
| all_losses = [ | |
| results[w][1] for results in [results_default, results_mup] for w in widths | |
| ] | |
| ylim_min = min(np.min(losses) for losses in all_losses) | |
| # Max is the highest mean init loss across all μP widths | |
| mup_init_losses = [results_mup[w][0] for w in widths] | |
| ylim_max = max(np.mean(losses) for losses in mup_init_losses) | |
| # Plot both methods | |
| init_def = _plot_method(ax1, results_default, widths, lr_grid, n_seeds, colors) | |
| init_mup = _plot_method(ax2, results_mup, widths, lr_grid, n_seeds, colors) | |
| # Add init loss reference lines | |
| ref_kwargs = { | |
| "linestyle": "--", | |
| "color": "r", | |
| "alpha": 0.5, | |
| "linewidth": 2, | |
| "label": "Init loss", | |
| } | |
| for ax, init_losses in [(ax1, init_def), (ax2, init_mup)]: | |
| mean_init = np.mean(init_losses) | |
| ax.plot([lr_grid[0], lr_grid[-1]], [mean_init, mean_init], **ref_kwargs) | |
| # Configure axes | |
| for ax in [ax1, ax2]: | |
| ax.set_ylabel("Final Loss") | |
| ax.set_xlabel("Learning Rate") | |
| ax.set_xscale("log") | |
| ax.set_xlim(lr_grid[0], lr_grid[-1]) | |
| ax.set_ylim(ylim_min, ylim_max) | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| ax1.set_title("Standard Parametrization") | |
| ax2.set_title("μP") | |
| plt.tight_layout() | |
| plt.savefig(output_path, dpi=150, bbox_inches="tight") | |
| print(f"\nSaved plot to {output_path}") | |
| ## | |
| # Main | |
| ## | |
| def run_experiment(args: argparse.Namespace) -> None: | |
| # Generate width range (powers of 2), largest first (slow stuff first) | |
| widths = [ | |
| args.min_width * (2**i) | |
| for i in range(int(np.log2(args.max_width / args.min_width)) + 1) | |
| ] | |
| widths.reverse() | |
| # Generate LR grid with batch size scaling (Goyal et al. linear scaling rule) | |
| # base_lr is the learning rate for reference batch size 1 | |
| # actual lr scales linearly with batch size | |
| reference_batch_size = 1 | |
| base_lr_grid = np.logspace(np.log10(args.min_lr), np.log10(args.max_lr), args.n_lrs) | |
| lr_scale = args.batch_size / reference_batch_size | |
| lr_grid = base_lr_grid * lr_scale | |
| print("Running μP experiment:") | |
| print(f" Widths: {widths}") | |
| print(f" Depth: {args.depth}") | |
| print(f" Batch size: {args.batch_size} (LR scaled by {lr_scale:.3f}x)") | |
| print(f" Input dim: {args.d_in}") | |
| print(f" LR range: [{lr_grid[0]:.0e}, {lr_grid[-1]:.0e}] ({args.n_lrs} points)") | |
| print(f" Steps: {args.n_steps}, Seeds: {args.n_seeds}") | |
| # Print parameter counts | |
| print("\n Parameter counts:") | |
| models = [] | |
| for width in widths: | |
| model = ResNet(args.d_in, width, args.depth, None, None, None, True) | |
| models.append(model) | |
| n_params = sum(v.size for _, v in tree_flatten(model.parameters())) # type: ignore | |
| print(f" w={width}: {n_params:,} params") | |
| # Verify parameter filters on first model | |
| print() | |
| verify_parameter_filters(models[0]) | |
| print() | |
| # Run experiments | |
| results_default: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]] = {} | |
| results_mup: dict[int, tuple[npt.NDArray[Any], npt.NDArray[Any]]] = {} | |
| for width in tqdm(widths, desc="Widths"): | |
| cfg = ExperimentConfig( | |
| width, | |
| args.depth, | |
| args.d_in, | |
| args.batch_size, | |
| args.n_steps, | |
| args.n_seeds, | |
| ) | |
| results_default[width] = sweep_learning_rates(cfg, lr_grid, None) | |
| results_mup[width] = sweep_learning_rates( | |
| cfg, lr_grid, MuPConfig(width, args.depth, args.d_in) | |
| ) | |
| # Save and plot | |
| metadata = { | |
| "depth": args.depth, | |
| "batch_size": args.batch_size, | |
| "n_steps": args.n_steps, | |
| "n_seeds": args.n_seeds, | |
| "d_in": args.d_in, | |
| } | |
| save_results( | |
| results_default, | |
| results_mup, | |
| widths, | |
| lr_grid, | |
| base_lr_grid, | |
| metadata, | |
| args.results_prefix, | |
| ) | |
| plot_comparison( | |
| results_default, results_mup, widths, lr_grid, args.n_seeds, args.plot_file | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="μP LR transfer experiment in MLX") | |
| parser.add_argument("--depth", type=int, default=2) | |
| parser.add_argument("--n-steps", type=int, default=400) | |
| parser.add_argument("--n-seeds", type=int, default=4) | |
| parser.add_argument("--min-lr", type=float, default=5e-11) | |
| parser.add_argument("--max-lr", type=float, default=2e-3) | |
| parser.add_argument("--n-lrs", type=int, default=32) | |
| parser.add_argument("--batch-size", type=int, default=256) | |
| parser.add_argument("--d-in", type=int, default=1) | |
| parser.add_argument("--min-width", type=int, default=128) | |
| parser.add_argument("--max-width", type=int, default=1024) | |
| parser.add_argument("--replot", action="store_true") | |
| parser.add_argument("--results-prefix", type=str, default="mlx_mup_results") | |
| parser.add_argument("--plot-file", type=str, default="mlx_mup.png") | |
| args = parser.parse_args() | |
| if args.replot: | |
| pattern = f"{args.results_prefix}_*.parquet" | |
| files = sorted(glob.glob(pattern)) | |
| if not files: | |
| print(f"No results files found matching {pattern}") | |
| return | |
| load_path = files[-1] | |
| print(f"Loading results from {load_path}...") | |
| results_default, results_mup, widths, lr_grid, _, metadata = load_results( | |
| load_path | |
| ) | |
| print(f"Loaded results for widths: {widths}") | |
| print( | |
| f" Config: depth={metadata['depth']}, batch={metadata['batch_size']}, " | |
| f"steps={metadata['n_steps']}, seeds={metadata['n_seeds']}" | |
| ) | |
| plot_comparison( | |
| results_default, | |
| results_mup, | |
| widths, | |
| lr_grid, | |
| metadata["n_seeds"], | |
| args.plot_file, | |
| ) | |
| else: | |
| run_experiment(args) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
