Skip to content

Instantly share code, notes, and snippets.

@strnan
Created January 3, 2026 05:15
Show Gist options
  • Select an option

  • Save strnan/f91ba383c67c48243e9c06e1cc1e914f to your computer and use it in GitHub Desktop.

Select an option

Save strnan/f91ba383c67c48243e9c06e1cc1e914f to your computer and use it in GitHub Desktop.
Distributed training strategy submission
import math
import os
import torch
import torch.nn.utils as nn_utils
import torch.distributed as dist
import torch.fft
from einops import rearrange
import datetime
from copy import deepcopy
from dataclasses import dataclass
from torch.optim.lr_scheduler import LambdaLR
from typing import List, Type, Union, Optional, Dict, Any, TypeAlias, Callable, Iterable, Tuple
from abc import ABC, abstractmethod
from exogym.aux.utils import LogModule
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[dict[str, Any]]]
def mps_compatible(func):
def all_gather_wrapper(tensor_list, tensor, *args, **kwargs):
is_tensor_mps = hasattr(tensor, "device") and tensor.device.type == "mps"
is_list_mps = any(hasattr(t, "device") and t.device.type == "mps" for t in tensor_list)
if is_tensor_mps or is_list_mps:
cpu_tensor = tensor.data.to("cpu") if is_tensor_mps else tensor
cpu_tensor_list = [t.data.to("cpu") if hasattr(t, "device") and t.device.type == "mps" else t for t in tensor_list]
result = func(cpu_tensor_list, cpu_tensor, *args, **kwargs)
if is_tensor_mps: tensor.data.copy_(cpu_tensor.to("mps"))
for i, t in enumerate(tensor_list):
if hasattr(t, "device") and t.device.type == "mps": t.data.copy_(cpu_tensor_list[i].to("mps"))
return result
return func(tensor_list, tensor, *args, **kwargs)
def standard_wrapper(tensor, *args, **kwargs):
if hasattr(tensor, "device") and tensor.device.type == "mps":
cpu_tensor = tensor.data.to("cpu")
result = func(cpu_tensor, *args, **kwargs)
tensor.data.copy_(cpu_tensor.to("mps"))
return result
return func(tensor, *args, **kwargs)
return all_gather_wrapper if func.__name__ == "all_gather" else standard_wrapper
@mps_compatible
def broadcast(tensor, src=0): return dist.broadcast(tensor, src=src)
@mps_compatible
def all_reduce(tensor, op=dist.ReduceOp.SUM): return dist.all_reduce(tensor, op=op)
@mps_compatible
def all_gather(tensor_list, tensor, group=None, async_op=False):
return dist.all_gather(tensor_list, tensor, group=group, async_op=async_op)
def quantize(tensor, bins, q_range):
min_val, max_val = tensor.min(), tensor.max()
scale = (max_val - min_val) / (bins - 1)
zero_point = torch.round(-min_val / scale)
return torch.clamp(torch.round((tensor / scale) + zero_point), 0, bins - 1), scale, zero_point
def dequantize(quantized, scale, zero_point):
return (quantized - zero_point) * scale
def topk_sparsify(tensor, k, chunk_size):
abs_t = torch.abs(tensor)
values, indices = torch.topk(abs_t.view(-1), k)
mask = torch.zeros_like(tensor.view(-1))
mask[indices] = 1
return tensor.view(-1) * mask, mask.view(tensor.shape)
@dataclass
class OptimSpec:
cls: Type[torch.optim.Optimizer]
kwargs: Dict[str, Any]
def build(self, model):
return self.cls(model.parameters(), **(self.kwargs or {}))
def ensure_optim_spec(optim: Union[str, OptimSpec, None], default: Optional[OptimSpec] = None, **kwargs) -> OptimSpec:
if optim is None: return default or OptimSpec(torch.optim.AdamW, kwargs)
if isinstance(optim, OptimSpec): return optim
raise TypeError
class Strategy(ABC, LogModule):
def __init__(self, lr_scheduler=None, lr_scheduler_kwargs=None, **kwargs):
self.lr_scheduler = lr_scheduler
self.lr_scheduler_kwargs = lr_scheduler_kwargs or {}
self.kwargs = kwargs
self.scheduler = None
self.lr_callbacks = []
self.max_steps = 1
def _init_node(self, model, rank, num_nodes):
self.model = model
self.rank = rank
self.num_nodes = num_nodes
self.local_step = 0
@abstractmethod
def step(self):
self.local_step += 1
def zero_grad(self):
self.optim.zero_grad()
def _setup_scheduler(self):
def lr_lambda(step):
warmup = self.lr_scheduler_kwargs.get("warmup_steps", 1)
max_steps = self.lr_scheduler_kwargs.get("max_steps", self.max_steps)
if step < warmup: return step / max(1, warmup)
progress = (step - warmup) / max(1, max_steps - warmup)
return 0.5 * (1.0 + math.cos(math.pi * progress))
if self.lr_scheduler == "lambda_cosine": self.scheduler = LambdaLR(self.optim, lr_lambda)
def __config__(self):
return {"strategy": self.__class__.__name__}
class CommunicationModule(ABC):
@abstractmethod
def communicate(self, model, rank, num_nodes, local_step):
pass
@abstractmethod
def _init_node(self, model, rank, num_nodes):
pass
class SparseQuantizedDiLoCoCommunicator(CommunicationModule):
def __init__(self, H=15, outer_optim_spec=None, top_k=32, chunk_size=128, bins=8, q_range=4, adaptive=True):
self.H = H
self.outer_optim_spec = outer_optim_spec
self.top_k = top_k
self.chunk_size = chunk_size
self.bins = bins
self.q_range = q_range
self.adaptive = adaptive
self.error_buffers = None
def _init_node(self, model, rank, num_nodes):
self.pg = dist.new_group(backend="gloo", timeout=datetime.timedelta(60))
if rank == 0:
self.master_model = deepcopy(model).to("cpu")
for p in self.master_model.parameters(): p.requires_grad = True
self.outer_optim = self.outer_optim_spec.cls(self.master_model.parameters(), process_group=self.pg, **self.outer_optim_spec.kwargs)
self.error_buffers = {n: torch.zeros_like(p.data) for n, p in model.named_parameters()}
def communicate(self, model, rank, num_nodes, local_step):
if num_nodes > 1 and local_step > 0 and local_step % self.H == 0:
curr_k = self.top_k if not self.adaptive else max(16, int(self.top_k * (1 - local_step / 100)))
for n, p in model.named_parameters():
pseudo_grad = self.master_model.state_dict()[n].data.to(p.device) - p.data + self.error_buffers[n]
sparse_grad, mask = topk_sparsify(pseudo_grad, curr_k, self.chunk_size)
quantized, scale, zp = quantize(sparse_grad, self.bins, self.q_range)
handle = all_reduce(quantized, op=dist.ReduceOp.SUM, async_op=True)
handle.wait()
quantized /= num_nodes
dequantized = dequantize(quantized, scale, zp)
averaged = torch.zeros_like(pseudo_grad)
averaged[mask] = dequantized[mask]
self.error_buffers[n] = pseudo_grad - averaged
if rank == 0:
self.outer_optim.zero_grad()
for n, param in self.master_model.named_parameters():
param.grad = param.data - model.state_dict()[n].data.to("cpu")
self.outer_optim.step()
for n, p in model.named_parameters():
p.data.copy_(self.master_model.state_dict()[n].to(p.device))
for p in model.parameters(): broadcast(p.data, src=0)
class SparseQuantizedDiLoCoStrategy(CommunicateOptimizeStrategy):
def __init__(self, optim_spec, outer_optim_spec, H=15, top_k=32, chunk_size=128, bins=8, q_range=4, **kwargs):
self.comm = SparseQuantizedDiLoCoCommunicator(H=H, outer_optim_spec=outer_optim_spec, top_k=top_k, chunk_size=chunk_size, bins=bins, q_range=q_range)
super().__init__(communication_modules=[self.comm], optim_spec=optim_spec, **kwargs)
class SparseLoCo(torch.optim.SGD):
def __init__(self, params, lr, momentum=0.95, weight_decay=0.02, top_k=32, chunk_size=128, use_dct=False, use_quantization=True, quantization_bins=8, quantization_range=4, process_group=None, **kwargs):
super().__init__(params, lr=lr, momentum=momentum, weight_decay=0.0, **kwargs)
self.decoupled_weight_decay = weight_decay
self.process_group = process_group
@torch.no_grad()
def step(self, closure=None):
super().step()
STRATEGY = SparseQuantizedDiLoCoStrategy(
optim_spec=OptimSpec(torch.optim.AdamW, {"lr": 0.0006}),
outer_optim_spec=OptimSpec(SparseLoCo, {"lr": 0.7, "momentum": 0.95, "weight_decay": 0.02, "top_k": 32, "chunk_size": 128, "use_dct": False, "use_quantization": True, "quantization_bins": 8, "quantization_range": 4}),
lr_scheduler="lambda_cosine",
lr_scheduler_kwargs={"warmup_steps": 300, "max_steps": 100},
max_norm=1.0,
H=15,
top_k=32,
chunk_size=128,
bins=8,
q_range=4,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment