Skip to content

Instantly share code, notes, and snippets.

@zyklotomic
Created March 9, 2025 18:12
Show Gist options
  • Select an option

  • Save zyklotomic/527cb96da86c2b5f5984bede3be9b227 to your computer and use it in GitHub Desktop.

Select an option

Save zyklotomic/527cb96da86c2b5f5984bede3be9b227 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"id": "796570dc-d6dc-4494-98c1-f1aed43b40cf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 2.5.1\n"
]
}
],
"source": [
"import torch\n",
"import logging\n",
"\n",
"from torch.nn.attention.flex_attention import (\n",
" create_block_mask,\n",
" create_mask,\n",
" flex_attention,\n",
" _mask_mod_signature,\n",
" _score_mod_signature,\n",
" noop_mask,\n",
" and_masks\n",
")\n",
"\n",
"from torch.nn.attention.flex_attention import flex_attention as reference_flex_attention\n",
"\n",
"from torch import Tensor\n",
"import torch.nn.functional as F\n",
"\n",
"from typing import Optional\n",
"from torch import Tensor\n",
"\n",
"print(\"torch version:\", torch.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "46a54d12-052a-40d8-9615-d5f828b9d33d",
"metadata": {},
"outputs": [],
"source": [
"# disabling recompiles logic because too verbose when testing\n",
"# re-enable to double check when compiles are happening when expected\n",
"torch._logging.set_logs(\n",
" dynamo=logging.WARN,\n",
" recompiles=False,\n",
" recompiles_verbose=False\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "89bfd796-6739-4654-9b36-3ad30a11419f",
"metadata": {},
"outputs": [],
"source": [
"class DynamicFlexAttention:\n",
" \"\"\"\n",
" This wrapper class around Flex Attention allows for dynamic sequence\n",
" lengths without having to excessively recompile flex_attention.\n",
" It pads the inputs Q, K, V to the size the Flex Attention kernel\n",
" was compiled for and uses Flex Attention's own masking mechanism to\n",
" ignore the padding.\n",
"\n",
" Rebuilds happen when the input sequence length exceeds any past\n",
" sequence length seen before.\n",
"\n",
" Recomputation of the blockmask does unfortunately have to occur\n",
" for each new input.\n",
"\n",
" Caveat/TODOs:\n",
"\n",
" - flex attention fails to compile properly for float64 I think?\n",
" So had to use high atol in torch.allclose\n",
"\n",
" - We assume that the batch size and num heads is\n",
" static between passes. Would trigger kernel rebuilds if otherwise.\n",
"\n",
" - Potentially cache the blockmasks with an LRU/LFU cache?\n",
"\n",
" - Dynamically choose the `flex_attention` kernel too? Pre-compile\n",
" flex_attention kernels in powers of 2? And then binary search/index\n",
" into `ceiling_next_power_of_2(input_seq_len)`? Pretty quick to index\n",
" into and prevent ridiculous padding sizes. Biggest would be in the\n",
" order of double the input size.\n",
"\n",
" - Current interface requires you to pre-specify the mask_mod being used.\n",
" Allow passing in a dictionary of mask_mods? I originally didn't realize\n",
" you unfortunately had to rebuild the blockmask each time.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" bs,\n",
" num_heads,\n",
" mask_mod: _mask_mod_signature = noop_mask,\n",
" compile_options = None\n",
" ):\n",
" # Flex attention fails to compile with dynamic=True in my testing.\n",
" # Hence the whole premise of this wrapper class.\n",
" self._flex_attention = torch.compile(\n",
" flex_attention,\n",
" dynamic=False,\n",
" options=compile_options,\n",
" )\n",
" self.bs = bs\n",
" self.num_heads = num_heads\n",
" self.max_seq_len = 0\n",
" self.mask_mod = mask_mod\n",
"\n",
" def flex_attention(\n",
" self,\n",
" query: Tensor,\n",
" key: Tensor,\n",
" value: Tensor,\n",
" score_mod: Optional[_score_mod_signature] = None,\n",
" scale: Optional[float] = None,\n",
" ) -> Tensor:\n",
" bs, num_heads, q_len, head_dim = query.shape\n",
" _, _, kv_len, _ = key.shape\n",
"\n",
" assert bs == self.bs and num_heads == self.num_heads, \\\n",
" \"Dynamic batch sizes and number of heads not currently \" \\\n",
" + \"supported for performance reasons. Pad inputs accordingly\" \\\n",
" + \" if desired.\"\n",
"\n",
" self.max_seq_len = max(\n",
" q_len,\n",
" kv_len,\n",
" self.max_seq_len\n",
" )\n",
"\n",
" # TODO: See if we can make our own blockmask constructor?\n",
" # Also LFU/LRU caching here?\n",
" # https://x.com/cHHillee/status/1851418255749169419?lang=en\n",
" blockmask = create_block_mask(\n",
" and_masks(\n",
" lambda _b, _h, q_i, kv_i: q_i < q_len,\n",
" lambda _b, _h, q_i, kv_i: kv_i < kv_len,\n",
" self.mask_mod\n",
" ),\n",
" B=None,\n",
" H=None,\n",
" Q_LEN=self.max_seq_len,\n",
" KV_LEN=self.max_seq_len,\n",
" )\n",
"\n",
" padded_q = F.pad(query, (0, 0, 0, self.max_seq_len - q_len))\n",
" padded_k = F.pad(key, (0, 0, 0, self.max_seq_len - kv_len))\n",
" padded_v = F.pad(value, (0, 0, 0, self.max_seq_len - kv_len))\n",
"\n",
" res = self._flex_attention(\n",
" padded_q,\n",
" padded_k,\n",
" padded_v,\n",
" score_mod=score_mod,\n",
" block_mask=blockmask,\n",
" scale=scale\n",
" )\n",
"\n",
" return res[:, :, :q_len, :]\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "3a6ae983-05be-4a4f-aaa9-3545e239fc93",
"metadata": {},
"outputs": [],
"source": [
"torch_compile_options = {\n",
" \"epilogue_fusion\" : True,\n",
" \"max_autotune\" : True,\n",
" \"shape_padding\" : True,\n",
" \"trace.enabled\" : True,\n",
" \"triton.cudagraphs\" : False,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "717a93ba-32a1-4873-9710-843553c52940",
"metadata": {},
"outputs": [],
"source": [
"dyn_attn = DynamicFlexAttention(8, 8, compile_options = torch_compile_options)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "3ddb067b-51f8-4cde-bf6c-93684e6ba939",
"metadata": {},
"outputs": [],
"source": [
"def checkerboard(score, batch, head, token_q, token_kv):\n",
" score = torch.where(torch.abs(token_kv - token_q) % 2 == 1, score * 0.5, score)\n",
" score = torch.where(torch.abs(token_kv - token_q) % 2 == 0, score * 2.0, score)\n",
" return score\n",
"\n",
"def attention_is_all_close(attention_fn_1, attention_fn_2):\n",
" # forces dynamic seq len recompilation in the 2nd 512\n",
" for seq_len in [512, 1024, 512]:\n",
" query = torch.randn(8, 8, seq_len, 64, device=\"cuda\", dtype=torch.float32)\n",
" key = torch.randn(8, 8, seq_len, 64, device=\"cuda\", dtype=torch.float32)\n",
" value = torch.randn(8, 8, seq_len, 64, device=\"cuda\", dtype=torch.float32)\n",
" \n",
" res_1 = attention_fn_1(query, key, value)\n",
" res_2 = attention_fn_2(query, key, value)\n",
" if not torch.allclose(res_1, res_2, atol=1e-02):\n",
" return False\n",
" \n",
" res_1 = attention_fn_1(query, key, value, score_mod=checkerboard)\n",
" res_2 = attention_fn_2(query, key, value, score_mod=checkerboard)\n",
" if not torch.allclose(res_1, res_2, atol=1e-02):\n",
" return False\n",
"\n",
" return True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83dd70f6-ec05-470c-8ed3-4a8aa840e3c1",
"metadata": {},
"outputs": [],
"source": [
"is_allclose = attention_is_all_close(dyn_attn.flex_attention, reference_flex_attention)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "953af133-e513-4a6d-9306-36122fdc2544",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"attention all close True\n"
]
}
],
"source": [
"print(\"attention all close\", is_allclose)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "006e5151-0fdd-45d8-be8a-82d22c306ee2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment