Created
March 4, 2026 22:03
-
-
Save soulitzer/1f9bd236a077502542e061934a818f9a to your computer and use it in GitHub Desktop.
Zero-copy gradient packing into a contiguous buffer via custom autograd Function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """Zero-copy gradient packing into a contiguous buffer. | |
| Demonstrates how a custom autograd Function can write gradients directly | |
| into a pre-allocated contiguous buffer, and have AccumulateGrad steal | |
| the views (Case 1.1) so that .grad points into the buffer with no copy. | |
| Key requirements for the steal path: | |
| 1. The gradient tensor must obey the layout contract (strides match the parameter) | |
| 2. The gradient tensor's refcount must be <= num_expected_refs (typically 1) | |
| - Clear ctx references before returning from backward | |
| - Don't hold extra references in outer scope | |
| """ | |
| import torch | |
| weight1 = torch.randn(3, 4, requires_grad=True) | |
| weight2 = torch.randn(5, 3, requires_grad=True) | |
| buf = torch.zeros(weight1.numel() + weight2.numel()) | |
| class PackedLinear(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, x, weight, buf_slice): | |
| ctx.save_for_backward(x) | |
| ctx.buf_slice = buf_slice | |
| return x @ weight.T | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| x, = ctx.saved_tensors | |
| out = ctx.buf_slice | |
| ctx.buf_slice = None # release ref so AccumulateGrad can steal | |
| torch.mm(grad_output.T, x, out=out) | |
| return None, out, None | |
| x = torch.randn(2, 4) | |
| y = PackedLinear.apply(x, weight1, buf[:12].view(3, 4)) | |
| z = PackedLinear.apply(y, weight2, buf[12:].view(5, 3)) | |
| z.sum().backward() | |
| assert weight1.grad.untyped_storage().data_ptr() == buf.untyped_storage().data_ptr() | |
| assert weight2.grad.untyped_storage().data_ptr() == buf.untyped_storage().data_ptr() | |
| print("Both .grad tensors share storage with buf — zero copy!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment