Created
January 5, 2026 19:54
-
-
Save HDCharles/fcfbb2ebcd7ad867ccfff36d23e37bc4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| @torch.no_grad | |
| @pytest.mark.unit | |
| @pytest.mark.parametrize( | |
| "n_balance_layers, group_size, n_input_features, strategy", | |
| [ | |
| (5, -1, 32, QuantizationStrategy.CHANNEL), # channel | |
| (4, 10, 40, QuantizationStrategy.GROUP), # group | |
| (4, torch.inf, 40, QuantizationStrategy.TENSOR), # tensor | |
| (3, 16, 64, QuantizationStrategy.TENSOR_GROUP), # tensor_group | |
| ], | |
| ) | |
| def test_compute_layer_means(n_balance_layers, group_size, n_input_features, strategy): | |
| """ | |
| Confirm our logic to compute duo_scaling layer means via a running tally | |
| matches the original memory-intensive AutoAWQ implementation, which concats | |
| all balance layers into a single tensor before reducing to mean | |
| Large models were prone to fail at this step. | |
| """ | |
| balance_layers = [ | |
| torch.nn.Linear(n_input_features, 10) for _ in range(n_balance_layers) | |
| ] | |
| group_size_arg = None | |
| match strategy: | |
| case QuantizationStrategy.CHANNEL: | |
| group_size = balance_layers[0].weight.shape[1] | |
| case QuantizationStrategy.TENSOR: | |
| group_size = n_input_features * 10 | |
| case _: | |
| group_size_arg = group_size | |
| for balance_layer in balance_layers: | |
| setattr( | |
| balance_layer, | |
| "quantization_scheme", | |
| QuantizationScheme( | |
| targets=["Linear"], | |
| weights=QuantizationArgs( | |
| strategy=strategy, | |
| group_size=group_size_arg, | |
| ), | |
| ), | |
| ) | |
| auto_awq_means = _auto_awq_normalize(balance_layers, group_size).mean(0) | |
| llmc_awq_means = AWQModifier._compute_layer_means(balance_layers).to( | |
| auto_awq_means.dtype | |
| ) | |
| assert_close(auto_awq_means, llmc_awq_means) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment