Skip to content

Instantly share code, notes, and snippets.

@sourabh2k15
Created February 24, 2026 04:47
Show Gist options
  • Select an option

  • Save sourabh2k15/532e785c12f875e7634cb8413c10c68c to your computer and use it in GitHub Desktop.

Select an option

Save sourabh2k15/532e785c12f875e7634cb8413c10c68c to your computer and use it in GitHub Desktop.
#@title Custom Linear layer backward pass
from torch.autograd.function import once_differentiable
import torch.nn as nn
import torch
class CustomLinearLayerNoBias(torch.autograd.Function):
@staticmethod
def forward(ctx, X, W):
ctx.save_for_backward(X, W)
return X @ W.t()
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
X, W = ctx.saved_tensors
grad_X = grad_output @ W
grad_W = grad_output.t() @ X
return grad_X, grad_W
# Setup inputs and torch nn.Linear ground-truth
X = torch.rand(3, 3)
W = torch.rand(2, 3)
X.requires_grad = True
W.requires_grad = True
torch_layer = nn.Linear(3, 2, bias=False)
torch_layer.weight.data.copy_(W)
torch_layer.weight.requires_grad = True
custom_out = CustomLinearLayerNoBias.apply(X, W)
X_copy = X.detach().clone()
X_copy.requires_grad = True
ground_truth_out = torch_layer(X_copy)
torch.testing.assert_close(custom_out, ground_truth_out)
custom_loss = custom_out.mean()
ground_truth_loss = ground_truth_out.mean()
custom_loss.backward()
ground_truth_loss.backward()
torch.testing.assert_close(W.grad, torch_layer.weight.grad)
torch.testing.assert_close(X.grad, X_copy.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment