Created
February 25, 2026 22:05
-
-
Save garrett361/c2338998a96ad2ab73aeec61d458e541 to your computer and use it in GitHub Desktop.
FSDP mixed grad dtypes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ❯ torchrun --nproc-per-node 2 test.py | |
| W0225 22:03:36.129000 3702213 torch/distributed/run.py:852] | |
| W0225 22:03:36.129000 3702213 torch/distributed/run.py:852] ***************************************** | |
| W0225 22:03:36.129000 3702213 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. | |
| W0225 22:03:36.129000 3702213 torch/distributed/run.py:852] ***************************************** | |
| [backward] lin1.weight: p.dtype=torch.bfloat16 grad_dtype=torch.float32 | |
| [backward] lin0.weight: p.dtype=torch.bfloat16 grad_dtype=torch.bfloat16 | |
| [Rank 0] FSDP reduce-scatter expects uniform gradient dtype but got {torch.bfloat16, torch.float32} | |
| [...] | |
| [Rank 1] FSDP reduce-scatter expects uniform gradient dtype but got {torch.float32, torch.bfloat16} | |
| [...] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard | |
| d_model = 128 | |
| class DoubleLinear(nn.Module): | |
| def __init__(self, d_model: int, device: torch.device) -> None: | |
| super().__init__() | |
| self.lin0 = nn.Linear(d_model, d_model, device=device, bias=False) | |
| self.lin1 = nn.Linear(d_model, d_model, device=device, bias=False) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| outputs = self.lin0(inputs) | |
| orig_dtype = self.lin1.weight.dtype | |
| self.lin1.to(torch.float32) | |
| outputs = self.lin1(outputs.to(torch.float32)) | |
| self.lin1.to(orig_dtype) | |
| return outputs | |
| if __name__ == "__main__": | |
| RANK = int(os.environ["RANK"]) | |
| LOCAL_RANK = int(os.environ["LOCAL_RANK"]) | |
| WORLD_SIZE = int(os.environ["WORLD_SIZE"]) | |
| device = torch.device(f"cuda:{RANK}") | |
| torch.cuda.set_device(device) | |
| try: | |
| dist.init_process_group(backend="nccl", rank=RANK, world_size=WORLD_SIZE, device_id=device) | |
| model = DoubleLinear(d_model, device) | |
| # Hooks for printing dtypes prior to FSDP hook averaging | |
| def _make_grad_hook(mod_name): | |
| def hook(module, grad_input, grad_output): | |
| if not RANK: | |
| for n, p in module.named_parameters(recurse=False): | |
| grad_dtype = p.grad.dtype if p.grad is not None else None | |
| print(f"[backward] {mod_name}.{n}: {p.dtype=} {grad_dtype=}") | |
| return hook | |
| model.lin0.register_full_backward_hook(_make_grad_hook("lin0")) | |
| model.lin1.register_full_backward_hook(_make_grad_hook("lin1")) | |
| mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) | |
| fully_shard(model, mp_policy=mp_policy) | |
| inputs = torch.randn(1, d_model, device=device, dtype=torch.bfloat16, requires_grad=True) | |
| outputs = model(inputs) | |
| outputs.sum().backward() | |
| finally: | |
| dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment