Skip to content

Instantly share code, notes, and snippets.

@baberabb
Created October 29, 2025 18:37
Show Gist options
  • Select an option

  • Save baberabb/3559d9159206129965708aa5c9872f41 to your computer and use it in GitHub Desktop.

Select an option

Save baberabb/3559d9159206129965708aa5c9872f41 to your computer and use it in GitHub Desktop.
compare_weights
def compare_weights(tensor1, tensor2, epsilon=None):
"""
Compare two tensors and provide detailed statistics about their differences.
"""
import torch
if tensor1.shape != tensor2.shape:
raise ValueError(f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}")
if epsilon is None:
dtype_epsilon = {
torch.float16: 1e-3, # fp16 machine epsilon ~9.77e-4
torch.bfloat16: 1e-2, # bfloat16 has even less precision for small values
torch.float32: 1e-6, # fp32 machine epsilon ~1.19e-7
torch.float64: 1e-14, # fp64 machine epsilon ~2.22e-16
}
epsilon = dtype_epsilon.get(tensor1.dtype, 1e-6)
rtol = 1e-3 if tensor1.dtype in [torch.float16, torch.bfloat16] else 1e-5
# Quick allclose check
all_close = torch.allclose(tensor1, tensor2, rtol=rtol, atol=epsilon)
# Calculate differences
abs_diff = torch.abs(tensor1 - tensor2)
# Relative difference with safe division
denominator = torch.maximum(
torch.abs(tensor1) + torch.abs(tensor2), torch.full_like(tensor1, epsilon)
)
rel_diff = abs_diff / denominator
total_params = tensor1.numel()
num_exact = (abs_diff < epsilon).sum().item()
mean_abs = abs_diff.mean().item()
mean_rel = rel_diff.mean().item()
max_abs = abs_diff.max().item()
max_rel = rel_diff.max().item()
abs_diff_flat = abs_diff.view(-1)
rel_diff_flat = rel_diff.view(-1)
p99_abs = torch_quantile(abs_diff_flat.float(), 0.99)
p999_abs = torch_quantile(abs_diff_flat.float(), 0.999)
p99_rel = torch_quantile(rel_diff_flat.float(), 0.99)
stats = {
"all_close": all_close,
"p99_abs_diff": float(p99_abs.item()),
"p999_abs_diff": float(p999_abs.item()),
"p99_rel_diff": float(p99_rel.item()),
"mean_abs_diff": mean_abs,
"mean_rel_diff": mean_rel,
"max_abs_diff": max_abs,
"max_rel_diff": max_rel,
"num_exact_matches": num_exact,
"total_params": total_params,
"exact_match_pct": 100 * num_exact / total_params,
"dtype": str(tensor1.dtype),
"epsilon": epsilon,
"all_close_rtol": rtol,
"all_close_atol": epsilon,
}
# # Print summary
# print(f"\nWeight Comparison Summary ({total_params:,} parameters):")
# print(f" All close (torch.allclose): {stats['all_close']}")
# print(f" 99% of weights differ by < {stats['p99_abs_diff']:.6f} (absolute)")
# print(f" 99.9% of weights differ by < {stats['p999_abs_diff']:.6f} (absolute)")
# print(f" 99% of weights differ by < {stats['p99_rel_diff']:.4%} (relative)")
# print(f" Mean absolute difference: {stats['mean_abs_diff']:.6f}")
# print(f" Maximum absolute difference: {stats['max_abs_diff']:.6f}")
# print(f" Maximum relative difference: {stats['max_rel_diff']:.4%}")
# print(f" Exact matches: {stats['exact_match_pct']:.2f}%")
return stats
def torch_quantile( # noqa: PLR0913 (too many arguments)
tensor: Tensor,
q: float | Tensor,
dim: int | None = None,
*,
keepdim: bool = False,
interpolation: str = "linear",
out: Tensor | None = None,
) -> Tensor:
r"""Improved ``torch.quantile`` for one scalar quantile.
Arguments
---------
tensor: ``Tensor``
See ``torch.quantile``.
q: ``float``
See ``torch.quantile``. Supports only scalar values currently.
dim: ``int``, optional
See ``torch.quantile``.
keepdim: ``bool``
See ``torch.quantile``. Supports only ``False`` currently.
Defaults to ``False``.
interpolation: ``{"linear", "lower", "higher", "midpoint", "nearest"}``
See ``torch.quantile``. Defaults to ``"linear"``.
out: ``Tensor``, optional
See ``torch.quantile``. Currently not supported.
Notes
-----
Uses ``torch.kthvalue``. Better than ``torch.quantile`` since:
#. it has no :math:`2^{24}` tensor `size limit <https://github.com/pytorch/pytorch/issues/64947#issuecomment-2304371451>`_;
#. it is much faster, at least on big tensor sizes.
"""
# Sanitization of: q
q_float = float(q) # May raise an (unpredictible) error
if not 0 <= q_float <= 1:
msg = f"Only values 0<=q<=1 are supported (got {q_float!r})"
raise ValueError(msg)
# Sanitization of: dim
# Because one cannot pass `dim=None` to `squeeze()` or `kthvalue()`
if dim_was_none := dim is None:
dim = 0
tensor = tensor.reshape((-1, *(1,) * (tensor.ndim - 1)))
# Sanitization of: inteporlation
idx_float = q_float * (tensor.shape[dim] - 1)
if interpolation == "nearest":
idxs = [round(idx_float)]
elif interpolation == "lower":
idxs = [floor(idx_float)]
elif interpolation == "higher":
idxs = [ceil(idx_float)]
elif interpolation in {"linear", "midpoint"}:
low = floor(idx_float)
idxs = [low] if idx_float == low else [low, low + 1]
weight = idx_float - low if interpolation == "linear" else 0.5
else:
msg = (
"Currently supported interpolations are {'linear', 'lower', 'higher', "
f"'midpoint', 'nearest'}} (got {interpolation!r})"
)
raise ValueError(msg)
# Sanitization of: out
if out is not None:
msg = f"Only None value is currently supported for out (got {out!r})"
raise ValueError(msg)
# Logic
outs = [torch.kthvalue(tensor, idx + 1, dim, keepdim=True)[0] for idx in idxs]
out = outs[0] if len(outs) == 1 else outs[0].lerp(outs[1], weight)
# Rectification of: keepdim
if keepdim:
return out
return out.squeeze() if dim_was_none else out.squeeze(dim)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment