Skip to content

Instantly share code, notes, and snippets.

@prabhatkgupta
Created February 14, 2026 23:33
Show Gist options
  • Select an option

  • Save prabhatkgupta/b873ec726cdae8ba5850917f998b0968 to your computer and use it in GitHub Desktop.

Select an option

Save prabhatkgupta/b873ec726cdae8ba5850917f998b0968 to your computer and use it in GitHub Desktop.
Mixture of Experts Routing Logic
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 106,
"id": "0e2151dc-de2c-4af0-aae8-84d9c4948ec6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import math\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "markdown",
"id": "2fe3a8aa-1b02-48fc-b442-b1bd1b20550d",
"metadata": {},
"source": [
"# Define Parameters"
]
},
{
"cell_type": "code",
"execution_count": 117,
"id": "2622c5b1-6b59-4195-864c-dc5fd12cbb2a",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 1\n",
"sequence_length = 4\n",
"total_tokens = batch_size * sequence_length\n",
"num_experts = 5\n",
"top_k = 3\n",
"embedding_dim = 10\n",
"capacity_factor = 1.1"
]
},
{
"cell_type": "markdown",
"id": "b2c32914-871f-40b4-9860-c9516b87829d",
"metadata": {},
"source": [
"# Input Tensor"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a5d0f5ed-eee3-4deb-a3a0-ed1f03563a01",
"metadata": {},
"outputs": [],
"source": [
"tokens = torch.rand(batch_size, sequence_length, embedding_dim)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5a9546af-c7ee-4f2d-924b-54deb96740e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 4, 10]), torch.float32, device(type='cpu'))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens.shape, tokens.dtype, tokens.device"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b1f43300-ae31-45f3-b10f-30144b09dd49",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.4633, 0.4645, 0.1478, 0.0731, 0.8550, 0.0257, 0.9071, 0.7326,\n",
" 0.7471, 0.7278],\n",
" [0.0837, 0.5403, 0.0357, 0.4765, 0.1963, 0.5602, 0.5469, 0.0830,\n",
" 0.1511, 0.8322],\n",
" [0.3197, 0.2551, 0.1655, 0.4526, 0.5493, 0.9306, 0.3617, 0.9077,\n",
" 0.3160, 0.4185],\n",
" [0.4556, 0.4495, 0.4330, 0.9596, 0.6628, 0.0412, 0.4735, 0.0142,\n",
" 0.3043, 0.0344]]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens"
]
},
{
"cell_type": "markdown",
"id": "51b0926e-dc3a-44c1-bd58-1be435fda375",
"metadata": {},
"source": [
"# Routing Logic"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "5a01fe23-d10a-4209-8e3d-89cef120aa8f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 4, 5]), torch.float32, device(type='cpu'))"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_expert = torch.rand(1, sequence_length, num_experts)\n",
"token_expert.shape, token_expert.dtype, token_expert.device"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "cb5e6d17-87d2-4e4c-ab22-19802e51e3a1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.3223, 0.0997, 0.0972, 0.8917, 0.1840],\n",
" [0.7678, 0.5863, 0.9878, 0.0187, 0.3178],\n",
" [0.6149, 0.3391, 0.2534, 0.8794, 0.4257],\n",
" [0.6254, 0.2395, 0.1902, 0.8509, 0.7391]]])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_expert"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "0375b903-3391-4f44-9e84-bc1259da056f",
"metadata": {},
"outputs": [],
"source": [
"# router_weights = torch.rand(embedding_dim, num_experts)\n",
"# router_weights.shape, router_weights.dtype, router_weights.device\n",
"\n",
"router_weights = nn.Linear(in_features=embedding_dim, out_features=num_experts)\n",
"router_weights.weight = nn.Parameter(torch.rand(embedding_dim, num_experts).T)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "e7b79acc-aa16-4ad6-b2c4-6160542e6718",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 4, 5]), torch.float32, device(type='cpu'))"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_expert = router_weights(tokens)\n",
"token_expert.shape, token_expert.dtype, token_expert.device"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "0936b8b0-5f09-4cb6-ac4f-49fb3c691b80",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[1.8489, 3.0403, 3.1197, 2.1794, 2.7089],\n",
" [1.2508, 2.3485, 2.1812, 1.6347, 1.4857],\n",
" [1.6586, 2.8409, 2.7868, 2.7355, 2.3498],\n",
" [2.2430, 2.1960, 2.0732, 2.2447, 2.5344]]], grad_fn=<ViewBackward0>)"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_expert"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "c23950d3-d4f4-4012-8a22-41c4da4b5376",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_expert.shape == (batch_size, sequence_length, num_experts)"
]
},
{
"cell_type": "markdown",
"id": "cf9a1436-1ae1-45dd-817d-fbd6085abf58",
"metadata": {},
"source": [
"# Select Top-K Experts per Token"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "64d817d4-739c-4f97-a6c1-b6eb5b8beb8a",
"metadata": {},
"outputs": [],
"source": [
"# Top-k selection provides sparsity in calculation\n",
"\n",
"# Total number of parameter is increased, with experts increased diversity but computationally in-expensive because of selection of experts"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "b47a359a-58b7-4b8e-ad3a-c1a44b7aa17a",
"metadata": {},
"outputs": [],
"source": [
"top_k_scores, top_k_mapping = token_expert.topk(k=top_k, dim=2)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "52fb6c80-de0e-4319-949a-7242bfe12813",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[3.1197, 3.0403, 2.7089],\n",
" [2.3485, 2.1812, 1.6347],\n",
" [2.8409, 2.7868, 2.7355],\n",
" [2.5344, 2.2447, 2.2430]]], grad_fn=<TopkBackward0>)"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top_k_scores"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "fbddf6bf-175c-4838-b002-9e9b67b60053",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[2, 1, 4],\n",
" [1, 2, 3],\n",
" [1, 2, 3],\n",
" [4, 3, 0]]])"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top_k_mapping"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "efb76a50-9b3c-4d21-856e-d948360f291d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[ -inf, 3.0403, 3.1197, -inf, 2.7089],\n",
" [ -inf, 2.3485, 2.1812, 1.6347, -inf],\n",
" [ -inf, 2.8409, 2.7868, 2.7355, -inf],\n",
" [2.2430, -inf, -inf, 2.2447, 2.5344]]],\n",
" grad_fn=<ScatterBackward0>)"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_expert_logits = torch.full_like(token_expert, float(\"-inf\"))\n",
"token_expert_logits.scatter_(index=top_k_mapping, src=top_k_scores, dim=2)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "0b5b1ae4-ffea-41bd-84fe-06236b2f9434",
"metadata": {},
"outputs": [],
"source": [
"top_k_scores_norm = F.softmax(token_expert_logits, dim=2)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "d0b4a844-9674-4462-91c3-7329fdbd0a1d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.0000, 0.3571, 0.3866, 0.0000, 0.2564],\n",
" [0.0000, 0.4281, 0.3622, 0.2097, 0.0000],\n",
" [0.0000, 0.3512, 0.3327, 0.3161, 0.0000],\n",
" [0.2994, 0.0000, 0.0000, 0.2999, 0.4007]]],\n",
" grad_fn=<SoftmaxBackward0>)"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top_k_scores_norm"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "485557aa-9f9b-40b8-a8c2-c0afcd32fdb5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.3866, 0.3571, 0.2564],\n",
" [0.4281, 0.3622, 0.2097],\n",
" [0.3512, 0.3327, 0.3161],\n",
" [0.4007, 0.2999, 0.2994]]], grad_fn=<SoftmaxBackward0>)"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.softmax(top_k_scores, dim=2)"
]
},
{
"cell_type": "markdown",
"id": "2e5ade6f-2103-4a46-ace7-dd1e14e6704e",
"metadata": {},
"source": [
"# Convert into Priority wise topK"
]
},
{
"cell_type": "code",
"execution_count": 124,
"id": "29eca966-86bc-4934-aa15-418638e0a54f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"token_expert_selection = F.one_hot(top_k_mapping, num_classes=num_experts)\n",
"print(token_expert_selection.shape == (batch_size, sequence_length, top_k, num_experts))\n",
"token_expert_selection = token_expert_selection.permute(0, 2, 1, 3)"
]
},
{
"cell_type": "code",
"execution_count": 125,
"id": "1b053edd-02f1-4936-817f-83837a39590d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[0, 0, 1, 0, 0],\n",
" [0, 1, 0, 0, 0],\n",
" [0, 1, 0, 0, 0],\n",
" [0, 0, 0, 0, 1]],\n",
"\n",
" [[0, 1, 0, 0, 0],\n",
" [0, 0, 1, 0, 0],\n",
" [0, 0, 1, 0, 0],\n",
" [0, 0, 0, 1, 0]],\n",
"\n",
" [[0, 0, 0, 0, 1],\n",
" [0, 0, 0, 1, 0],\n",
" [0, 0, 0, 1, 0],\n",
" [1, 0, 0, 0, 0]]]])"
]
},
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_expert_selection"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "751e9da7-8e61-4673-991a-b97fa6b344ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"capacity = math.floor((top_k * capacity_factor * sequence_length) / num_experts)\n",
"capacity"
]
},
{
"cell_type": "markdown",
"id": "84cf4ab7-7f5a-49d5-b80e-3dd2791c2798",
"metadata": {},
"source": [
"# Mask Over-Capacity logits"
]
},
{
"cell_type": "code",
"execution_count": 127,
"id": "42b82401-0c1e-441f-8dcf-fcbb81c73321",
"metadata": {},
"outputs": [],
"source": [
"# for expert capacity estimation - load_balancing + memory efficiency in training/inference\n",
"\n",
"token_expert_selection_queue = token_expert_selection.reshape(batch_size, top_k * sequence_length, num_experts)\n",
"token_expert_selection_mask = torch.cumsum(token_expert_selection_queue, dim=1) <= capacity\n",
"token_expert_selection_mask = token_expert_selection_mask.reshape(batch_size, top_k, sequence_length, num_experts)\n"
]
},
{
"cell_type": "code",
"execution_count": 128,
"id": "5bac2071-b864-40d0-8922-3fc2988b2016",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[ True, True, True, True, True],\n",
" [ True, True, True, True, True],\n",
" [ True, True, True, True, True],\n",
" [ True, True, True, True, True]],\n",
"\n",
" [[ True, False, True, True, True],\n",
" [ True, False, True, True, True],\n",
" [ True, False, False, True, True],\n",
" [ True, False, False, True, True]],\n",
"\n",
" [[ True, False, False, True, True],\n",
" [ True, False, False, True, True],\n",
" [ True, False, False, False, True],\n",
" [ True, False, False, False, True]]]])"
]
},
"execution_count": 128,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_expert_selection_mask"
]
},
{
"cell_type": "code",
"execution_count": 150,
"id": "7d071e28-0ffa-462d-bbb2-060b674624e8",
"metadata": {},
"outputs": [],
"source": [
"final_selection = token_expert_selection_mask * token_expert_selection"
]
},
{
"cell_type": "code",
"execution_count": 151,
"id": "a4635e1a-03ae-44c7-9c3c-99e38ee1d835",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[0, 0, 1, 0, 0],\n",
" [0, 1, 0, 0, 0],\n",
" [0, 1, 0, 0, 0],\n",
" [0, 0, 0, 0, 1]],\n",
"\n",
" [[0, 0, 0, 0, 0],\n",
" [0, 0, 1, 0, 0],\n",
" [0, 0, 0, 0, 0],\n",
" [0, 0, 0, 1, 0]],\n",
"\n",
" [[0, 0, 0, 0, 1],\n",
" [0, 0, 0, 1, 0],\n",
" [0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0]]]])"
]
},
"execution_count": 151,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"final_selection"
]
},
{
"cell_type": "code",
"execution_count": 152,
"id": "06302b0b-d905-42d0-8442-e6aa60e1ece5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[0.0000, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000],\n",
" [0.3866, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.2564]],\n",
"\n",
" [[0.0000, 0.0000, 0.0000],\n",
" [0.4281, 0.0000, 0.0000],\n",
" [0.0000, 0.3622, 0.0000],\n",
" [0.0000, 0.0000, 0.2097],\n",
" [0.0000, 0.0000, 0.0000]],\n",
"\n",
" [[0.0000, 0.0000, 0.0000],\n",
" [0.3512, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000]],\n",
"\n",
" [[0.0000, 0.0000, 0.2994],\n",
" [0.0000, 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000],\n",
" [0.0000, 0.2999, 0.0000],\n",
" [0.4007, 0.0000, 0.0000]]]], grad_fn=<PermuteBackward0>)"
]
},
"execution_count": 152,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"final_probs = final_selection * top_k_scores_norm\n",
"\n",
"# (batch_size, top_k, sequence_length, num_experts) -> (batch_size, sequence_length, num_experts, top_k)\n",
"\n",
"final_probs = final_probs.permute(0, 2, 3, 1)\n",
"final_probs "
]
},
{
"cell_type": "markdown",
"id": "7d2c3c4e-3f46-4b87-b064-454379f7ee5f",
"metadata": {},
"source": [
"# Finally multiply with sequence tokens to get Expert vectors as input"
]
},
{
"cell_type": "code",
"execution_count": 148,
"id": "f54f38ea-3399-497b-839c-036c37a2cafd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[0.4633, 0.4645, 0.1478, 0.0731, 0.8550, 0.0257, 0.9071, 0.7326,\n",
" 0.7471, 0.7278],\n",
" [0.0837, 0.5403, 0.0357, 0.4765, 0.1963, 0.5602, 0.5469, 0.0830,\n",
" 0.1511, 0.8322],\n",
" [0.3197, 0.2551, 0.1655, 0.4526, 0.5493, 0.9306, 0.3617, 0.9077,\n",
" 0.3160, 0.4185],\n",
" [0.4556, 0.4495, 0.4330, 0.9596, 0.6628, 0.0412, 0.4735, 0.0142,\n",
" 0.3043, 0.0344]]])"
]
},
"execution_count": 148,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# batch_size * sequence_length * embedding_dim\n",
"tokens"
]
},
{
"cell_type": "code",
"execution_count": 157,
"id": "9fe36f16-64c0-426d-85d2-38de1bb3f144",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000],\n",
" [0.1364, 0.1346, 0.1296, 0.2873, 0.1984, 0.0123, 0.1418, 0.0043,\n",
" 0.0911, 0.0103]],\n",
"\n",
" [[0.1481, 0.3209, 0.0734, 0.3630, 0.2769, 0.5667, 0.3612, 0.3544,\n",
" 0.1757, 0.5033],\n",
" [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000],\n",
" [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000]],\n",
"\n",
" [[0.1791, 0.1795, 0.0571, 0.0283, 0.3305, 0.0099, 0.3507, 0.2832,\n",
" 0.2888, 0.2813],\n",
" [0.0303, 0.1957, 0.0129, 0.1726, 0.0711, 0.2029, 0.1981, 0.0301,\n",
" 0.0547, 0.3014],\n",
" [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000]],\n",
"\n",
" [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000],\n",
" [0.1366, 0.1348, 0.1299, 0.2878, 0.1988, 0.0123, 0.1420, 0.0043,\n",
" 0.0913, 0.0103],\n",
" [0.0175, 0.1133, 0.0075, 0.0999, 0.0411, 0.1175, 0.1147, 0.0174,\n",
" 0.0317, 0.1745]],\n",
"\n",
" [[0.1826, 0.1801, 0.1735, 0.3845, 0.2656, 0.0165, 0.1897, 0.0057,\n",
" 0.1219, 0.0138],\n",
" [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
" 0.0000, 0.0000],\n",
" [0.1188, 0.1191, 0.0379, 0.0187, 0.2192, 0.0066, 0.2326, 0.1878,\n",
" 0.1915, 0.1866]]]], grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 157,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# batch_size, sequence_length, num_experts, top_k\n",
"\n",
"# (batch_size, num_experts, top_k, sequence_length) * (batch_size, sequence_length, embedding_dim)\n",
"\n",
"# batch_size * num_experts * top_k * embedding_dim\n",
"final_probs.permute(0, 2, 3, 1) @ tokens"
]
},
{
"cell_type": "code",
"execution_count": 159,
"id": "df1e72f1-7b42-40ad-959f-8c0d4c6a3778",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.1364, 0.1346, 0.1296, 0.2873, 0.1984, 0.0123, 0.1418, 0.0043, 0.0911,\n",
" 0.0103])"
]
},
"execution_count": 159,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"0.2994 * tokens[0][-1]"
]
},
{
"cell_type": "code",
"execution_count": 161,
"id": "6cdb3ecf-6867-494c-8511-ef66dee6b5a3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.1481, 0.3209, 0.0734, 0.3630, 0.2769, 0.5667, 0.3611, 0.3543, 0.1757,\n",
" 0.5033])"
]
},
"execution_count": 161,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"0.4281 * tokens[0][1] + 0.3512 * tokens[0][2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8422df9-afe5-46fd-b153-ccf99e41b57c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment