Skip to content

Instantly share code, notes, and snippets.

@garrett361
Created February 25, 2026 22:05
Show Gist options
  • Select an option

  • Save garrett361/c2338998a96ad2ab73aeec61d458e541 to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/c2338998a96ad2ab73aeec61d458e541 to your computer and use it in GitHub Desktop.
FSDP mixed grad dtypes
❯ torchrun --nproc-per-node 2 test.py
W0225 22:03:36.129000 3702213 torch/distributed/run.py:852]
W0225 22:03:36.129000 3702213 torch/distributed/run.py:852] *****************************************
W0225 22:03:36.129000 3702213 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0225 22:03:36.129000 3702213 torch/distributed/run.py:852] *****************************************
[backward] lin1.weight: p.dtype=torch.bfloat16 grad_dtype=torch.float32
[backward] lin0.weight: p.dtype=torch.bfloat16 grad_dtype=torch.bfloat16
[Rank 0] FSDP reduce-scatter expects uniform gradient dtype but got {torch.bfloat16, torch.float32}
[...]
[Rank 1] FSDP reduce-scatter expects uniform gradient dtype but got {torch.float32, torch.bfloat16}
[...]
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard
d_model = 128
class DoubleLinear(nn.Module):
def __init__(self, d_model: int, device: torch.device) -> None:
super().__init__()
self.lin0 = nn.Linear(d_model, d_model, device=device, bias=False)
self.lin1 = nn.Linear(d_model, d_model, device=device, bias=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.lin0(inputs)
orig_dtype = self.lin1.weight.dtype
self.lin1.to(torch.float32)
outputs = self.lin1(outputs.to(torch.float32))
self.lin1.to(orig_dtype)
return outputs
if __name__ == "__main__":
RANK = int(os.environ["RANK"])
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{RANK}")
torch.cuda.set_device(device)
try:
dist.init_process_group(backend="nccl", rank=RANK, world_size=WORLD_SIZE, device_id=device)
model = DoubleLinear(d_model, device)
# Hooks for printing dtypes prior to FSDP hook averaging
def _make_grad_hook(mod_name):
def hook(module, grad_input, grad_output):
if not RANK:
for n, p in module.named_parameters(recurse=False):
grad_dtype = p.grad.dtype if p.grad is not None else None
print(f"[backward] {mod_name}.{n}: {p.dtype=} {grad_dtype=}")
return hook
model.lin0.register_full_backward_hook(_make_grad_hook("lin0"))
model.lin1.register_full_backward_hook(_make_grad_hook("lin1"))
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
fully_shard(model, mp_policy=mp_policy)
inputs = torch.randn(1, d_model, device=device, dtype=torch.bfloat16, requires_grad=True)
outputs = model(inputs)
outputs.sum().backward()
finally:
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment