Skip to content

Instantly share code, notes, and snippets.

@segyges
Created July 16, 2025 23:53
Show Gist options
  • Select an option

  • Save segyges/4d6b83fa31f07c41bc6a449131bfcacb to your computer and use it in GitHub Desktop.

Select an option

Save segyges/4d6b83fa31f07c41bc6a449131bfcacb to your computer and use it in GitHub Desktop.
import torch
@torch.compile
def fwht_optimized(x, dim=-1):
"""
Highly optimized FWHT specifically designed for torch.compile.
Uses more compiler-friendly operations.
"""
if dim != -1:
x = x.transpose(dim, -1)
n = x.shape[-1]
if n & (n - 1) != 0:
raise ValueError(f"Size must be power of 2, got {n}")
# Use torch operations that compile well
result = x.clone()
log_n = int(torch.log2(torch.tensor(n, dtype=torch.float32)).item())
for i in range(log_n):
h = 1 << i # 2^i
step = h << 1 # 2^(i+1)
# Vectorized indices for all butterfly operations at this level
base_indices = torch.arange(0, n, step, device=x.device)
left_indices = base_indices.unsqueeze(1) + torch.arange(h, device=x.device)
right_indices = left_indices + h
# Flatten indices
left_flat = left_indices.flatten()
right_flat = right_indices.flatten()
# Butterfly operations
left_vals = result[..., left_flat]
right_vals = result[..., right_flat]
result[..., left_flat] = left_vals + right_vals
result[..., right_flat] = left_vals - right_vals
if dim != -1:
result = result.transpose(dim, -1)
return result
class FWHTLayer(torch.nn.Module):
"""
PyTorch layer wrapper for FWHT that can be used in neural networks.
"""
def __init__(self, dim=-1, normalize=False):
super().__init__()
self.dim = dim
self.normalize = normalize
def forward(self, x):
result = fwht_optimized(x, self.dim)
if self.normalize:
result = result / x.shape[self.dim]
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment