Skip to content

Instantly share code, notes, and snippets.

@strnan
Created January 9, 2026 14:15
Show Gist options
  • Select an option

  • Save strnan/0d808d1e82545ef25b8cecb87f84478e to your computer and use it in GitHub Desktop.

Select an option

Save strnan/0d808d1e82545ef25b8cecb87f84478e to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import math
import os
import torch
import torch.nn.utils as nn_utils
import torch.distributed as dist
import datetime
from copy import deepcopy
from dataclasses import dataclass
from torch.optim.lr_scheduler import LambdaLR
from typing import Iterable, Optional, Dict, Any, TypeAlias, Type, Union
from abc import ABC, abstractmethod
from exogym.aux.utils import LogModule
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[dict[str, Any]]]
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
def mps_compatible(func):
def all_gather_wrapper(tensor_list, tensor, *args, **kwargs):
is_tensor_mps = hasattr(tensor, "device") and tensor.device.type == "mps"
is_list_mps = any(hasattr(t, "device") and t.device.type == "mps" for t in tensor_list)
if is_tensor_mps or is_list_mps:
cpu_tensor = tensor.data.to("cpu") if is_tensor_mps else tensor
cpu_tensor_list = [
t.data.to("cpu") if hasattr(t, "device") and t.device.type == "mps" else t
for t in tensor_list
]
result = func(cpu_tensor_list, cpu_tensor, *args, **kwargs)
if is_tensor_mps:
tensor.data.copy_(cpu_tensor.to("mps"))
for i, t in enumerate(tensor_list):
if hasattr(t, "device") and t.device.type == "mps":
t.data.copy_(cpu_tensor_list[i].to("mps"))
return result
return func(tensor_list, tensor, *args, **kwargs)
def standard_wrapper(tensor, *args, **kwargs):
if hasattr(tensor, "device") and tensor.device.type == "mps":
cpu_tensor = tensor.data.to("cpu")
result = func(cpu_tensor, *args, **kwargs)
tensor.data.copy_(cpu_tensor.to("mps"))
return result
return func(tensor, *args, **kwargs)
return all_gather_wrapper if func.__name__ == "all_gather" else standard_wrapper
@mps_compatible
def broadcast(tensor, src=0):
return dist.broadcast(tensor, src=src)
@mps_compatible
def all_reduce(tensor, op=dist.ReduceOp.SUM):
return dist.all_reduce(tensor, op=op)
@mps_compatible
def all_gather(tensor_list, tensor, group=None, async_op=False):
return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **(self.kwargs or {}))
class Strategy(ABC, LogModule):
def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None, **kwargs):
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
self.kwargs = kwargs
self.scheduler = None
self.lr_callbacks = []
self.max_steps = 1
def _init_node(self, model, rank, num_nodes):
if torch.cuda.is_available():
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
torch.cuda.set_device(local_rank)
self.model = model
self.rank = rank
self.num_nodes = num_nodes
self.local_step = 0
@abstractmethod
def step(self):
self.nbytes = 0
if self.scheduler is not None:
self.scheduler.step()
if self.rank == 0:
for cb in self.lr_callbacks:
cb(self.scheduler.get_last_lr()[0])
self.local_step += 1
def zero_grad(self):
self.optim.zero_grad(set_to_none=True)
def _setup_scheduler(self):
def lr_lambda(step: int):
warmup = int(self.lr_scheduler_kwargs.get("warmup_steps", 80))
max_steps = int(self.lr_scheduler_kwargs.get("max_steps", self.max_steps))
max_steps = max(1, max_steps)
warmup = max(1, min(warmup, max_steps))
if step < warmup:
return step / warmup
progress = (step - warmup) / max(1, max_steps - warmup)
progress = min(max(progress, 0.0), 1.0)
return 0.5 * (1.0 + math.cos(math.pi * progress))
if self.lr_scheduler == "lambda_cosine":
self.scheduler = LambdaLR(self.optim, lr_lambda)
def __config__(self):
return {"strategy": self.__class__.__name__}
class CommunicationModule(ABC):
@abstractmethod
def communicate(self, model, rank, num_nodes, local_step):
pass
@abstractmethod
def _init_node(self, model, rank, num_nodes):
pass
class CommunicateOptimizeStrategy(Strategy):
def __init__(self, communication_modules, optim_spec=None, max_norm=None, **kwargs):
super().__init__(**kwargs)
self.communication_modules = communication_modules
self.optim_spec = optim_spec
self.max_norm = max_norm
for m in self.communication_modules:
m.strategy = self
def _init_node(self, model, rank, num_nodes):
super()._init_node(model, rank, num_nodes)
self.optim = self.optim_spec.build(model)
self._setup_scheduler()
for m in self.communication_modules:
m._init_node(model, rank, num_nodes)
def step(self):
if self.max_norm:
nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
self.optim.step()
for m in self.communication_modules:
m.communicate(self.model, self.rank, self.num_nodes, self.local_step)
super().step()
class DiLoCoCommunicator(CommunicationModule):
def __init__(self, H=25, outer_optim_spec=None):
self.H = H
self.outer_optim_spec = outer_optim_spec
def _init_node(self, model, rank, num_nodes):
self.pg = dist.new_group(backend="gloo", timeout=datetime.timedelta(60))
self.master_model = deepcopy(model).to("cpu")
for p in self.master_model.parameters():
p.requires_grad = True
self.outer_optim = self.outer_optim_spec.cls(
self.master_model.parameters(),
process_group=self.pg,
**self.outer_optim_spec.kwargs,
)
def communicate(self, model, rank, num_nodes, local_step):
if num_nodes > 1 and local_step > 0 and local_step % self.H == 0:
self.outer_optim.zero_grad(set_to_none=True)
local_sd = model.state_dict()
for n, p in self.master_model.named_parameters():
p.grad = p.data - local_sd[n].data.to("cpu")
self.outer_optim.step()
master_sd = self.master_model.state_dict()
for n, p in model.named_parameters():
p.data.copy_(master_sd[n].to(p.device))
class DiLoCoStrategy(CommunicateOptimizeStrategy):
def __init__(self, optim_spec, outer_optim_spec, H=25, **kwargs):
self.comm = DiLoCoCommunicator(H=H, outer_optim_spec=outer_optim_spec)
super().__init__(
communication_modules=[self.comm],
optim_spec=optim_spec,
**kwargs,
)
class SparseLoCo(torch.optim.SGD):
def __init__(
self,
params,
lr,
momentum=0.9,
weight_decay=0.05,
top_k=64,
chunk_size=64,
use_dct=True,
use_quantization=True,
quantization_bins=4,
quantization_range=6,
process_group=None,
**kwargs,
):
super().__init__(params, lr=lr, momentum=momentum, weight_decay=0.0, **kwargs)
self.decoupled_weight_decay = weight_decay
self.process_group = process_group
@torch.no_grad()
def step(self, closure=None):
return super().step(closure=closure)
STRATEGY = DiLoCoStrategy(
optim_spec=OptimSpec(
torch.optim.AdamW,
{
"lr": 0.0012,
"betas": (0.9, 0.95),
"weight_decay": 0.08,
"eps": 1e-8,
},
),
outer_optim_spec=OptimSpec(
SparseLoCo,
{
"lr": 0.6,
"momentum": 0.9,
"weight_decay": 0.05,
"top_k": 64,
"chunk_size": 64,
"use_dct": True,
"use_quantization": True,
"quantization_bins": 4,
"quantization_range": 6,
},
),
lr_scheduler="lambda_cosine",
lr_scheduler_kwargs={
"warmup_steps": 80,
"max_steps": 100,
},
max_norm=1.0,
H=25,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment