Last active
October 4, 2025 10:52
-
-
Save vukrosic/2c0117344dd269263adf0b6e5382889f to your computer and use it in GitHub Desktop.
Implementing the NVFP4 Recipe From Scratch.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/vukrosic/2c0117344dd269263adf0b6e5382889f/excercise.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "AbCTKh0GJAGS" | |
| }, | |
| "source": [ | |
| "# Implementing the NVFP4 Recipe From Scratch: A Developer's Tutorial\n", | |
| "\n", | |
| "This tutorial deconstructs the core algorithms from PR [#2177](https://github.com/NVIDIA/TransformerEngine/pull/2177/files) to teach you how to implement them conceptually. We will build Python/PyTorch reference functions that mirror the logic of the new C++/CUDA kernels.\n", | |
| "\n", | |
| "Our goal is to implement these key components:\n", | |
| "\n", | |
| "1. **Core 1D Block Quantization**: The fundamental scaling and casting logic for 1x16 blocks.\n", | |
| "2. **2D Block Quantization**: An extension for quantizing 16x16 blocks, ideal for weights.\n", | |
| "3. **Random Hadamard Transform (RHT)**: The pre-quantization step to improve accuracy.\n", | |
| "4. **The Fused Operation**: Combining everything to produce the final `NVFP4Tensor` components.\n", | |
| "\n", | |
| "We will focus on the *algorithmic logic*, not CUDA-level performance optimizations.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "IuWzIXMwJAGT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import math\n", | |
| "\n", | |
| "# For reproducibility\n", | |
| "torch.manual_seed(0)\n", | |
| "torch.cuda.manual_seed(0)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "P3w70_L5JAGT" | |
| }, | |
| "source": [ | |
| "## Step 1: Understanding the Target - The NVFP4 E2M1 Format\n", | |
| "\n", | |
| "Before we can quantize, we need to know what we're converting *to*. NVFP4 in this PR uses the `E2M1` format (2 exponent bits, 1 mantissa bit). It's a 4-bit floating-point number. We can represent all possible 16 values in a lookup table (LUT). This helps us simulate the casting process.\n", | |
| "\n", | |
| "The C++ code uses native `__nv_fp4_e2m1` types, but this LUT is perfect for a Python reference.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "uEdE3LqzJAGT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "# The 16 possible values for an E2M1 FP4 number.\n", | |
| "# Index corresponds to the 4-bit integer representation.\n", | |
| "FP4_E2M1_LUT = torch.tensor([\n", | |
| " # Positive values (first bit 0)\n", | |
| " 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,\n", | |
| " # Negative values (first bit 1)\n", | |
| " -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,\n", | |
| "], dtype=torch.float32)\n", | |
| "\n", | |
| "# The maximum absolute value for E2M1 is 6.0. This is a critical constant.\n", | |
| "FP4_E2M1_MAX_VAL = 6.0\n", | |
| "\n", | |
| "def find_closest_fp4_val(value):\n", | |
| " \"\"\"Simulates casting a float to the nearest FP4 value.\"\"\"\n", | |
| " # Find the value in our LUT that is closest to the input value.\n", | |
| " # The index of this closest value is our 4-bit representation.\n", | |
| " return torch.argmin(torch.abs(value - FP4_E2M1_LUT.to(value.device)))\n", | |
| "\n", | |
| "print(f\"FP4 E2M1 Lookup Table:\\n{FP4_E2M1_LUT}\")\n", | |
| "print(f\"\\nExample: Casting 2.9 to FP4 -> finds value {FP4_E2M1_LUT[find_closest_fp4_val(torch.tensor(2.9))]}\")\n", | |
| "print(f\"Example: Casting -4.2 to FP4 -> finds value {FP4_E2M1_LUT[find_closest_fp4_val(torch.tensor(-4.2))]}\")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "LDtc0nIsJAGT" | |
| }, | |
| "source": [ | |
| "## Step 2: Implementing 1D Block Quantization\n", | |
| "\n", | |
| "This is the core logic. For each 1D block of 16 elements in a tensor row, we perform these steps. This logic is what the reference implementation `quantize_nvfp4_1d` in `test_cast_nvfp4_transpose.cu` performs.\n", | |
| "\n", | |
| "1. Find the absolute maximum value (`amax`) in the 16-element block.\n", | |
| "2. Calculate a `scaling_factor` for this block. The formula is `amax / FP4_E2M1_MAX_VAL`.\n", | |
| "3. **Scale** the original 16 values by dividing by the `scaling_factor`.\n", | |
| "4. **Cast** the scaled values to the nearest FP4 value.\n", | |
| "5. Store the resulting 16 4-bit integers and the single `scaling_factor`.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "r0xhEcdVJAGT" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def quantize_1d_block_reference(hp_tensor: torch.Tensor):\n", | |
| " \"\"\"\n", | |
| " Reference implementation for 1D block quantization (1x16 blocks).\n", | |
| " \"\"\"\n", | |
| " assert hp_tensor.dim() == 2, \"Input must be a 2D tensor\"\n", | |
| " rows, cols = hp_tensor.shape\n", | |
| " assert cols % 16 == 0, \"Columns must be divisible by 16\"\n", | |
| "\n", | |
| " # Outputs\n", | |
| " num_scale_blocks = cols // 16\n", | |
| " quantized_data = torch.zeros(rows, cols, dtype=torch.int8, device=hp_tensor.device)\n", | |
| " scaling_factors = torch.zeros(rows, num_scale_blocks, dtype=hp_tensor.dtype, device=hp_tensor.device)\n", | |
| "\n", | |
| " for i in range(rows):\n", | |
| " for j in range(num_scale_blocks):\n", | |
| " # 1. Get the 1x16 block\n", | |
| " start_col, end_col = j * 16, (j + 1) * 16\n", | |
| " block = hp_tensor[i, start_col:end_col]\n", | |
| "\n", | |
| " # 2. Find amax\n", | |
| " block_amax = torch.max(torch.abs(block))\n", | |
| " if block_amax == 0: # Handle all-zero blocks\n", | |
| " scaling_factors[i, j] = 0.0\n", | |
| " # Quantized data is already 0\n", | |
| " continue\n", | |
| "\n", | |
| " # 3. Calculate scaling factor\n", | |
| " scaling_factor = block_amax / FP4_E2M1_MAX_VAL\n", | |
| " scaling_factors[i, j] = scaling_factor\n", | |
| "\n", | |
| " # 4. Scale the block\n", | |
| " scaled_block = block / scaling_factor\n", | |
| "\n", | |
| " # 5. Cast to FP4 (by finding closest value in LUT)\n", | |
| " for k in range(16):\n", | |
| " quantized_data[i, start_col + k] = find_closest_fp4_val(scaled_block[k])\n", | |
| "\n", | |
| " return quantized_data, scaling_factors\n", | |
| "\n", | |
| "# --- Test it ---\n", | |
| "sample_tensor = torch.randn((2, 32), dtype=torch.bfloat16, device='cuda')\n", | |
| "q_data_1d, scales_1d = quantize_1d_block_reference(sample_tensor)\n", | |
| "\n", | |
| "print(\"--- 1D Quantization Example ---\")\n", | |
| "print(f\"Original Tensor Shape: {sample_tensor.shape}\")\n", | |
| "print(f\"Quantized Data Shape: {q_data_1d.shape} (stores 4-bit integer indices)\")\n", | |
| "print(f\"Scaling Factors Shape: {scales_1d.shape}\")\n", | |
| "print(\"\\nFirst row's scaling factors:\")\n", | |
| "print(scales_1d[0])\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "lkvQ_1pyJAGT" | |
| }, | |
| "source": [ | |
| "## Step 3: Implementing 2D Block Quantization\n", | |
| "\n", | |
| "The PR enables 2D quantization for weights. The logic is similar, but the block size is now 16x16. There is only **one scaling factor for the entire 256-element block**. This is implemented in the reference function `quantize_nvfp4_2d` in `test_cast_nvfp4_transpose.cu`.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "pub7ckEjJAGU" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def quantize_2d_block_reference(hp_tensor: torch.Tensor):\n", | |
| " \"\"\"\n", | |
| " Reference implementation for 2D block quantization (16x16 blocks).\n", | |
| " \"\"\"\n", | |
| " assert hp_tensor.dim() == 2, \"Input must be a 2D tensor\"\n", | |
| " rows, cols = hp_tensor.shape\n", | |
| " assert rows % 16 == 0 and cols % 16 == 0, \"Dimensions must be divisible by 16\"\n", | |
| "\n", | |
| " # Outputs\n", | |
| " num_blocks_y, num_blocks_x = rows // 16, cols // 16\n", | |
| " quantized_data = torch.zeros_like(hp_tensor, dtype=torch.int8)\n", | |
| " scaling_factors = torch.zeros(num_blocks_y, num_blocks_x, dtype=hp_tensor.dtype, device=hp_tensor.device)\n", | |
| "\n", | |
| " for i in range(num_blocks_y):\n", | |
| " for j in range(num_blocks_x):\n", | |
| " # 1. Get the 16x16 block\n", | |
| " start_row, end_row = i * 16, (i + 1) * 16\n", | |
| " start_col, end_col = j * 16, (j + 1) * 16\n", | |
| " block = hp_tensor[start_row:end_row, start_col:end_col]\n", | |
| "\n", | |
| " # 2. Find amax for the entire 16x16 block\n", | |
| " block_amax = torch.max(torch.abs(block))\n", | |
| " if block_amax == 0:\n", | |
| " scaling_factors[i, j] = 0.0\n", | |
| " continue\n", | |
| "\n", | |
| " # 3. Calculate scaling factor\n", | |
| " scaling_factor = block_amax / FP4_E2M1_MAX_VAL\n", | |
| " scaling_factors[i, j] = scaling_factor\n", | |
| "\n", | |
| " # 4. Scale the block\n", | |
| " scaled_block = block / scaling_factor\n", | |
| "\n", | |
| " # 5. Cast to FP4\n", | |
| " # (Vectorized version for simplicity)\n", | |
| " quantized_block = torch.zeros_like(scaled_block, dtype=torch.int8)\n", | |
| " for y in range(16):\n", | |
| " for x in range(16):\n", | |
| " quantized_block[y, x] = find_closest_fp4_val(scaled_block[y, x])\n", | |
| " quantized_data[start_row:end_row, start_col:end_col] = quantized_block\n", | |
| "\n", | |
| " return quantized_data, scaling_factors\n", | |
| "\n", | |
| "\n", | |
| "# --- Test it ---\n", | |
| "sample_tensor_2d = torch.randn((32, 64), dtype=torch.bfloat16, device='cuda')\n", | |
| "q_data_2d, scales_2d = quantize_2d_block_reference(sample_tensor_2d)\n", | |
| "\n", | |
| "print(\"--- 2D Quantization Example ---\")\n", | |
| "print(f\"Original Tensor Shape: {sample_tensor_2d.shape}\")\n", | |
| "print(f\"Quantized Data Shape: {q_data_2d.shape}\")\n", | |
| "print(f\"Scaling Factors Shape: {scales_2d.shape} (2x4 blocks of 16x16)\")\n", | |
| "print(\"\\nScaling factors for all 16x16 blocks:\")\n", | |
| "print(scales_2d)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "QNnQIlU-JAGU" | |
| }, | |
| "source": [ | |
| "## Step 4: Implementing Random Hadamard Transform (RHT)\n", | |
| "\n", | |
| "RHT is a pre-processing step applied to activations before quantization. It's a matrix multiplication with a special \"Hadamard\" matrix. The goal is to distribute the information across the vector, making quantization less lossy. The PR adds highly optimized kernels for this (`hadamard_transform_cast_fusion.cu`).\n", | |
| "\n", | |
| "Our reference will build the matrix and apply it block-wise.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "yRIh9ugLJAGU" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def get_hadamard_matrix(size, device):\n", | |
| " \"\"\"Constructs a Hadamard matrix of a power-of-two size.\"\"\"\n", | |
| " if size == 1:\n", | |
| " return torch.ones((1, 1), device=device)\n", | |
| " h_prev = get_hadamard_matrix(size // 2, device)\n", | |
| " h_next = torch.cat([\n", | |
| " torch.cat([h_prev, h_prev], dim=1),\n", | |
| " torch.cat([h_prev, -h_prev], dim=1),\n", | |
| " ], dim=0)\n", | |
| " return h_next\n", | |
| "\n", | |
| "def random_hadamard_transform_reference(hp_tensor: torch.Tensor):\n", | |
| " \"\"\"Applies a 16x16 RHT to the tensor block-wise.\"\"\"\n", | |
| " rows, cols = hp_tensor.shape\n", | |
| " assert cols % 16 == 0, \"Columns must be divisible by 16\"\n", | |
| "\n", | |
| " # The transform matrix includes normalization\n", | |
| " h_matrix = get_hadamard_matrix(16, hp_tensor.device).to(hp_tensor.dtype)\n", | |
| " h_matrix *= (1.0 / math.sqrt(16))\n", | |
| "\n", | |
| " transformed_tensor = torch.zeros_like(hp_tensor)\n", | |
| "\n", | |
| " for i in range(rows):\n", | |
| " for j in range(cols // 16):\n", | |
| " start_col, end_col = j * 16, (j + 1) * 16\n", | |
| " block = hp_tensor[i, start_col:end_col]\n", | |
| " # Apply the transform: block @ H\n", | |
| " transformed_block = torch.matmul(block, h_matrix)\n", | |
| " transformed_tensor[i, start_col:end_col] = transformed_block\n", | |
| "\n", | |
| " return transformed_tensor\n", | |
| "\n", | |
| "# --- Test it ---\n", | |
| "sample_tensor_rht = torch.randn((1, 32), dtype=torch.bfloat16, device='cuda')\n", | |
| "transformed_tensor = random_hadamard_transform_reference(sample_tensor_rht)\n", | |
| "\n", | |
| "print(\"--- RHT Example ---\")\n", | |
| "print(\"Original first 16 values:\\n\", sample_tensor_rht[0, :16])\n", | |
| "print(\"\\nTransformed first 16 values:\\n\", transformed_tensor[0, :16])\n", | |
| "print(f\"Shape remains the same: {transformed_tensor.shape}\")\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "acAqwXOVJAGU" | |
| }, | |
| "source": [ | |
| "## Step 5: The Fused Operation - Putting It All Together\n", | |
| "\n", | |
| "The true power of the PR is fusing all these steps into a single, efficient CUDA kernel. The kernel performs:\n", | |
| "`Cast -> RHT (optional) -> Quantize -> Transpose -> Quantize (again for transposed layout)`\n", | |
| "\n", | |
| "This avoids materializing intermediate tensors in memory and is much faster. Let's create a Python function that orchestrates our reference components to simulate this entire pipeline. This mimics the `compute_ref` function in `test_cast_nvfp4_transpose.cu`.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "XntTaPY2JAGU" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def nvfp4_recipe_reference(\n", | |
| " hp_tensor: torch.Tensor,\n", | |
| " use_rht: bool,\n", | |
| " use_2d_quant_for_weights: bool # In TE, this only applies to weights, but we simulate it here\n", | |
| "):\n", | |
| " \"\"\"\n", | |
| " Simulates the full, fused quantization pipeline.\n", | |
| " \"\"\"\n", | |
| " # --- Process the input for row-wise (activation) usage ---\n", | |
| " processed_tensor = random_hadamard_transform_reference(hp_tensor) if use_rht else hp_tensor\n", | |
| " # Always use 1D quantization for activations/row-wise data\n", | |
| " q_data, scales = quantize_1d_block_reference(processed_tensor)\n", | |
| "\n", | |
| " # --- Process the input for column-wise (weight) usage ---\n", | |
| " hp_tensor_t = hp_tensor.T.contiguous()\n", | |
| " if use_2d_quant_for_weights:\n", | |
| " # NOTE: Real implementation pads to 16x16 blocks. We'll assume divisible dimensions.\n", | |
| " q_data_t, scales_t = quantize_2d_block_reference(hp_tensor_t)\n", | |
| " else:\n", | |
| " q_data_t, scales_t = quantize_1d_block_reference(hp_tensor_t)\n", | |
| "\n", | |
| " print(\"Simulated fused operation successful!\")\n", | |
| " return q_data, scales, q_data_t, scales_t\n", | |
| "\n", | |
| "# --- Test it with a realistic shape ---\n", | |
| "activation_tensor = torch.randn((128, 2048), dtype=torch.bfloat16, device='cuda')\n", | |
| "\n", | |
| "q_activation, scales_activation, q_weight, scales_weight = nvfp4_recipe_reference(\n", | |
| " activation_tensor,\n", | |
| " use_rht=True,\n", | |
| " use_2d_quant_for_weights=True\n", | |
| ")\n", | |
| "\n", | |
| "print(\"\\n--- Outputs of the Fused Pipeline ---\")\n", | |
| "print(f\"Quantized Activation Shape: {q_activation.shape}\")\n", | |
| "print(f\"Activation Scales Shape: {scales_activation.shape}\")\n", | |
| "print(f\"Quantized Weight (Transposed) Shape: {q_weight.shape}\")\n", | |
| "print(f\"Weight Scales (Transposed) Shape: {scales_weight.shape}\")\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "lYgXcWRMJAGU" | |
| }, | |
| "source": [ | |
| "## Step 6: The `NVFP4Tensor` Data Structure\n", | |
| "\n", | |
| "Finally, why does the PR introduce a new `NVFP4Tensor` class in Python?\n", | |
| "\n", | |
| "Because the results of the fused operation (`q_data`, `scales`, `q_data_t`, `scales_t`) all belong together. They represent a single high-precision tensor in its quantized form. The `NVFP4Tensor` acts as a container for all these components.\n", | |
| "\n", | |
| "When a TE layer needs the tensor for a forward pass GEMM (activations), it uses `q_data` and `scales`. When it needs the tensor for a wgrad GEMM (weights), it uses `q_data_t` and `scales_t`. This avoids costly re-quantization or transposing of packed 4-bit data on the fly.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "NKYnyH0MJAGU" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from dataclasses import dataclass\n", | |
| "\n", | |
| "@dataclass\n", | |
| "class NVFP4TensorReference:\n", | |
| " \"\"\"A Python dataclass to represent the real NVFP4Tensor structure.\"\"\"\n", | |
| " _rowwise_data: torch.Tensor\n", | |
| " _rowwise_scale_inv: torch.Tensor\n", | |
| " _columnwise_data: torch.Tensor\n", | |
| " _columnwise_scale_inv: torch.Tensor\n", | |
| " original_shape: tuple\n", | |
| "\n", | |
| "# Let's package our results into this structure\n", | |
| "nvfp4_tensor_ref = NVFP4TensorReference(\n", | |
| " _rowwise_data=q_activation,\n", | |
| " _rowwise_scale_inv=scales_activation,\n", | |
| " _columnwise_data=q_weight,\n", | |
| " _columnwise_scale_inv=scales_weight,\n", | |
| " original_shape=activation_tensor.shape\n", | |
| ")\n", | |
| "\n", | |
| "print(\"Representation of a complete NVFP4Tensor object:\")\n", | |
| "print(nvfp4_tensor_ref)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "DEJS6bXUJAGU" | |
| }, | |
| "source": [ | |
| "## Conclusion\n", | |
| "\n", | |
| "You have now implemented the core algorithmic building blocks of the NVFP4 recipe from scratch.\n", | |
| "\n", | |
| "You've learned that the implementation is not just a simple cast, but a sophisticated, fused pipeline that involves:\n", | |
| "1. **Block-based Scaling**: Calculating per-block scaling factors (either 1D or 2D).\n", | |
| "2. **Optional Pre-processing (RHT)**: Applying a mathematical transform to improve numerical stability.\n", | |
| "3. **Fused Operations**: Performing quantization and transposition in a single step to generate layouts for both forward and backward passes efficiently.\n", | |
| "4. **A Specialized Data Structure**: Using `NVFP4Tensor` to hold all the necessary components (data, scales, transposed versions) together.\n", | |
| "\n", | |
| "The actual C++/CUDA code in the PR takes these exact algorithms and implements them with extreme performance optimizations, using techniques like shared memory, tensor core instructions, and careful data movement to make 4-bit training feasible at scale.\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "niTGw1h_JAGU" | |
| }, | |
| "source": [ | |
| "# Advanced Lessons: Implementing the NVFP4 Recipe\n", | |
| "\n", | |
| "Welcome to the advanced implementation tutorial for the NVFP4 recipe. In the previous session, we built high-level Python models of the core algorithms. Now, we will dissect the engineering principles and low-level details from the PR to understand how this is implemented for maximum performance on a GPU.\n", | |
| "\n", | |
| "### Learning Path:\n", | |
| "* **Lesson 1: The \"Why\" of Fused Kernels** - Why not just call the Python functions in sequence?\n", | |
| "* **Lesson 2: Anatomy of the CUDA Kernel** - A conceptual breakdown of the C++ `block_scaled_1d_cast_transpose_kernel`.\n", | |
| "* **Lesson 3: The Nuances of Two-Level Scaling** - Understanding the global (`S_enc`) and local (`S_dec_b`) scaling factors.\n", | |
| "* **Lesson 4: Distributed Training & Quantized All-Gather** - How to handle custom data types in a multi-GPU setting.\n", | |
| "* **Lesson 5: The Python API Glue** - How the `NVFP4Quantizer` class orchestrates everything.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "KyXORSe0JAGU" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import math\n", | |
| "\n", | |
| "# For reproducibility\n", | |
| "torch.manual_seed(0)\n", | |
| "torch.cuda.manual_seed(0)\n", | |
| "\n", | |
| "# Constants from the previous lesson\n", | |
| "FP4_E2M1_MAX_VAL = 6.0\n", | |
| "# A new constant from the PR: the max value of an FP8 E4M3 number, used for scaling factors.\n", | |
| "FP8_E4M3_MAX_VAL = 448.0\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "u4AW5GXWJAGU" | |
| }, | |
| "source": [ | |
| "## Lesson 1: The \"Why\" of Fused Kernels - The Memory Bottleneck\n", | |
| "\n", | |
| "In our previous tutorial, we implemented each step (RHT, Quantize, Transpose) as a separate Python function. On a real GPU, this would be incredibly inefficient. Why? **Memory Bandwidth**.\n", | |
| "\n", | |
| "A GPU is fastest when it's doing math (computing). It's relatively slow when it's moving data between its main memory (HBM) and its compute cores. Operations like ours are often **memory-bound**, meaning the GPU spends more time waiting for data than computing on it.\n", | |
| "\n", | |
| "Consider the \"naive\" approach:\n", | |
| "1. `hp_tensor` is in Global Memory.\n", | |
| "2. **Kernel 1 (RHT)**: Load `hp_tensor`, compute RHT, write `rht_tensor` back to Global Memory.\n", | |
| "3. **Kernel 2 (Amax)**: Load `rht_tensor`, compute amax, write `amax_tensor` back to Global Memory.\n", | |
| "4. **Kernel 3 (Quantize)**: Load `rht_tensor` and `amax_tensor`, compute scales and quantized data, write `q_tensor` and `scales_tensor` to Global Memory.\n", | |
| "5. ...and so on for the transpose.\n", | |
| "\n", | |
| "This involves multiple round-trips to slow global memory. A **fused kernel**, like the one in this PR (`quantize_transpose_vector_blockwise_fp4.cu`), does all of this in a single trip.\n", | |
| "\n", | |
| "### The Fused Kernel Strategy:\n", | |
| "1. **Launch ONE Kernel.**\n", | |
| "2. Threads load a small tile of the `hp_tensor` from Global Memory into ultra-fast **Shared Memory**.\n", | |
| "3. Perform all operations (RHT, amax reduction, scaling, casting) directly on the data in Shared Memory.\n", | |
| "4. Write the final, tiny outputs (`q_tensor` tile, `scales_tensor` tile) back to Global Memory.\n", | |
| "\n", | |
| "This minimizes global memory traffic and maximizes computation, leading to massive speedups. The entire PR is built around this principle.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "QElLt6c8JAGU" | |
| }, | |
| "source": [ | |
| "## Lesson 2: Anatomy of a Fused CUDA Kernel\n", | |
| "\n", | |
| "Let's write a \"pseudo-code\" walkthrough of the main kernel. We can't run CUDA C++ here, but we can model its logic and structure in Python comments to understand how it works. We'll focus on the `block_scaled_1d_cast_transpose_kernel` logic from the new C++ tests.\n", | |
| "\n", | |
| "A CUDA kernel is executed by a grid of *thread blocks*. Each block is responsible for processing one \"tile\" of the input data. Inside a block, threads cooperate using **Shared Memory**.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "32ixAc-zJAGU" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def conceptual_fused_kernel(hp_tensor):\n", | |
| " \"\"\"A Python simulation of the fused kernel's logic for a single 16x16 tile.\"\"\"\n", | |
| " # --- Kernel Launch Setup (Done by the CUDA runtime) ---\n", | |
| " # Imagine this function is ONE thread block, given an index (blockIdx.x, blockIdx.y)\n", | |
| " # to identify which 16x16 tile of the hp_tensor it should process.\n", | |
| " # Let's assume this block is responsible for the tile starting at (0, 0).\n", | |
| " TILE_DIM = 16\n", | |
| " block_start_row, block_start_col = 0, 0\n", | |
| "\n", | |
| " # --- Inside the Kernel (Execution on GPU) ---\n", | |
| "\n", | |
| " # 1. Cooperative Loading into Shared Memory\n", | |
| " # Each of the 256 threads in the block loads one element from global HBM\n", | |
| " # into the fast, on-chip shared memory scratchpad.\n", | |
| " shared_mem_tile = hp_tensor[\n", | |
| " block_start_row : block_start_row + TILE_DIM,\n", | |
| " block_start_col : block_start_col + TILE_DIM\n", | |
| " ].clone() # .clone() simulates the copy to a new memory space.\n", | |
| " # In CUDA, a `__syncthreads()` barrier would wait for all loads to complete.\n", | |
| "\n", | |
| " # 2. On-Chip AMAX Reduction (Row-wise)\n", | |
| " # The threads now work on the fast shared memory tile.\n", | |
| " # They cooperatively find the amax for each of the 16 rows in the tile.\n", | |
| " row_amaxes = torch.max(torch.abs(shared_mem_tile), dim=1).values\n", | |
| " # This is a simplified view. In CUDA, this is a multi-step reduction using\n", | |
| " # warp-level primitives (`__shfl_down_sync`) and another `__syncthreads()`.\n", | |
| "\n", | |
| " # 3. Calculate Row-wise Scaling Factors\n", | |
| " row_scales = row_amaxes / FP4_E2M1_MAX_VAL\n", | |
| " # Handle division by zero for all-zero rows\n", | |
| " row_scales[row_scales == 0] = 1.0\n", | |
| "\n", | |
| " # 4. Scale and Cast (Row-wise)\n", | |
| " # Each thread scales its value and simulates the cast.\n", | |
| " # The actual CUDA kernel uses a PTX instruction like `cvt.rn.satfinite.e2m1x2.f32`\n", | |
| " # which converts two FP32 numbers to two packed FP4 numbers in one go.\n", | |
| " scaled_tile = shared_mem_tile / row_scales.unsqueeze(1)\n", | |
| " quantized_tile = torch.round(scaled_tile).clamp(-FP4_E2M1_MAX_VAL, FP4_E2M1_MAX_VAL) # Simplified cast logic\n", | |
| "\n", | |
| " # 5. On-Chip Transposition\n", | |
| " # Threads cooperatively write to a second shared memory buffer in a transposed pattern.\n", | |
| " transposed_shared_mem_tile = shared_mem_tile.T.contiguous()\n", | |
| " # `__syncthreads()` ensures the transpose is complete.\n", | |
| "\n", | |
| " # 6. AMAX, Scale, and Cast (Column-wise / Transposed)\n", | |
| " # The process is repeated on the transposed tile to get the column-wise outputs.\n", | |
| " col_amaxes = torch.max(torch.abs(transposed_shared_mem_tile), dim=1).values\n", | |
| " col_scales = col_amaxes / FP4_E2M1_MAX_VAL\n", | |
| " col_scales[col_scales == 0] = 1.0\n", | |
| " scaled_transposed_tile = transposed_shared_mem_tile / col_scales.unsqueeze(1)\n", | |
| " quantized_transposed_tile = torch.round(scaled_transposed_tile).clamp(-FP4_E2M1_MAX_VAL, FP4_E2M1_MAX_VAL)\n", | |
| "\n", | |
| " # 7. Write Final Results to Global Memory\n", | |
| " # The threads write their final results from shared memory back to the final output tensors in HBM.\n", | |
| " # This is the only other time they touch global memory.\n", | |
| " print(\"Conceptual kernel finished processing one tile.\")\n", | |
| " return quantized_tile, row_scales, quantized_transposed_tile, col_scales\n", | |
| "\n", | |
| "# --- Run the conceptual model ---\n", | |
| "sample_tile = torch.randn((16, 16), dtype=torch.float32, device='cuda')\n", | |
| "q_data, scales, q_data_t, scales_t = conceptual_fused_kernel(sample_tile)\n", | |
| "\n", | |
| "print(f\"\\nRow-wise quantized data shape: {q_data.shape}\")\n", | |
| "print(f\"Row-wise scales shape: {scales.shape} (One scale per row in the tile)\")\n", | |
| "print(f\"Column-wise quantized data shape: {q_data_t.shape}\")\n", | |
| "print(f\"Column-wise scales shape: {scales_t.shape} (One scale per column in the tile)\")\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "2wdwxO9IJAGV" | |
| }, | |
| "source": [ | |
| "## Lesson 3: The Nuances of Two-Level Scaling\n", | |
| "\n", | |
| "The previous lessons used a simplified scaling formula: `scale = amax / 6.0`. The actual implementation in the PR is more sophisticated, as seen in the C++ function `compute_global_encode_scaling_factor_FP4`. It uses a **two-level scaling system**.\n", | |
| "\n", | |
| "1. **Global Per-Tensor Scale (`S_enc`)**: A single FP32 scale factor is computed for the *entire tensor*. Its job is to map the tensor's global amax into a range that is friendly to FP8-E4M3, the format used for the *scaling factors themselves*.\n", | |
| "\n", | |
| "2. **Local Per-Block Scale (`S_dec_b`)**: This is the scale we've been calculating (`block_amax / 6.0`). It handles local variations.\n", | |
| "\n", | |
| "**The final scaling factor stored in memory is `S_final = S_dec_b * S_enc`**.\n", | |
| "\n", | |
| "Why do this? It improves numerical precision. By pre-scaling the entire tensor with `S_enc`, we ensure that the per-block `S_dec_b` values can be accurately represented by the FP8-E4M3 format.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "qpUmg3QzJAGV" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def two_level_scaling_reference(hp_tensor: torch.Tensor):\n", | |
| " \"\"\"Reference implementation for the two-level scaling logic.\"\"\"\n", | |
| " # -- Level 1: Global Scaling --\n", | |
| " global_amax = torch.max(torch.abs(hp_tensor))\n", | |
| "\n", | |
| " # This formula is a direct translation of the C++ `compute_global_encode_scaling_factor_FP4`\n", | |
| " # It maps the global amax to the dynamic range of FP8 * FP4\n", | |
| " if global_amax == 0.0:\n", | |
| " S_enc = 1.0\n", | |
| " else:\n", | |
| " S_enc = (FP8_E4M3_MAX_VAL * FP4_E2M1_MAX_VAL) / global_amax\n", | |
| " S_enc = min(S_enc, torch.finfo(torch.float32).max) # Clamp to max float32\n", | |
| "\n", | |
| " # -- Level 2: Local Scaling (within a 1D block) --\n", | |
| " rows, cols = hp_tensor.shape\n", | |
| " num_scale_blocks = cols // 16\n", | |
| " final_scales = torch.zeros(rows, num_scale_blocks, dtype=torch.float32, device=hp_tensor.device)\n", | |
| "\n", | |
| " for i in range(rows):\n", | |
| " for j in range(num_scale_blocks):\n", | |
| " block = hp_tensor[i, j*16:(j+1)*16]\n", | |
| " block_amax = torch.max(torch.abs(block))\n", | |
| "\n", | |
| " # Calculate the local decoding scale\n", | |
| " S_dec_b = block_amax / FP4_E2M1_MAX_VAL\n", | |
| "\n", | |
| " # Combine with global encoding scale to get the final scale\n", | |
| " S_final = S_dec_b * S_enc\n", | |
| "\n", | |
| " # The final scale is then cast to FP8 E4M3 for storage.\n", | |
| " # We will just store it as float32 for this reference.\n", | |
| " final_scales[i, j] = S_final\n", | |
| "\n", | |
| " print(f\"Global Amax: {global_amax:.4f}, S_enc (Global Scale): {S_enc:.4f}\")\n", | |
| " return final_scales\n", | |
| "\n", | |
| "# --- Test the two-level scaling ---\n", | |
| "sample_tensor = torch.randn((2, 32), device='cuda') * 10 # Scale up to see a more interesting amax\n", | |
| "final_scaling_factors = two_level_scaling_reference(sample_tensor)\n", | |
| "print(\"\\nFinal (two-level) scaling factors for the first row:\")\n", | |
| "print(final_scaling_factors[0])\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "O9cCbq4WJAGV" | |
| }, | |
| "source": [ | |
| "## Lesson 4: Distributed Training & Quantized All-Gather\n", | |
| "\n", | |
| "Making a new feature work on one GPU is only half the battle. For large models, it must work with tensor parallelism across multiple GPUs. This PR adds a custom `_all_gather_nvfp4` function in `transformer_engine/pytorch/distributed.py`.\n", | |
| "\n", | |
| "**The Problem**: You can't just call `torch.distributed.all_gather` on an `NVFP4Tensor` object. The All-Gather operation only works on single, contiguous `torch.Tensor`s.\n", | |
| "\n", | |
| "**The Solution**:\n", | |
| "1. Deconstruct the `NVFP4Tensor` on each GPU into its constituent `torch.Tensor` components (e.g., `_rowwise_data`, `_rowwise_scale_inv`).\n", | |
| "2. Perform a separate `all_gather` operation on each component tensor.\n", | |
| "3. Reconstruct a new, larger `NVFP4Tensor` on each GPU from the gathered components.\n", | |
| "\n", | |
| "**A New Problem (The \"Interleave\" Issue)**: When you gather a *transposed* tensor (like `_columnwise_data`) along the batch dimension, the data from different GPUs gets interleaved incorrectly.\n", | |
| "\n", | |
| "Imagine 2 GPUs. GPU0 has `[A0, B0]` and GPU1 has `[A1, B1]`. After gathering, the memory layout isn't `[A0, B0, A1, B1]`. It becomes something like `[A0, A1, B0, B1]`.\n", | |
| "\n", | |
| "To fix this, the PR adds a `swap_first_dims` operation. Let's simulate this.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "id": "wUWKQa8QJAGV" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def simulate_distributed_gather_and_fix():\n", | |
| " world_size = 4 # Simulate 4 GPUs\n", | |
| " local_dim0, local_dim1 = 2, 8\n", | |
| "\n", | |
| " # Create dummy transposed data on each GPU\n", | |
| " gpu_data = [torch.arange(local_dim0 * local_dim1, dtype=torch.float32).reshape(local_dim0, local_dim1) + (i*100) for i in range(world_size)]\n", | |
| " print(f\"--- Data on GPU 0 (Transposed Layout) ---\\n{gpu_data[0]}\")\n", | |
| "\n", | |
| " # Simulate `all_gather` on the first dimension. This creates the interleaved result.\n", | |
| " interleaved_data = torch.cat(gpu_data, dim=0)\n", | |
| " print(f\"\\n--- Interleaved Data After All-Gather ---\\n{interleaved_data}\")\n", | |
| "\n", | |
| " # The `swap_first_dims` logic to fix the layout\n", | |
| " # This is what `tex.swap_first_dims` in the PR does in a highly optimized way.\n", | |
| " total_dim0 = interleaved_data.shape[0]\n", | |
| " fixed_data = interleaved_data.reshape(world_size, total_dim0 // world_size, -1).transpose(0, 1).reshape(total_dim0, -1)\n", | |
| "\n", | |
| " print(f\"\\n--- Data After `swap_first_dims` Fix ---\\n{fixed_data}\")\n", | |
| "\n", | |
| "simulate_distributed_gather_and_fix()\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "6CBe328SJAGV" | |
| }, | |
| "source": [ | |
| "## Lesson 5: The Python API Glue - `NVFP4Quantizer`\n", | |
| "\n", | |
| "The `NVFP4Quantizer` class in `transformer_engine/pytorch/tensor/nvfp4_tensor.py` is the high-level orchestrator. It's the bridge between the Python world and the C++/CUDA backend.\n", | |
| "\n", | |
| "Let's break down its key responsibilities based on the PR:\n", | |
| "\n", | |
| "1. **Configuration (`__init__`)**: It reads the `Recipe` object and stores flags like `with_rht`, `stochastic_rounding`, and `with_2d_quantization`. It also pre-builds the RHT matrix if needed.\n", | |
| "\n", | |
| "2. **State Management**: It holds stateful information. For example, it generates and stores the random sign mask for the RHT matrix.\n", | |
| "\n", | |
| "3. **Backend Invocation (`quantize`)**: This is the main method. It takes a high-precision `torch.Tensor` as input.\n", | |
| " * It checks the tensor shape and properties.\n", | |
| " * It packages all the configuration flags and tensor pointers into a C-compatible structure (`QuantizationConfigWrapper`).\n", | |
| " * It calls the core C++ function (e.g., `tex.quantize_fp4`) through the Pybind11 bridge. This is the function that launches the fused CUDA kernel we discussed in Lesson 2.\n", | |
| "\n", | |
| "4. **Object Creation**: The C++ function returns raw tensor data. The `NVFP4Quantizer` takes this raw data and uses it to construct and return a proper, user-friendly `NVFP4Tensor` Python object.\n", | |
| "\n", | |
| "This class design cleanly separates the high-level configuration and object management in Python from the low-level, high-performance computations in C++/CUDA.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "e7jPBHUNJAGV" | |
| }, | |
| "source": [ | |
| "## Grand Conclusion\n", | |
| "\n", | |
| "You have now journeyed from a high-level user of the NVFP4 recipe to understanding the deepest implementation details. You've learned:\n", | |
| "\n", | |
| "- **Performance is King**: Fused kernels are essential to overcome memory bandwidth limitations, which is the primary motivation for the C++/CUDA implementation.\n", | |
| "- **CUDA Programming Patterns**: Thread blocks, shared memory, and cooperative execution are the tools used to build these fused kernels.\n", | |
| "- **Numerical Precision Matters**: The two-level scaling system is a clever trick to maintain accuracy when the scaling factors themselves must be stored in a low-precision format.\n", | |
| "- **Distributed Systems are Complex**: Features must be designed with multi-GPU execution in mind, often requiring custom communication patterns like the fix for interleaved gathering.\n", | |
| "- **APIs are Abstractions**: The Python `NVFP4Quantizer` class provides a clean interface that hides the immense complexity of the underlying C++/CUDA/distributed logic.\n", | |
| "\n", | |
| "You are now well-equipped to read through the files in PR #2177, such as `quantize_transpose_vector_blockwise_fp4.cu` and `distributed.py`, and recognize the patterns and algorithms we've discussed here." | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.0" | |
| }, | |
| "colab": { | |
| "provenance": [], | |
| "name": "Implementing the NVFP4 Recipe From Scratch.ipynb", | |
| "include_colab_link": true | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment