Created
October 24, 2025 21:26
-
-
Save Clybius/05c5ff2dbe7a8a26f8bd53aa2a2212e6 to your computer and use it in GitHub Desktop.
SNOO: Step-K Nesterov Outer Optimizer Implementation (Clybius Variant) (CPU)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch.optim import Optimizer | |
| """ | |
| Usage: Wrap your optimizer of choice with SnooC, default params seem to work fine. Params and such are copied into SnooC's internal optimizer (SGD), which are placed on the CPU to save VRAM. | |
| Example: | |
| ``` | |
| from SnooC_CPU import SnooC | |
| optimizer = SnooC(torch.optim.AdamW([params], lr=0.001, weight_decay=0.01)) | |
| ``` | |
| """ | |
| # Initial implementation from KellerJordan's repository: https://github.com/KellerJordan/modded-nanogpt/blob/cb68de1621d056bb9c6b8269f440211dec8400fe/train_gpt_medium.py | |
| class SnooC(Optimizer): | |
| """ | |
| @DominikKallusky, @vishal9-team, @vinaysrao | |
| Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can | |
| improve the stability and smoothness of the optimization process and thus the quality | |
| of large language models (LLM) and other models. Snoo implicitly adds temporal regularization | |
| to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter | |
| minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead | |
| in compute and moderate memory usage. | |
| This version is modified by Clybius to store its states on the CPU in order to reduce GPU memory usage, | |
| with delayed initialization for easier implementation in various frameworks. | |
| """ | |
| @torch.no_grad() | |
| def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None: | |
| self.optimizer = optimizer | |
| self.lr = lr | |
| self.momentum = momentum | |
| self.k = k | |
| self.current_step = 0 | |
| self.model_params = None | |
| self.outer_params = None | |
| self.outer_optimizer = None | |
| # Delay initialization until first step, as some training frameworks place param groups later *after* optimizer specification. | |
| if self.optimizer.param_groups: | |
| self.param_groups = self.optimizer.param_groups | |
| @torch.no_grad() | |
| def _initialize_outer_optimizer(self): | |
| params = [] | |
| for pg in self.optimizer.param_groups: | |
| if len(pg['params']) > 1: | |
| for param in pg['params']: | |
| if isinstance(param, torch.Tensor): | |
| params.append(param) | |
| else: | |
| params.extend(pg['params']) | |
| if not params: | |
| return | |
| self.model_params = list(params) | |
| self.outer_params = [p.clone().to('cpu') for p in self.model_params] | |
| # Use SGD as the 'outer' optimizer, as this is simply just Nesterov momentum. | |
| # TODO: Allow outer optimizer specification for user-choice. Test Muon with this (steamhappy emote here). | |
| self.outer_optimizer = torch.optim.SGD( | |
| self.outer_params, | |
| lr=self.lr, | |
| momentum=self.momentum, | |
| nesterov=True, | |
| ) | |
| self.param_groups = self.optimizer.param_groups | |
| del params | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| if self.outer_optimizer is None: | |
| # Delayed initialization at first step | |
| if self.optimizer.param_groups: | |
| self._initialize_outer_optimizer() | |
| # Fall back to the inner optimizer's step if no params. | |
| if self.outer_optimizer is None: | |
| return self.optimizer.step(closure) | |
| loss = self.optimizer.step(closure) | |
| if self.current_step % self.k == 0: | |
| # Perform the outer step on CPU | |
| for p_gpu, p_cpu in zip(self.model_params, self.outer_params): | |
| p_cpu.grad = p_cpu.data - p_gpu.data.to('cpu', non_blocking=True) | |
| self.outer_optimizer.step() | |
| for p_gpu, p_cpu in zip(self.model_params, self.outer_params): | |
| p_gpu.copy_(p_cpu.data, non_blocking=True) | |
| self.current_step += 1 | |
| return loss | |
| def zero_grad(self, set_to_none: bool = False): | |
| self.optimizer.zero_grad(set_to_none=set_to_none) | |
| def state_dict(self): | |
| return self.optimizer.state_dict() | |
| def load_state_dict(self, state_dict): | |
| self.optimizer.load_state_dict(state_dict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment