Created
July 16, 2025 23:53
-
-
Save segyges/4d6b83fa31f07c41bc6a449131bfcacb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| @torch.compile | |
| def fwht_optimized(x, dim=-1): | |
| """ | |
| Highly optimized FWHT specifically designed for torch.compile. | |
| Uses more compiler-friendly operations. | |
| """ | |
| if dim != -1: | |
| x = x.transpose(dim, -1) | |
| n = x.shape[-1] | |
| if n & (n - 1) != 0: | |
| raise ValueError(f"Size must be power of 2, got {n}") | |
| # Use torch operations that compile well | |
| result = x.clone() | |
| log_n = int(torch.log2(torch.tensor(n, dtype=torch.float32)).item()) | |
| for i in range(log_n): | |
| h = 1 << i # 2^i | |
| step = h << 1 # 2^(i+1) | |
| # Vectorized indices for all butterfly operations at this level | |
| base_indices = torch.arange(0, n, step, device=x.device) | |
| left_indices = base_indices.unsqueeze(1) + torch.arange(h, device=x.device) | |
| right_indices = left_indices + h | |
| # Flatten indices | |
| left_flat = left_indices.flatten() | |
| right_flat = right_indices.flatten() | |
| # Butterfly operations | |
| left_vals = result[..., left_flat] | |
| right_vals = result[..., right_flat] | |
| result[..., left_flat] = left_vals + right_vals | |
| result[..., right_flat] = left_vals - right_vals | |
| if dim != -1: | |
| result = result.transpose(dim, -1) | |
| return result | |
| class FWHTLayer(torch.nn.Module): | |
| """ | |
| PyTorch layer wrapper for FWHT that can be used in neural networks. | |
| """ | |
| def __init__(self, dim=-1, normalize=False): | |
| super().__init__() | |
| self.dim = dim | |
| self.normalize = normalize | |
| def forward(self, x): | |
| result = fwht_optimized(x, self.dim) | |
| if self.normalize: | |
| result = result / x.shape[self.dim] | |
| return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment