Skip to content

Instantly share code, notes, and snippets.

@garrett361
Last active January 3, 2026 11:58
Show Gist options
  • Select an option

  • Save garrett361/227ac4ec3f17d7e85833a98e15b5893a to your computer and use it in GitHub Desktop.

Select an option

Save garrett361/227ac4ec3f17d7e85833a98e15b5893a to your computer and use it in GitHub Desktop.
Sonic MoE Backwards Math

Writing out the math used in Sonic MoE: https://arxiv.org/abs/2512.14080. Sec. 3.2 and Appendix C.

Setup

Goal: compute the backwards pass without needing to cache large tensors which would blow up with increased sparsification.

Tensors:

  • X_{ted}: input tensors
  • W^1_{edn}: up projection
  • W^2_{end): down projection
  • H_{ten}: post-up-proj activation
  • A_{ten}: post-act-fn activation
  • S_{te}: router scores
  • O_{td}: outputs

Indices: e indexes the experts, d the hidden dim, n the intermediate dimension (with n << d for sparse MoEs!), and t the token index.

The limit of interest is the case where we scale up e >> 1 with e * n (the total inner dimension of the MoE weights) held fixed. Tensors with an index like Y_{ted} are extremely expensive in this case.

Note: the above is a slight abuse of notation as the n-index of W^1 is twice the size of the n-index of W^2 for SwiGLU-like activation functions, but this will not make any qualitative difference, so we ignore it.

Forward Pass

Einstein notation used throughout

H_{ten} = X_{ted} W^1_{edn} # up-proj
A_{ten} = phi(H_{ten}) # generic activation fn phi
Y_{ted} = A_{ten} W^2_{end} # down-proj
O_{td} = S_{te} Y_{ted} # Expert-weighted outputs

Backwards Pass

Paper notation: dX = dLoss / dX, for arbitrary tensor X. Note: not a fan of this notation, since dX already means something in math, plus it makes the chain-rule look weird: dLoss / dX = (dLoss / dY) (dY / dX) turns into dX = dY (dY / dX).

We need to compute dS_{ten}, dW^2_{end}, dH_{ten} in order to compute the gradients for all learnable parameters. Goals in doing this:

  • Avoid needing to cache Y_{ted}, X_{ted} since their O(n_toks * n_experts * hidden_dim) elements are unacceptably large for sparse (n_experts >> 1) MoE.
  • Avoid writing dY_{ted} to HBM, for similar reasons.

Score Derivative dS_{te}

The naive derivative of the score is

dS_{te} = dO_{t'd} (dO_{t'd} / dS_{te})
        = dO_{td}  Y_{ted}

Computing as above would require caching Y_{ted} ~ O(n_toks * n_experts * hidden_dim) elements, which is unacceptably large for sparse (n_experts >> 1, hidden_dim fixed) MoE.

Instead use Y_{ed} = A_{en} W^2{end} and group the computation as

dS_{te} = dO_{td} A_{ten} W^2{end}
        === dA'_{ten} A_{ten}

avoiding the need to cache the large Y_{ted}. Also, dA'_{ten} = dO_{td} W^2{end} is much smaller since the n << d for sparse MoEs, meaning we also create smaller tensors during this backward and the reduction dimension (n vs d) is smaller.

Sneak preview: we will save dA'_{ten} and reuse this tensor in the dH_{ten} computation.

Down-proj Weight Derivative dW^2_{end}

Naive:

dW^2_{end} = dY_{te'd} (dY_{te'd} / dW^2_{end})
           = dY_{ted} A_{ten}

We want to eliminate the reliance on the large dY_{ted} tensor. To this end we use the chain-rule to write

dY_{ted} = dO_{t'd'} (dO_{t'd'} / dY_{ted})
         = dO_{td} S_{te}

making the above

dW^2_{end} = dY_{te'd} (dY_{te'd} / dW^2_{end})
           = dO_{td} S_{te} A_{ten}
           === dO_{td} A'_{ten}

So, we successfully traded the large dY_{te'd} for the smaller A'_{ten} = S_{te} A_{ten}.

In the next section, we'll show that the derivative of the loss with respect to A'_{ten}, i.e. dA'_{ten}, is directly proportional to dA_{ten}, which is presumably where all of the A' notation comes from.

Up-proj Activation Derivative dH_{ten}

Naive:

dH_{ten} = dA_{ten} (dA_{ten} / dH_{ten})
         = dA_{ten} phi'(H_{ten})
         = dY_{t'e'd} ( dY_{t'e'd} / dA_{ten}) phi'(H_{ten})
         = dY_{ted} W^2_{end} phi'(H_{ten})

where phi'(...) is the element-wise derivative of the activation function.

Computing as in the final line would be expensive (due to dY_{ted}), so we instead compute as in the second line dH_{ten} = dA_{ten} phi'(H_{ten}).

The second term is recomputed from the cached H_{ten}. A trick is needed to compute dA_{ten} without the use of dY_{t'e'd}: it can be computed simply from dA'_{ten} = dO_{td} W^2_{end}, which we saved from the dS_{te} computation above.

The trick is just repeated chain rule applications:

dA_{ten} = dO_{t'd} (dO_{t'd} / dA_{ten})
         = dO_{t'd} (dO_{t'd} / dY_{t''e'd}) (dY_{t''e'd} / dA_{ten})
         = dO_{t'd} S_{t'e'} (dY_{t'e'd} / dA_{ten})
         = dO_{td} S_{te} W^2_{end}
         = S_{te} dA'_{ten}

Or maybe more simply/generally: the loss is a function of O_{td} which is in turn a function of the combination S_{te}A_{ten} = A'_{ten} and other tensors:

Loss(O_{td}) = Loss(S_{te}A_{ten}, ...)
             = Loss(A'_{ten}, ...)

where the ellipses are independent of both S_{te} and A_{ten}. For a generic function of this form, dA_{ten} = S_{te} dA'_{ten} follows from the chain rule by taking a derivative with respect to A_{ten}.

In summary:

dH_{ten} = S_{te} dA'_{ten} phi'(H_{ten})
@GarlGuo
Copy link

GarlGuo commented Dec 21, 2025

Thank you! I will also provide an example pytorch einsum code that illustrates this derivation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment