Created
February 14, 2026 23:33
-
-
Save prabhatkgupta/b873ec726cdae8ba5850917f998b0968 to your computer and use it in GitHub Desktop.
Mixture of Experts Routing Logic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 106, | |
| "id": "0e2151dc-de2c-4af0-aae8-84d9c4948ec6", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import math\n", | |
| "import torch.nn.functional as F" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2fe3a8aa-1b02-48fc-b442-b1bd1b20550d", | |
| "metadata": {}, | |
| "source": [ | |
| "# Define Parameters" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 117, | |
| "id": "2622c5b1-6b59-4195-864c-dc5fd12cbb2a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "batch_size = 1\n", | |
| "sequence_length = 4\n", | |
| "total_tokens = batch_size * sequence_length\n", | |
| "num_experts = 5\n", | |
| "top_k = 3\n", | |
| "embedding_dim = 10\n", | |
| "capacity_factor = 1.1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b2c32914-871f-40b4-9860-c9516b87829d", | |
| "metadata": {}, | |
| "source": [ | |
| "# Input Tensor" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "a5d0f5ed-eee3-4deb-a3a0-ed1f03563a01", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tokens = torch.rand(batch_size, sequence_length, embedding_dim)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "5a9546af-c7ee-4f2d-924b-54deb96740e0", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([1, 4, 10]), torch.float32, device(type='cpu'))" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "tokens.shape, tokens.dtype, tokens.device" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "b1f43300-ae31-45f3-b10f-30144b09dd49", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[0.4633, 0.4645, 0.1478, 0.0731, 0.8550, 0.0257, 0.9071, 0.7326,\n", | |
| " 0.7471, 0.7278],\n", | |
| " [0.0837, 0.5403, 0.0357, 0.4765, 0.1963, 0.5602, 0.5469, 0.0830,\n", | |
| " 0.1511, 0.8322],\n", | |
| " [0.3197, 0.2551, 0.1655, 0.4526, 0.5493, 0.9306, 0.3617, 0.9077,\n", | |
| " 0.3160, 0.4185],\n", | |
| " [0.4556, 0.4495, 0.4330, 0.9596, 0.6628, 0.0412, 0.4735, 0.0142,\n", | |
| " 0.3043, 0.0344]]])" | |
| ] | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "tokens" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "51b0926e-dc3a-44c1-bd58-1be435fda375", | |
| "metadata": {}, | |
| "source": [ | |
| "# Routing Logic" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 41, | |
| "id": "5a01fe23-d10a-4209-8e3d-89cef120aa8f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([1, 4, 5]), torch.float32, device(type='cpu'))" | |
| ] | |
| }, | |
| "execution_count": 41, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "token_expert = torch.rand(1, sequence_length, num_experts)\n", | |
| "token_expert.shape, token_expert.dtype, token_expert.device" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "id": "cb5e6d17-87d2-4e4c-ab22-19802e51e3a1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[0.3223, 0.0997, 0.0972, 0.8917, 0.1840],\n", | |
| " [0.7678, 0.5863, 0.9878, 0.0187, 0.3178],\n", | |
| " [0.6149, 0.3391, 0.2534, 0.8794, 0.4257],\n", | |
| " [0.6254, 0.2395, 0.1902, 0.8509, 0.7391]]])" | |
| ] | |
| }, | |
| "execution_count": 42, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "token_expert" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "id": "0375b903-3391-4f44-9e84-bc1259da056f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# router_weights = torch.rand(embedding_dim, num_experts)\n", | |
| "# router_weights.shape, router_weights.dtype, router_weights.device\n", | |
| "\n", | |
| "router_weights = nn.Linear(in_features=embedding_dim, out_features=num_experts)\n", | |
| "router_weights.weight = nn.Parameter(torch.rand(embedding_dim, num_experts).T)\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 44, | |
| "id": "e7b79acc-aa16-4ad6-b2c4-6160542e6718", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([1, 4, 5]), torch.float32, device(type='cpu'))" | |
| ] | |
| }, | |
| "execution_count": 44, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "token_expert = router_weights(tokens)\n", | |
| "token_expert.shape, token_expert.dtype, token_expert.device" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 45, | |
| "id": "0936b8b0-5f09-4cb6-ac4f-49fb3c691b80", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[1.8489, 3.0403, 3.1197, 2.1794, 2.7089],\n", | |
| " [1.2508, 2.3485, 2.1812, 1.6347, 1.4857],\n", | |
| " [1.6586, 2.8409, 2.7868, 2.7355, 2.3498],\n", | |
| " [2.2430, 2.1960, 2.0732, 2.2447, 2.5344]]], grad_fn=<ViewBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 45, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "token_expert" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "id": "c23950d3-d4f4-4012-8a22-41c4da4b5376", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "True" | |
| ] | |
| }, | |
| "execution_count": 46, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "token_expert.shape == (batch_size, sequence_length, num_experts)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "cf9a1436-1ae1-45dd-817d-fbd6085abf58", | |
| "metadata": {}, | |
| "source": [ | |
| "# Select Top-K Experts per Token" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 47, | |
| "id": "64d817d4-739c-4f97-a6c1-b6eb5b8beb8a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Top-k selection provides sparsity in calculation\n", | |
| "\n", | |
| "# Total number of parameter is increased, with experts increased diversity but computationally in-expensive because of selection of experts" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 48, | |
| "id": "b47a359a-58b7-4b8e-ad3a-c1a44b7aa17a", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "top_k_scores, top_k_mapping = token_expert.topk(k=top_k, dim=2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "id": "52fb6c80-de0e-4319-949a-7242bfe12813", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[3.1197, 3.0403, 2.7089],\n", | |
| " [2.3485, 2.1812, 1.6347],\n", | |
| " [2.8409, 2.7868, 2.7355],\n", | |
| " [2.5344, 2.2447, 2.2430]]], grad_fn=<TopkBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 49, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "top_k_scores" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 50, | |
| "id": "fbddf6bf-175c-4838-b002-9e9b67b60053", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[2, 1, 4],\n", | |
| " [1, 2, 3],\n", | |
| " [1, 2, 3],\n", | |
| " [4, 3, 0]]])" | |
| ] | |
| }, | |
| "execution_count": 50, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "top_k_mapping" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 60, | |
| "id": "efb76a50-9b3c-4d21-856e-d948360f291d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[ -inf, 3.0403, 3.1197, -inf, 2.7089],\n", | |
| " [ -inf, 2.3485, 2.1812, 1.6347, -inf],\n", | |
| " [ -inf, 2.8409, 2.7868, 2.7355, -inf],\n", | |
| " [2.2430, -inf, -inf, 2.2447, 2.5344]]],\n", | |
| " grad_fn=<ScatterBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 60, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "token_expert_logits = torch.full_like(token_expert, float(\"-inf\"))\n", | |
| "token_expert_logits.scatter_(index=top_k_mapping, src=top_k_scores, dim=2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 61, | |
| "id": "0b5b1ae4-ffea-41bd-84fe-06236b2f9434", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "top_k_scores_norm = F.softmax(token_expert_logits, dim=2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 62, | |
| "id": "d0b4a844-9674-4462-91c3-7329fdbd0a1d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[0.0000, 0.3571, 0.3866, 0.0000, 0.2564],\n", | |
| " [0.0000, 0.4281, 0.3622, 0.2097, 0.0000],\n", | |
| " [0.0000, 0.3512, 0.3327, 0.3161, 0.0000],\n", | |
| " [0.2994, 0.0000, 0.0000, 0.2999, 0.4007]]],\n", | |
| " grad_fn=<SoftmaxBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 62, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "top_k_scores_norm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 63, | |
| "id": "485557aa-9f9b-40b8-a8c2-c0afcd32fdb5", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[0.3866, 0.3571, 0.2564],\n", | |
| " [0.4281, 0.3622, 0.2097],\n", | |
| " [0.3512, 0.3327, 0.3161],\n", | |
| " [0.4007, 0.2999, 0.2994]]], grad_fn=<SoftmaxBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 63, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "F.softmax(top_k_scores, dim=2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "2e5ade6f-2103-4a46-ace7-dd1e14e6704e", | |
| "metadata": {}, | |
| "source": [ | |
| "# Convert into Priority wise topK" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 124, | |
| "id": "29eca966-86bc-4934-aa15-418638e0a54f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "True\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "token_expert_selection = F.one_hot(top_k_mapping, num_classes=num_experts)\n", | |
| "print(token_expert_selection.shape == (batch_size, sequence_length, top_k, num_experts))\n", | |
| "token_expert_selection = token_expert_selection.permute(0, 2, 1, 3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 125, | |
| "id": "1b053edd-02f1-4936-817f-83837a39590d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[0, 0, 1, 0, 0],\n", | |
| " [0, 1, 0, 0, 0],\n", | |
| " [0, 1, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 1]],\n", | |
| "\n", | |
| " [[0, 1, 0, 0, 0],\n", | |
| " [0, 0, 1, 0, 0],\n", | |
| " [0, 0, 1, 0, 0],\n", | |
| " [0, 0, 0, 1, 0]],\n", | |
| "\n", | |
| " [[0, 0, 0, 0, 1],\n", | |
| " [0, 0, 0, 1, 0],\n", | |
| " [0, 0, 0, 1, 0],\n", | |
| " [1, 0, 0, 0, 0]]]])" | |
| ] | |
| }, | |
| "execution_count": 125, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "token_expert_selection" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 126, | |
| "id": "751e9da7-8e61-4673-991a-b97fa6b344ba", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "2" | |
| ] | |
| }, | |
| "execution_count": 126, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "capacity = math.floor((top_k * capacity_factor * sequence_length) / num_experts)\n", | |
| "capacity" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "84cf4ab7-7f5a-49d5-b80e-3dd2791c2798", | |
| "metadata": {}, | |
| "source": [ | |
| "# Mask Over-Capacity logits" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 127, | |
| "id": "42b82401-0c1e-441f-8dcf-fcbb81c73321", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# for expert capacity estimation - load_balancing + memory efficiency in training/inference\n", | |
| "\n", | |
| "token_expert_selection_queue = token_expert_selection.reshape(batch_size, top_k * sequence_length, num_experts)\n", | |
| "token_expert_selection_mask = torch.cumsum(token_expert_selection_queue, dim=1) <= capacity\n", | |
| "token_expert_selection_mask = token_expert_selection_mask.reshape(batch_size, top_k, sequence_length, num_experts)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 128, | |
| "id": "5bac2071-b864-40d0-8922-3fc2988b2016", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[ True, True, True, True, True],\n", | |
| " [ True, True, True, True, True],\n", | |
| " [ True, True, True, True, True],\n", | |
| " [ True, True, True, True, True]],\n", | |
| "\n", | |
| " [[ True, False, True, True, True],\n", | |
| " [ True, False, True, True, True],\n", | |
| " [ True, False, False, True, True],\n", | |
| " [ True, False, False, True, True]],\n", | |
| "\n", | |
| " [[ True, False, False, True, True],\n", | |
| " [ True, False, False, True, True],\n", | |
| " [ True, False, False, False, True],\n", | |
| " [ True, False, False, False, True]]]])" | |
| ] | |
| }, | |
| "execution_count": 128, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "token_expert_selection_mask" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 150, | |
| "id": "7d071e28-0ffa-462d-bbb2-060b674624e8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "final_selection = token_expert_selection_mask * token_expert_selection" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 151, | |
| "id": "a4635e1a-03ae-44c7-9c3c-99e38ee1d835", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[0, 0, 1, 0, 0],\n", | |
| " [0, 1, 0, 0, 0],\n", | |
| " [0, 1, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 1]],\n", | |
| "\n", | |
| " [[0, 0, 0, 0, 0],\n", | |
| " [0, 0, 1, 0, 0],\n", | |
| " [0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 1, 0]],\n", | |
| "\n", | |
| " [[0, 0, 0, 0, 1],\n", | |
| " [0, 0, 0, 1, 0],\n", | |
| " [0, 0, 0, 0, 0],\n", | |
| " [1, 0, 0, 0, 0]]]])" | |
| ] | |
| }, | |
| "execution_count": 151, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "final_selection" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 152, | |
| "id": "06302b0b-d905-42d0-8442-e6aa60e1ece5", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[0.0000, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.0000],\n", | |
| " [0.3866, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.2564]],\n", | |
| "\n", | |
| " [[0.0000, 0.0000, 0.0000],\n", | |
| " [0.4281, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.3622, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.2097],\n", | |
| " [0.0000, 0.0000, 0.0000]],\n", | |
| "\n", | |
| " [[0.0000, 0.0000, 0.0000],\n", | |
| " [0.3512, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.0000]],\n", | |
| "\n", | |
| " [[0.0000, 0.0000, 0.2994],\n", | |
| " [0.0000, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.0000],\n", | |
| " [0.0000, 0.2999, 0.0000],\n", | |
| " [0.4007, 0.0000, 0.0000]]]], grad_fn=<PermuteBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 152, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "final_probs = final_selection * top_k_scores_norm\n", | |
| "\n", | |
| "# (batch_size, top_k, sequence_length, num_experts) -> (batch_size, sequence_length, num_experts, top_k)\n", | |
| "\n", | |
| "final_probs = final_probs.permute(0, 2, 3, 1)\n", | |
| "final_probs " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "7d2c3c4e-3f46-4b87-b064-454379f7ee5f", | |
| "metadata": {}, | |
| "source": [ | |
| "# Finally multiply with sequence tokens to get Expert vectors as input" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 148, | |
| "id": "f54f38ea-3399-497b-839c-036c37a2cafd", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[0.4633, 0.4645, 0.1478, 0.0731, 0.8550, 0.0257, 0.9071, 0.7326,\n", | |
| " 0.7471, 0.7278],\n", | |
| " [0.0837, 0.5403, 0.0357, 0.4765, 0.1963, 0.5602, 0.5469, 0.0830,\n", | |
| " 0.1511, 0.8322],\n", | |
| " [0.3197, 0.2551, 0.1655, 0.4526, 0.5493, 0.9306, 0.3617, 0.9077,\n", | |
| " 0.3160, 0.4185],\n", | |
| " [0.4556, 0.4495, 0.4330, 0.9596, 0.6628, 0.0412, 0.4735, 0.0142,\n", | |
| " 0.3043, 0.0344]]])" | |
| ] | |
| }, | |
| "execution_count": 148, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# batch_size * sequence_length * embedding_dim\n", | |
| "tokens" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 157, | |
| "id": "9fe36f16-64c0-426d-85d2-38de1bb3f144", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", | |
| " 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", | |
| " 0.0000, 0.0000],\n", | |
| " [0.1364, 0.1346, 0.1296, 0.2873, 0.1984, 0.0123, 0.1418, 0.0043,\n", | |
| " 0.0911, 0.0103]],\n", | |
| "\n", | |
| " [[0.1481, 0.3209, 0.0734, 0.3630, 0.2769, 0.5667, 0.3612, 0.3544,\n", | |
| " 0.1757, 0.5033],\n", | |
| " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", | |
| " 0.0000, 0.0000],\n", | |
| " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", | |
| " 0.0000, 0.0000]],\n", | |
| "\n", | |
| " [[0.1791, 0.1795, 0.0571, 0.0283, 0.3305, 0.0099, 0.3507, 0.2832,\n", | |
| " 0.2888, 0.2813],\n", | |
| " [0.0303, 0.1957, 0.0129, 0.1726, 0.0711, 0.2029, 0.1981, 0.0301,\n", | |
| " 0.0547, 0.3014],\n", | |
| " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", | |
| " 0.0000, 0.0000]],\n", | |
| "\n", | |
| " [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", | |
| " 0.0000, 0.0000],\n", | |
| " [0.1366, 0.1348, 0.1299, 0.2878, 0.1988, 0.0123, 0.1420, 0.0043,\n", | |
| " 0.0913, 0.0103],\n", | |
| " [0.0175, 0.1133, 0.0075, 0.0999, 0.0411, 0.1175, 0.1147, 0.0174,\n", | |
| " 0.0317, 0.1745]],\n", | |
| "\n", | |
| " [[0.1826, 0.1801, 0.1735, 0.3845, 0.2656, 0.0165, 0.1897, 0.0057,\n", | |
| " 0.1219, 0.0138],\n", | |
| " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", | |
| " 0.0000, 0.0000],\n", | |
| " [0.1188, 0.1191, 0.0379, 0.0187, 0.2192, 0.0066, 0.2326, 0.1878,\n", | |
| " 0.1915, 0.1866]]]], grad_fn=<UnsafeViewBackward0>)" | |
| ] | |
| }, | |
| "execution_count": 157, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "# batch_size, sequence_length, num_experts, top_k\n", | |
| "\n", | |
| "# (batch_size, num_experts, top_k, sequence_length) * (batch_size, sequence_length, embedding_dim)\n", | |
| "\n", | |
| "# batch_size * num_experts * top_k * embedding_dim\n", | |
| "final_probs.permute(0, 2, 3, 1) @ tokens" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 159, | |
| "id": "df1e72f1-7b42-40ad-959f-8c0d4c6a3778", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([0.1364, 0.1346, 0.1296, 0.2873, 0.1984, 0.0123, 0.1418, 0.0043, 0.0911,\n", | |
| " 0.0103])" | |
| ] | |
| }, | |
| "execution_count": 159, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "0.2994 * tokens[0][-1]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 161, | |
| "id": "6cdb3ecf-6867-494c-8511-ef66dee6b5a3", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([0.1481, 0.3209, 0.0734, 0.3630, 0.2769, 0.5667, 0.3611, 0.3543, 0.1757,\n", | |
| " 0.5033])" | |
| ] | |
| }, | |
| "execution_count": 161, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "0.4281 * tokens[0][1] + 0.3512 * tokens[0][2]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "d8422df9-afe5-46fd-b153-ccf99e41b57c", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.11.9" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment