Writing out the math used in Sonic MoE: https://arxiv.org/abs/2512.14080. Sec. 3.2 and Appendix C.
Goal: compute the backwards pass without needing to cache large tensors which would blow up with increased sparsification.
Tensors:
X_{ted}: input tensorsW^1_{edn}: up projectionW^2_{end): down projectionH_{ten}: post-up-proj activationA_{ten}: post-act-fn activationS_{te}: router scoresO_{td}: outputs
Indices: e indexes the experts, d the hidden dim, n the intermediate dimension (with n << d for sparse MoEs!), and t the token index.
The limit of interest is the case where we scale up e >> 1 with e * n (the total inner dimension of the MoE weights) held fixed. Tensors with an index like Y_{ted} are extremely expensive in this case.
Note: the above is a slight abuse of notation as the n-index of W^1 is twice the size of the n-index of W^2 for SwiGLU-like activation functions, but this will not make any qualitative difference, so we ignore it.
Einstein notation used throughout
H_{ten} = X_{ted} W^1_{edn} # up-proj
A_{ten} = phi(H_{ten}) # generic activation fn phi
Y_{ted} = A_{ten} W^2_{end} # down-proj
O_{td} = S_{te} Y_{ted} # Expert-weighted outputs
Paper notation: dX = dLoss / dX, for arbitrary tensor X. Note: not a fan of this notation, since dX already means something in math, plus it makes the chain-rule look weird: dLoss / dX = (dLoss / dY) (dY / dX) turns into dX = dY (dY / dX).
We need to compute dS_{ten}, dW^2_{end}, dH_{ten} in order to compute the gradients for all
learnable parameters. Goals in doing this:
- Avoid needing to cache
Y_{ted}, X_{ted}since theirO(n_toks * n_experts * hidden_dim)elements are unacceptably large for sparse (n_experts >> 1) MoE. - Avoid writing
dY_{ted}to HBM, for similar reasons.
The naive derivative of the score is
dS_{te} = dO_{t'd} (dO_{t'd} / dS_{te})
= dO_{td} Y_{ted}
Computing as above would require caching Y_{ted} ~ O(n_toks * n_experts * hidden_dim) elements, which is unacceptably large for sparse (n_experts >> 1, hidden_dim fixed) MoE.
Instead use Y_{ed} = A_{en} W^2{end} and group the computation as
dS_{te} = dO_{td} A_{ten} W^2{end}
=== dA'_{ten} A_{ten}
avoiding the need to cache the large Y_{ted}. Also, dA'_{ten} = dO_{td} W^2{end} is much smaller since the n << d for sparse MoEs, meaning we also create smaller tensors during this backward and the reduction dimension (n vs d) is smaller.
Sneak preview: we will save dA'_{ten} and reuse this tensor in the dH_{ten} computation.
Naive:
dW^2_{end} = dY_{te'd} (dY_{te'd} / dW^2_{end})
= dY_{ted} A_{ten}
We want to eliminate the reliance on the large dY_{ted} tensor. To this end we use the chain-rule to write
dY_{ted} = dO_{t'd'} (dO_{t'd'} / dY_{ted})
= dO_{td} S_{te}
making the above
dW^2_{end} = dY_{te'd} (dY_{te'd} / dW^2_{end})
= dO_{td} S_{te} A_{ten}
=== dO_{td} A'_{ten}
So, we successfully traded the large dY_{te'd} for the smaller A'_{ten} = S_{te} A_{ten}.
In the next section, we'll show that the derivative of the loss with respect to A'_{ten}, i.e. dA'_{ten}, is directly proportional to dA_{ten}, which is presumably where all of the A' notation comes from.
Naive:
dH_{ten} = dA_{ten} (dA_{ten} / dH_{ten})
= dA_{ten} phi'(H_{ten})
= dY_{t'e'd} ( dY_{t'e'd} / dA_{ten}) phi'(H_{ten})
= dY_{ted} W^2_{end} phi'(H_{ten})
where phi'(...) is the element-wise derivative of the activation function.
Computing as in the final line would be expensive (due to dY_{ted}), so we instead compute as in the second line dH_{ten} = dA_{ten} phi'(H_{ten}).
The second term is recomputed from the cached H_{ten}.
A trick is needed to compute dA_{ten} without the use of dY_{t'e'd}: it can be computed simply from dA'_{ten} = dO_{td} W^2_{end}, which we saved from the dS_{te} computation above.
The trick is just repeated chain rule applications:
dA_{ten} = dO_{t'd} (dO_{t'd} / dA_{ten})
= dO_{t'd} (dO_{t'd} / dY_{t''e'd}) (dY_{t''e'd} / dA_{ten})
= dO_{t'd} S_{t'e'} (dY_{t'e'd} / dA_{ten})
= dO_{td} S_{te} W^2_{end}
= S_{te} dA'_{ten}
Or maybe more simply/generally: the loss is a function of O_{td} which is in turn a function of the
combination S_{te}A_{ten} = A'_{ten} and other tensors:
Loss(O_{td}) = Loss(S_{te}A_{ten}, ...)
= Loss(A'_{ten}, ...)
where the ellipses are independent of both S_{te} and A_{ten}. For a generic function of this
form, dA_{ten} = S_{te} dA'_{ten} follows from the chain rule by taking a derivative with respect
to A_{ten}.
In summary:
dH_{ten} = S_{te} dA'_{ten} phi'(H_{ten})
Thank you! I will also provide an example pytorch einsum code that illustrates this derivation.