Skip to content

Instantly share code, notes, and snippets.

@shunting314
Created December 9, 2025 00:13
Show Gist options
  • Select an option

  • Save shunting314/7fa64d8de9622fa33b2aba4efcc25424 to your computer and use it in GitHub Desktop.

Select an option

Save shunting314/7fa64d8de9622fa33b2aba4efcc25424 to your computer and use it in GitHub Desktop.
diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py
index af867f4..5f7f2ac 100644
--- a/src/liger_kernel/ops/rms_norm.py
+++ b/src/liger_kernel/ops/rms_norm.py
@@ -450,6 +450,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
elif X.device.type == "xpu":
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
+ sm_count = sm_count * 32
# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment