Skip to content

Instantly share code, notes, and snippets.

@soulitzer
Created March 4, 2026 22:03
Show Gist options
  • Select an option

  • Save soulitzer/1f9bd236a077502542e061934a818f9a to your computer and use it in GitHub Desktop.

Select an option

Save soulitzer/1f9bd236a077502542e061934a818f9a to your computer and use it in GitHub Desktop.
Zero-copy gradient packing into a contiguous buffer via custom autograd Function
"""Zero-copy gradient packing into a contiguous buffer.
Demonstrates how a custom autograd Function can write gradients directly
into a pre-allocated contiguous buffer, and have AccumulateGrad steal
the views (Case 1.1) so that .grad points into the buffer with no copy.
Key requirements for the steal path:
1. The gradient tensor must obey the layout contract (strides match the parameter)
2. The gradient tensor's refcount must be <= num_expected_refs (typically 1)
- Clear ctx references before returning from backward
- Don't hold extra references in outer scope
"""
import torch
weight1 = torch.randn(3, 4, requires_grad=True)
weight2 = torch.randn(5, 3, requires_grad=True)
buf = torch.zeros(weight1.numel() + weight2.numel())
class PackedLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, buf_slice):
ctx.save_for_backward(x)
ctx.buf_slice = buf_slice
return x @ weight.T
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
out = ctx.buf_slice
ctx.buf_slice = None # release ref so AccumulateGrad can steal
torch.mm(grad_output.T, x, out=out)
return None, out, None
x = torch.randn(2, 4)
y = PackedLinear.apply(x, weight1, buf[:12].view(3, 4))
z = PackedLinear.apply(y, weight2, buf[12:].view(5, 3))
z.sum().backward()
assert weight1.grad.untyped_storage().data_ptr() == buf.untyped_storage().data_ptr()
assert weight2.grad.untyped_storage().data_ptr() == buf.untyped_storage().data_ptr()
print("Both .grad tensors share storage with buf — zero copy!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment