Skip to content

Instantly share code, notes, and snippets.

@cbensimon
Last active September 8, 2025 16:15
Show Gist options
  • Select an option

  • Save cbensimon/be69b456214ae06c6910d825e28fe145 to your computer and use it in GitHub Desktop.

Select an option

Save cbensimon/be69b456214ae06c6910d825e28fe145 to your computer and use it in GitHub Desktop.
torchao-inductor-mapping.py
"""
torch: 2.8.0
torchao: 0.11.0
"""
import torch
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import quantize_
QUANTIZE = True
class ToyLinearModel(torch.nn.Module):
def __init__(self, m: int, n: int):
super().__init__()
self.linear = torch.nn.Linear(m, n)
def forward(self, x):
x = self.linear(x)
return x
module = ToyLinearModel(1024, 1024).to(device='cuda', dtype=torch.bfloat16)
if QUANTIZE:
quantize_(module, Float8DynamicActivationFloat8WeightConfig())
args = (torch.randn(1, 1024, dtype=torch.bfloat16, device='cuda'),)
exported = torch.export.export(module, args=args)
compiled = torch._inductor.aoti_load_package(torch._inductor.aoti_compile_and_package(exported))
assert set(exported.state_dict.keys()) == set(module.state_dict().keys())
assert set(exported.state_dict.keys()) == set(compiled.get_constant_fqns()) # Does not pass if QUANTIZE is True
print(set(exported.state_dict.keys()))
# Output (QUANTIZE False): {'linear.bias', 'linear.weight'}
# Output (QUANTIZE True): {'linear.bias', 'linear.weight'}
print(set(compiled.get_constant_fqns()))
# Output (QUANTIZE False): {'linear.bias', 'linear.weight'}
# Output (QUANTIZE True): {'constant2', 'linear.parametrizations.weight.original0', 'linear.bias'}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment