Skip to content

Instantly share code, notes, and snippets.

@strnan
Created January 3, 2026 10:23
Show Gist options
  • Select an option

  • Save strnan/543ab257a28cda117a709fe318e61140 to your computer and use it in GitHub Desktop.

Select an option

Save strnan/543ab257a28cda117a709fe318e61140 to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import math
import os
import torch
import torch.nn.utils as nn_utils
import torch.distributed as dist
import torch.fft
from einops import rearrange
import datetime
from copy import deepcopy
from dataclasses import dataclass
from torch.optim.lr_scheduler import LambdaLR
from typing import List, Type, Union, Optional, Dict, Any, TypeAlias, Callable, Iterable, Tuple
from abc import ABC, abstractmethod
from exogym.aux.utils import LogModule
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[dict[str, Any]]]
def mps_compatible(func):
def all_gather_wrapper(tensor_list, tensor, *args, **kwargs):
is_tensor_mps = hasattr(tensor, "device") and tensor.device.type == "mps"
is_list_mps = any(hasattr(t, "device") and t.device.type == "mps" for t in tensor_list)
if is_tensor_mps or is_list_mps:
cpu_tensor = tensor.data.to("cpu") if is_tensor_mps else tensor
cpu_tensor_list = [
t.data.to("cpu") if hasattr(t, "device") and t.device.type == "mps" else t
for t in tensor_list
]
result = func(cpu_tensor_list, cpu_tensor, *args, **kwargs)
if is_tensor_mps:
tensor.data.copy_(cpu_tensor.to("mps"))
for i, t in enumerate(tensor_list):
if hasattr(t, "device") and t.device.type == "mps":
t.data.copy_(cpu_tensor_list[i].to("mps"))
return result
return func(tensor_list, tensor, *args, **kwargs)
def standard_wrapper(tensor, *args, **kwargs):
if hasattr(tensor, "device") and tensor.device.type == "mps":
cpu_tensor = tensor.data.to("cpu")
result = func(cpu_tensor, *args, **kwargs)
tensor.data.copy_(cpu_tensor.to("mps"))
return result
return func(tensor, *args, **kwargs)
return all_gather_wrapper if func.__name__ == "all_gather" else standard_wrapper
@mps_compatible
def broadcast(tensor, src=0):
return dist.broadcast(tensor, src=src)
@mps_compatible
def all_reduce(tensor, op=dist.ReduceOp.SUM):
return dist.all_reduce(tensor, op=op)
@mps_compatible
def all_gather(tensor_list, tensor, group=None, async_op=False):
return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **(self.kwargs or {}))
def ensure_optim_spec(
optim: Union[str, OptimSpec, None], default: Optional[OptimSpec] = None, **kwargs
) -> OptimSpec:
if optim is None:
return default or OptimSpec(torch.optim.AdamW, kwargs)
if isinstance(optim, OptimSpec):
return optim
raise TypeError
class Strategy(ABC, LogModule):
def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None, **kwargs):
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
self.kwargs = kwargs
self.scheduler = None
self.lr_callbacks = []
self.max_steps = 1
def _init_node(self, model, rank, num_nodes):
self.model = model
self.rank = rank
self.num_nodes = num_nodes
self.local_step = 0
@abstractmethod
def step(self):
self.local_step += 1
def zero_grad(self):
self.optim.zero_grad()
def _setup_scheduler(self):
def lr_lambda(step):
warmup = self.lr_scheduler_kwargs.get("warmup_steps", 1)
max_steps = self.lr_scheduler_kwargs.get("max_steps", self.max_steps)
if step < warmup:
return step / max(1, warmup)
progress = (step - warmup) / max(1, max_steps - warmup)
return 0.5 * (1.0 + math.cos(math.pi * progress))
if self.lr_scheduler == "lambda_cosine":
self.scheduler = LambdaLR(self.optim, lr_lambda)
def __config__(self):
return {"strategy": self.__class__.__name__}
class CommunicationModule(ABC):
@abstractmethod
def communicate(self, model, rank, num_nodes, local_step):
pass
@abstractmethod
def _init_node(self, model, rank, num_nodes):
pass
class CommunicateOptimizeStrategy(Strategy):
def __init__(self, communication_modules, optim_spec=None, max_norm=None, **kwargs):
super().__init__(**kwargs)
self.communication_modules = communication_modules
self.optim_spec = optim_spec
self.max_norm = max_norm
for m in self.communication_modules:
m.strategy = self
def _init_node(self, model, rank, num_nodes):
super()._init_node(model, rank, num_nodes)
self.optim = self.optim_spec.build(model)
self._setup_scheduler()
for m in self.communication_modules:
m._init_node(model, rank, num_nodes)
def step(self):
if self.max_norm:
nn_utils.clip_grad_norm_(self.model.parameters(), self.max_norm)
self.optim.step()
for m in self.communication_modules:
m.communicate(self.model, self.rank, self.num_nodes, self.local_step)
if self.scheduler:
self.scheduler.step()
self.local_step += 1
class DiLoCoCommunicator(CommunicationModule):
def __init__(self, H=25, outer_optim_spec=None):
self.H = H
self.outer_optim_spec = outer_optim_spec
def _init_node(self, model, rank, num_nodes):
self.pg = dist.new_group(backend="gloo", timeout=datetime.timedelta(60))
self.master_model = deepcopy(model).to("cpu")
for p in self.master_model.parameters():
p.requires_grad = True
self.outer_optim = self.outer_optim_spec.cls(
self.master_model.parameters(),
process_group=self.pg,
**self.outer_optim_spec.kwargs,
)
def communicate(self, model, rank, num_nodes, local_step):
if num_nodes > 1 and local_step > 0 and local_step % self.H == 0:
self.outer_optim.zero_grad()
for n, p in self.master_model.named_parameters():
p.grad = p.data - model.state_dict()[n].data.to("cpu")
self.outer_optim.step()
for n, p in model.named_parameters():
p.data.copy_(self.master_model.state_dict()[n].to(p.device))
class DiLoCoStrategy(CommunicateOptimizeStrategy):
def __init__(self, optim_spec, outer_optim_spec, H=25, **kwargs):
self.comm = DiLoCoCommunicator(H=H, outer_optim_spec=outer_optim_spec)
super().__init__(
communication_modules=[self.comm],
optim_spec=optim_spec,
**kwargs,
)
class SparseLoCo(torch.optim.SGD):
def __init__(
self,
params,
lr,
momentum=0.9,
weight_decay=0.05,
top_k=64,
chunk_size=64,
use_dct=True,
use_quantization=True,
quantization_bins=4,
quantization_range=6,
process_group=None,
**kwargs,
):
super().__init__(
params,
lr=lr,
momentum=momentum,
weight_decay=0.0,
**kwargs,
)
self.decoupled_weight_decay = weight_decay
self.process_group = process_group
self.top_k = top_k
self.chunk_size = chunk_size
self.use_dct = use_dct
self.use_quantization = use_quantization
self.quantization_bins = quantization_bins
self.quantization_range = quantization_range
def sparsify(self, tensor):
original_shape = tensor.shape
tensor = tensor.flatten()
num_elements = tensor.numel()
num_chunks = math.ceil(num_elements / self.chunk_size)
padded_size = num_chunks * self.chunk_size
padded_tensor = torch.zeros(padded_size, device=tensor.device)
padded_tensor[:num_elements] = tensor
chunks = rearrange(padded_tensor, '(c s) -> c s', s=self.chunk_size)
if self.use_dct:
chunks = torch.fft.dct(chunks, norm='ortho')
importance = chunks.abs()
k = min(self.top_k, self.chunk_size)
values, indices = torch.topk(importance, k, dim=-1)
if self.use_quantization:
min_v = -self.quantization_range
max_v = self.quantization_range
values = torch.clamp((values - min_v) / (max_v - min_v) * (self.quantization_bins - 1), 0, self.quantization_bins - 1)
values = values.int()
return values, indices, chunks.shape, original_shape, num_elements
def reconstruct(self, values, indices, chunk_shape, original_shape, num_elements, reduced=False):
if self.use_quantization:
min_v = -self.quantization_range
max_v = self.quantization_range
values = values.float() / (self.quantization_bins - 1) * (max_v - min_v) + min_v
chunks = torch.zeros(chunk_shape, device=values.device)
chunks.scatter_(-1, indices, values)
if self.use_dct:
chunks = torch.fft.idct(chunks, norm='ortho')
flat = rearrange(chunks, 'c s -> (c s)')
tensor = flat[:num_elements].reshape(original_shape)
if reduced:
tensor /= dist.get_world_size(self.process_group)
return tensor
@torch.no_grad()
def step(self, closure=None):
world_size = dist.get_world_size(self.process_group) if self.process_group else 1
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if self.decoupled_weight_decay != 0:
grad = grad.add(p.data, alpha=self.decoupled_weight_decay)
if world_size > 1:
values, indices, chunk_shape, original_shape, num_elements = self.sparsify(grad)
flat_values = values.flatten()
flat_indices = indices.flatten() + torch.arange(values.shape[0], device=indices.device).unsqueeze(1) * self.chunk_size
flat_indices = flat_indices.flatten()
all_values_list = [torch.zeros_like(flat_values) for _ in range(world_size)]
all_indices_list = [torch.zeros_like(flat_indices) for _ in range(world_size)]
dist.all_gather(all_values_list, flat_values, group=self.process_group)
dist.all_gather(all_indices_list, flat_indices, group=self.process_group)
reduced_grad = torch.zeros_like(grad.flatten())
for v, i in zip(all_values_list, all_indices_list):
reduced_grad.index_add_(0, i.long(), v.float())
reduced_grad = reduced_grad.reshape(original_shape)
reduced_grad /= world_size
if self.use_dct:
# Re-sparsify or handle, but for simplicity, use dense after
pass
grad = reduced_grad.to(p.grad.device)
state = self.state[p]
if 'momentum_buffer' not in state:
buf = state['momentum_buffer'] = torch.clone(grad).detach()
else:
buf = state['momentum_buffer']
buf.mul_(group['momentum']).add_(grad)
p.data.add_(buf, alpha=-group['lr'])
STRATEGY = DiLoCoStrategy(
optim_spec=OptimSpec(
torch.optim.AdamW,
{"lr": 0.001},
),
outer_optim_spec=OptimSpec(
SparseLoCo,
{
"lr": 0.8,
"momentum": 0.9,
"weight_decay": 0.05,
"top_k": 64,
"chunk_size": 64,
"use_dct": True,
"use_quantization": True,
"quantization_bins": 4,
"quantization_range": 6,
},
),
lr_scheduler="lambda_cosine",
lr_scheduler_kwargs={
"warmup_steps": 800,
"max_steps": 100,
},
max_norm=1.5,
H=25,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment