Skip to content

Instantly share code, notes, and snippets.

@W3SS
Forked from karpathy/pytorch_strangeness.py
Created February 19, 2026 02:45
Show Gist options
  • Select an option

  • Save W3SS/aa0b88038e7e6b8f9502e528a9d5b71b to your computer and use it in GitHub Desktop.

Select an option

Save W3SS/aa0b88038e7e6b8f9502e528a9d5b71b to your computer and use it in GitHub Desktop.
pytorch strangeness
import torch
import torch.nn as nn
torch.manual_seed(42)
x = torch.randn(2, 768)
# matrix multiply "ignores" the second row when calculating the first row
w = torch.randn(768, 768)
z1 = x[0] @ w
z2 = (x @ w)[0]
print((z1-z2).abs().max().item()) # prints 0 (should be 0, OK)
# linear does not!
m = nn.Linear(768, 768, bias=False)
with torch.no_grad():
m.weight.copy_(w.T)
q1 = m(x[0])
q2 = m(x)[0]
print((q1-q2).abs().max().item()) # prints ~2e-5 ( should be 0?!)
# and z1 != q1
print((z1-q1).abs().max().item()) # prints ~9e-5 (should be 0?!)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment