Created
October 29, 2025 18:37
-
-
Save baberabb/3559d9159206129965708aa5c9872f41 to your computer and use it in GitHub Desktop.
compare_weights
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def compare_weights(tensor1, tensor2, epsilon=None): | |
| """ | |
| Compare two tensors and provide detailed statistics about their differences. | |
| """ | |
| import torch | |
| if tensor1.shape != tensor2.shape: | |
| raise ValueError(f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}") | |
| if epsilon is None: | |
| dtype_epsilon = { | |
| torch.float16: 1e-3, # fp16 machine epsilon ~9.77e-4 | |
| torch.bfloat16: 1e-2, # bfloat16 has even less precision for small values | |
| torch.float32: 1e-6, # fp32 machine epsilon ~1.19e-7 | |
| torch.float64: 1e-14, # fp64 machine epsilon ~2.22e-16 | |
| } | |
| epsilon = dtype_epsilon.get(tensor1.dtype, 1e-6) | |
| rtol = 1e-3 if tensor1.dtype in [torch.float16, torch.bfloat16] else 1e-5 | |
| # Quick allclose check | |
| all_close = torch.allclose(tensor1, tensor2, rtol=rtol, atol=epsilon) | |
| # Calculate differences | |
| abs_diff = torch.abs(tensor1 - tensor2) | |
| # Relative difference with safe division | |
| denominator = torch.maximum( | |
| torch.abs(tensor1) + torch.abs(tensor2), torch.full_like(tensor1, epsilon) | |
| ) | |
| rel_diff = abs_diff / denominator | |
| total_params = tensor1.numel() | |
| num_exact = (abs_diff < epsilon).sum().item() | |
| mean_abs = abs_diff.mean().item() | |
| mean_rel = rel_diff.mean().item() | |
| max_abs = abs_diff.max().item() | |
| max_rel = rel_diff.max().item() | |
| abs_diff_flat = abs_diff.view(-1) | |
| rel_diff_flat = rel_diff.view(-1) | |
| p99_abs = torch_quantile(abs_diff_flat.float(), 0.99) | |
| p999_abs = torch_quantile(abs_diff_flat.float(), 0.999) | |
| p99_rel = torch_quantile(rel_diff_flat.float(), 0.99) | |
| stats = { | |
| "all_close": all_close, | |
| "p99_abs_diff": float(p99_abs.item()), | |
| "p999_abs_diff": float(p999_abs.item()), | |
| "p99_rel_diff": float(p99_rel.item()), | |
| "mean_abs_diff": mean_abs, | |
| "mean_rel_diff": mean_rel, | |
| "max_abs_diff": max_abs, | |
| "max_rel_diff": max_rel, | |
| "num_exact_matches": num_exact, | |
| "total_params": total_params, | |
| "exact_match_pct": 100 * num_exact / total_params, | |
| "dtype": str(tensor1.dtype), | |
| "epsilon": epsilon, | |
| "all_close_rtol": rtol, | |
| "all_close_atol": epsilon, | |
| } | |
| # # Print summary | |
| # print(f"\nWeight Comparison Summary ({total_params:,} parameters):") | |
| # print(f" All close (torch.allclose): {stats['all_close']}") | |
| # print(f" 99% of weights differ by < {stats['p99_abs_diff']:.6f} (absolute)") | |
| # print(f" 99.9% of weights differ by < {stats['p999_abs_diff']:.6f} (absolute)") | |
| # print(f" 99% of weights differ by < {stats['p99_rel_diff']:.4%} (relative)") | |
| # print(f" Mean absolute difference: {stats['mean_abs_diff']:.6f}") | |
| # print(f" Maximum absolute difference: {stats['max_abs_diff']:.6f}") | |
| # print(f" Maximum relative difference: {stats['max_rel_diff']:.4%}") | |
| # print(f" Exact matches: {stats['exact_match_pct']:.2f}%") | |
| return stats | |
| def torch_quantile( # noqa: PLR0913 (too many arguments) | |
| tensor: Tensor, | |
| q: float | Tensor, | |
| dim: int | None = None, | |
| *, | |
| keepdim: bool = False, | |
| interpolation: str = "linear", | |
| out: Tensor | None = None, | |
| ) -> Tensor: | |
| r"""Improved ``torch.quantile`` for one scalar quantile. | |
| Arguments | |
| --------- | |
| tensor: ``Tensor`` | |
| See ``torch.quantile``. | |
| q: ``float`` | |
| See ``torch.quantile``. Supports only scalar values currently. | |
| dim: ``int``, optional | |
| See ``torch.quantile``. | |
| keepdim: ``bool`` | |
| See ``torch.quantile``. Supports only ``False`` currently. | |
| Defaults to ``False``. | |
| interpolation: ``{"linear", "lower", "higher", "midpoint", "nearest"}`` | |
| See ``torch.quantile``. Defaults to ``"linear"``. | |
| out: ``Tensor``, optional | |
| See ``torch.quantile``. Currently not supported. | |
| Notes | |
| ----- | |
| Uses ``torch.kthvalue``. Better than ``torch.quantile`` since: | |
| #. it has no :math:`2^{24}` tensor `size limit <https://github.com/pytorch/pytorch/issues/64947#issuecomment-2304371451>`_; | |
| #. it is much faster, at least on big tensor sizes. | |
| """ | |
| # Sanitization of: q | |
| q_float = float(q) # May raise an (unpredictible) error | |
| if not 0 <= q_float <= 1: | |
| msg = f"Only values 0<=q<=1 are supported (got {q_float!r})" | |
| raise ValueError(msg) | |
| # Sanitization of: dim | |
| # Because one cannot pass `dim=None` to `squeeze()` or `kthvalue()` | |
| if dim_was_none := dim is None: | |
| dim = 0 | |
| tensor = tensor.reshape((-1, *(1,) * (tensor.ndim - 1))) | |
| # Sanitization of: inteporlation | |
| idx_float = q_float * (tensor.shape[dim] - 1) | |
| if interpolation == "nearest": | |
| idxs = [round(idx_float)] | |
| elif interpolation == "lower": | |
| idxs = [floor(idx_float)] | |
| elif interpolation == "higher": | |
| idxs = [ceil(idx_float)] | |
| elif interpolation in {"linear", "midpoint"}: | |
| low = floor(idx_float) | |
| idxs = [low] if idx_float == low else [low, low + 1] | |
| weight = idx_float - low if interpolation == "linear" else 0.5 | |
| else: | |
| msg = ( | |
| "Currently supported interpolations are {'linear', 'lower', 'higher', " | |
| f"'midpoint', 'nearest'}} (got {interpolation!r})" | |
| ) | |
| raise ValueError(msg) | |
| # Sanitization of: out | |
| if out is not None: | |
| msg = f"Only None value is currently supported for out (got {out!r})" | |
| raise ValueError(msg) | |
| # Logic | |
| outs = [torch.kthvalue(tensor, idx + 1, dim, keepdim=True)[0] for idx in idxs] | |
| out = outs[0] if len(outs) == 1 else outs[0].lerp(outs[1], weight) | |
| # Rectification of: keepdim | |
| if keepdim: | |
| return out | |
| return out.squeeze() if dim_was_none else out.squeeze(dim) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment