Skip to content

Instantly share code, notes, and snippets.

@Clybius
Created October 24, 2025 21:26
Show Gist options
  • Select an option

  • Save Clybius/05c5ff2dbe7a8a26f8bd53aa2a2212e6 to your computer and use it in GitHub Desktop.

Select an option

Save Clybius/05c5ff2dbe7a8a26f8bd53aa2a2212e6 to your computer and use it in GitHub Desktop.
SNOO: Step-K Nesterov Outer Optimizer Implementation (Clybius Variant) (CPU)
import torch
from torch.optim import Optimizer
"""
Usage: Wrap your optimizer of choice with SnooC, default params seem to work fine. Params and such are copied into SnooC's internal optimizer (SGD), which are placed on the CPU to save VRAM.
Example:
```
from SnooC_CPU import SnooC
optimizer = SnooC(torch.optim.AdamW([params], lr=0.001, weight_decay=0.01))
```
"""
# Initial implementation from KellerJordan's repository: https://github.com/KellerJordan/modded-nanogpt/blob/cb68de1621d056bb9c6b8269f440211dec8400fe/train_gpt_medium.py
class SnooC(Optimizer):
"""
@DominikKallusky, @vishal9-team, @vinaysrao
Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can
improve the stability and smoothness of the optimization process and thus the quality
of large language models (LLM) and other models. Snoo implicitly adds temporal regularization
to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter
minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead
in compute and moderate memory usage.
This version is modified by Clybius to store its states on the CPU in order to reduce GPU memory usage,
with delayed initialization for easier implementation in various frameworks.
"""
@torch.no_grad()
def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None:
self.optimizer = optimizer
self.lr = lr
self.momentum = momentum
self.k = k
self.current_step = 0
self.model_params = None
self.outer_params = None
self.outer_optimizer = None
# Delay initialization until first step, as some training frameworks place param groups later *after* optimizer specification.
if self.optimizer.param_groups:
self.param_groups = self.optimizer.param_groups
@torch.no_grad()
def _initialize_outer_optimizer(self):
params = []
for pg in self.optimizer.param_groups:
if len(pg['params']) > 1:
for param in pg['params']:
if isinstance(param, torch.Tensor):
params.append(param)
else:
params.extend(pg['params'])
if not params:
return
self.model_params = list(params)
self.outer_params = [p.clone().to('cpu') for p in self.model_params]
# Use SGD as the 'outer' optimizer, as this is simply just Nesterov momentum.
# TODO: Allow outer optimizer specification for user-choice. Test Muon with this (steamhappy emote here).
self.outer_optimizer = torch.optim.SGD(
self.outer_params,
lr=self.lr,
momentum=self.momentum,
nesterov=True,
)
self.param_groups = self.optimizer.param_groups
del params
@torch.no_grad()
def step(self, closure=None):
if self.outer_optimizer is None:
# Delayed initialization at first step
if self.optimizer.param_groups:
self._initialize_outer_optimizer()
# Fall back to the inner optimizer's step if no params.
if self.outer_optimizer is None:
return self.optimizer.step(closure)
loss = self.optimizer.step(closure)
if self.current_step % self.k == 0:
# Perform the outer step on CPU
for p_gpu, p_cpu in zip(self.model_params, self.outer_params):
p_cpu.grad = p_cpu.data - p_gpu.data.to('cpu', non_blocking=True)
self.outer_optimizer.step()
for p_gpu, p_cpu in zip(self.model_params, self.outer_params):
p_gpu.copy_(p_cpu.data, non_blocking=True)
self.current_step += 1
return loss
def zero_grad(self, set_to_none: bool = False):
self.optimizer.zero_grad(set_to_none=set_to_none)
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment