Created
March 9, 2025 18:12
-
-
Save zyklotomic/527cb96da86c2b5f5984bede3be9b227 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "796570dc-d6dc-4494-98c1-f1aed43b40cf", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch version: 2.5.1\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import torch\n", | |
| "import logging\n", | |
| "\n", | |
| "from torch.nn.attention.flex_attention import (\n", | |
| " create_block_mask,\n", | |
| " create_mask,\n", | |
| " flex_attention,\n", | |
| " _mask_mod_signature,\n", | |
| " _score_mod_signature,\n", | |
| " noop_mask,\n", | |
| " and_masks\n", | |
| ")\n", | |
| "\n", | |
| "from torch.nn.attention.flex_attention import flex_attention as reference_flex_attention\n", | |
| "\n", | |
| "from torch import Tensor\n", | |
| "import torch.nn.functional as F\n", | |
| "\n", | |
| "from typing import Optional\n", | |
| "from torch import Tensor\n", | |
| "\n", | |
| "print(\"torch version:\", torch.__version__)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "id": "46a54d12-052a-40d8-9615-d5f828b9d33d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# disabling recompiles logic because too verbose when testing\n", | |
| "# re-enable to double check when compiles are happening when expected\n", | |
| "torch._logging.set_logs(\n", | |
| " dynamo=logging.WARN,\n", | |
| " recompiles=False,\n", | |
| " recompiles_verbose=False\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "id": "89bfd796-6739-4654-9b36-3ad30a11419f", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class DynamicFlexAttention:\n", | |
| " \"\"\"\n", | |
| " This wrapper class around Flex Attention allows for dynamic sequence\n", | |
| " lengths without having to excessively recompile flex_attention.\n", | |
| " It pads the inputs Q, K, V to the size the Flex Attention kernel\n", | |
| " was compiled for and uses Flex Attention's own masking mechanism to\n", | |
| " ignore the padding.\n", | |
| "\n", | |
| " Rebuilds happen when the input sequence length exceeds any past\n", | |
| " sequence length seen before.\n", | |
| "\n", | |
| " Recomputation of the blockmask does unfortunately have to occur\n", | |
| " for each new input.\n", | |
| "\n", | |
| " Caveat/TODOs:\n", | |
| "\n", | |
| " - flex attention fails to compile properly for float64 I think?\n", | |
| " So had to use high atol in torch.allclose\n", | |
| "\n", | |
| " - We assume that the batch size and num heads is\n", | |
| " static between passes. Would trigger kernel rebuilds if otherwise.\n", | |
| "\n", | |
| " - Potentially cache the blockmasks with an LRU/LFU cache?\n", | |
| "\n", | |
| " - Dynamically choose the `flex_attention` kernel too? Pre-compile\n", | |
| " flex_attention kernels in powers of 2? And then binary search/index\n", | |
| " into `ceiling_next_power_of_2(input_seq_len)`? Pretty quick to index\n", | |
| " into and prevent ridiculous padding sizes. Biggest would be in the\n", | |
| " order of double the input size.\n", | |
| "\n", | |
| " - Current interface requires you to pre-specify the mask_mod being used.\n", | |
| " Allow passing in a dictionary of mask_mods? I originally didn't realize\n", | |
| " you unfortunately had to rebuild the blockmask each time.\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " def __init__(\n", | |
| " self,\n", | |
| " bs,\n", | |
| " num_heads,\n", | |
| " mask_mod: _mask_mod_signature = noop_mask,\n", | |
| " compile_options = None\n", | |
| " ):\n", | |
| " # Flex attention fails to compile with dynamic=True in my testing.\n", | |
| " # Hence the whole premise of this wrapper class.\n", | |
| " self._flex_attention = torch.compile(\n", | |
| " flex_attention,\n", | |
| " dynamic=False,\n", | |
| " options=compile_options,\n", | |
| " )\n", | |
| " self.bs = bs\n", | |
| " self.num_heads = num_heads\n", | |
| " self.max_seq_len = 0\n", | |
| " self.mask_mod = mask_mod\n", | |
| "\n", | |
| " def flex_attention(\n", | |
| " self,\n", | |
| " query: Tensor,\n", | |
| " key: Tensor,\n", | |
| " value: Tensor,\n", | |
| " score_mod: Optional[_score_mod_signature] = None,\n", | |
| " scale: Optional[float] = None,\n", | |
| " ) -> Tensor:\n", | |
| " bs, num_heads, q_len, head_dim = query.shape\n", | |
| " _, _, kv_len, _ = key.shape\n", | |
| "\n", | |
| " assert bs == self.bs and num_heads == self.num_heads, \\\n", | |
| " \"Dynamic batch sizes and number of heads not currently \" \\\n", | |
| " + \"supported for performance reasons. Pad inputs accordingly\" \\\n", | |
| " + \" if desired.\"\n", | |
| "\n", | |
| " self.max_seq_len = max(\n", | |
| " q_len,\n", | |
| " kv_len,\n", | |
| " self.max_seq_len\n", | |
| " )\n", | |
| "\n", | |
| " # TODO: See if we can make our own blockmask constructor?\n", | |
| " # Also LFU/LRU caching here?\n", | |
| " # https://x.com/cHHillee/status/1851418255749169419?lang=en\n", | |
| " blockmask = create_block_mask(\n", | |
| " and_masks(\n", | |
| " lambda _b, _h, q_i, kv_i: q_i < q_len,\n", | |
| " lambda _b, _h, q_i, kv_i: kv_i < kv_len,\n", | |
| " self.mask_mod\n", | |
| " ),\n", | |
| " B=None,\n", | |
| " H=None,\n", | |
| " Q_LEN=self.max_seq_len,\n", | |
| " KV_LEN=self.max_seq_len,\n", | |
| " )\n", | |
| "\n", | |
| " padded_q = F.pad(query, (0, 0, 0, self.max_seq_len - q_len))\n", | |
| " padded_k = F.pad(key, (0, 0, 0, self.max_seq_len - kv_len))\n", | |
| " padded_v = F.pad(value, (0, 0, 0, self.max_seq_len - kv_len))\n", | |
| "\n", | |
| " res = self._flex_attention(\n", | |
| " padded_q,\n", | |
| " padded_k,\n", | |
| " padded_v,\n", | |
| " score_mod=score_mod,\n", | |
| " block_mask=blockmask,\n", | |
| " scale=scale\n", | |
| " )\n", | |
| "\n", | |
| " return res[:, :, :q_len, :]\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "id": "3a6ae983-05be-4a4f-aaa9-3545e239fc93", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "torch_compile_options = {\n", | |
| " \"epilogue_fusion\" : True,\n", | |
| " \"max_autotune\" : True,\n", | |
| " \"shape_padding\" : True,\n", | |
| " \"trace.enabled\" : True,\n", | |
| " \"triton.cudagraphs\" : False,\n", | |
| "}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "id": "717a93ba-32a1-4873-9710-843553c52940", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dyn_attn = DynamicFlexAttention(8, 8, compile_options = torch_compile_options)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 47, | |
| "id": "3ddb067b-51f8-4cde-bf6c-93684e6ba939", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def checkerboard(score, batch, head, token_q, token_kv):\n", | |
| " score = torch.where(torch.abs(token_kv - token_q) % 2 == 1, score * 0.5, score)\n", | |
| " score = torch.where(torch.abs(token_kv - token_q) % 2 == 0, score * 2.0, score)\n", | |
| " return score\n", | |
| "\n", | |
| "def attention_is_all_close(attention_fn_1, attention_fn_2):\n", | |
| " # forces dynamic seq len recompilation in the 2nd 512\n", | |
| " for seq_len in [512, 1024, 512]:\n", | |
| " query = torch.randn(8, 8, seq_len, 64, device=\"cuda\", dtype=torch.float32)\n", | |
| " key = torch.randn(8, 8, seq_len, 64, device=\"cuda\", dtype=torch.float32)\n", | |
| " value = torch.randn(8, 8, seq_len, 64, device=\"cuda\", dtype=torch.float32)\n", | |
| " \n", | |
| " res_1 = attention_fn_1(query, key, value)\n", | |
| " res_2 = attention_fn_2(query, key, value)\n", | |
| " if not torch.allclose(res_1, res_2, atol=1e-02):\n", | |
| " return False\n", | |
| " \n", | |
| " res_1 = attention_fn_1(query, key, value, score_mod=checkerboard)\n", | |
| " res_2 = attention_fn_2(query, key, value, score_mod=checkerboard)\n", | |
| " if not torch.allclose(res_1, res_2, atol=1e-02):\n", | |
| " return False\n", | |
| "\n", | |
| " return True\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "83dd70f6-ec05-470c-8ed3-4a8aa840e3c1", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "is_allclose = attention_is_all_close(dyn_attn.flex_attention, reference_flex_attention)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 49, | |
| "id": "953af133-e513-4a6d-9306-36122fdc2544", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "attention all close True\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(\"attention all close\", is_allclose)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "006e5151-0fdd-45d8-be8a-82d22c306ee2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.9" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment