Skip to content

Instantly share code, notes, and snippets.

@vukrosic
Last active October 4, 2025 10:52
Show Gist options
  • Select an option

  • Save vukrosic/2c0117344dd269263adf0b6e5382889f to your computer and use it in GitHub Desktop.

Select an option

Save vukrosic/2c0117344dd269263adf0b6e5382889f to your computer and use it in GitHub Desktop.
Implementing the NVFP4 Recipe From Scratch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/vukrosic/2c0117344dd269263adf0b6e5382889f/excercise.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AbCTKh0GJAGS"
},
"source": [
"# Implementing the NVFP4 Recipe From Scratch: A Developer's Tutorial\n",
"\n",
"This tutorial deconstructs the core algorithms from PR [#2177](https://github.com/NVIDIA/TransformerEngine/pull/2177/files) to teach you how to implement them conceptually. We will build Python/PyTorch reference functions that mirror the logic of the new C++/CUDA kernels.\n",
"\n",
"Our goal is to implement these key components:\n",
"\n",
"1. **Core 1D Block Quantization**: The fundamental scaling and casting logic for 1x16 blocks.\n",
"2. **2D Block Quantization**: An extension for quantizing 16x16 blocks, ideal for weights.\n",
"3. **Random Hadamard Transform (RHT)**: The pre-quantization step to improve accuracy.\n",
"4. **The Fused Operation**: Combining everything to produce the final `NVFP4Tensor` components.\n",
"\n",
"We will focus on the *algorithmic logic*, not CUDA-level performance optimizations.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IuWzIXMwJAGT"
},
"outputs": [],
"source": [
"import torch\n",
"import math\n",
"\n",
"# For reproducibility\n",
"torch.manual_seed(0)\n",
"torch.cuda.manual_seed(0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P3w70_L5JAGT"
},
"source": [
"## Step 1: Understanding the Target - The NVFP4 E2M1 Format\n",
"\n",
"Before we can quantize, we need to know what we're converting *to*. NVFP4 in this PR uses the `E2M1` format (2 exponent bits, 1 mantissa bit). It's a 4-bit floating-point number. We can represent all possible 16 values in a lookup table (LUT). This helps us simulate the casting process.\n",
"\n",
"The C++ code uses native `__nv_fp4_e2m1` types, but this LUT is perfect for a Python reference.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uEdE3LqzJAGT"
},
"outputs": [],
"source": [
"# The 16 possible values for an E2M1 FP4 number.\n",
"# Index corresponds to the 4-bit integer representation.\n",
"FP4_E2M1_LUT = torch.tensor([\n",
" # Positive values (first bit 0)\n",
" 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,\n",
" # Negative values (first bit 1)\n",
" -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,\n",
"], dtype=torch.float32)\n",
"\n",
"# The maximum absolute value for E2M1 is 6.0. This is a critical constant.\n",
"FP4_E2M1_MAX_VAL = 6.0\n",
"\n",
"def find_closest_fp4_val(value):\n",
" \"\"\"Simulates casting a float to the nearest FP4 value.\"\"\"\n",
" # Find the value in our LUT that is closest to the input value.\n",
" # The index of this closest value is our 4-bit representation.\n",
" return torch.argmin(torch.abs(value - FP4_E2M1_LUT.to(value.device)))\n",
"\n",
"print(f\"FP4 E2M1 Lookup Table:\\n{FP4_E2M1_LUT}\")\n",
"print(f\"\\nExample: Casting 2.9 to FP4 -> finds value {FP4_E2M1_LUT[find_closest_fp4_val(torch.tensor(2.9))]}\")\n",
"print(f\"Example: Casting -4.2 to FP4 -> finds value {FP4_E2M1_LUT[find_closest_fp4_val(torch.tensor(-4.2))]}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LDtc0nIsJAGT"
},
"source": [
"## Step 2: Implementing 1D Block Quantization\n",
"\n",
"This is the core logic. For each 1D block of 16 elements in a tensor row, we perform these steps. This logic is what the reference implementation `quantize_nvfp4_1d` in `test_cast_nvfp4_transpose.cu` performs.\n",
"\n",
"1. Find the absolute maximum value (`amax`) in the 16-element block.\n",
"2. Calculate a `scaling_factor` for this block. The formula is `amax / FP4_E2M1_MAX_VAL`.\n",
"3. **Scale** the original 16 values by dividing by the `scaling_factor`.\n",
"4. **Cast** the scaled values to the nearest FP4 value.\n",
"5. Store the resulting 16 4-bit integers and the single `scaling_factor`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "r0xhEcdVJAGT"
},
"outputs": [],
"source": [
"def quantize_1d_block_reference(hp_tensor: torch.Tensor):\n",
" \"\"\"\n",
" Reference implementation for 1D block quantization (1x16 blocks).\n",
" \"\"\"\n",
" assert hp_tensor.dim() == 2, \"Input must be a 2D tensor\"\n",
" rows, cols = hp_tensor.shape\n",
" assert cols % 16 == 0, \"Columns must be divisible by 16\"\n",
"\n",
" # Outputs\n",
" num_scale_blocks = cols // 16\n",
" quantized_data = torch.zeros(rows, cols, dtype=torch.int8, device=hp_tensor.device)\n",
" scaling_factors = torch.zeros(rows, num_scale_blocks, dtype=hp_tensor.dtype, device=hp_tensor.device)\n",
"\n",
" for i in range(rows):\n",
" for j in range(num_scale_blocks):\n",
" # 1. Get the 1x16 block\n",
" start_col, end_col = j * 16, (j + 1) * 16\n",
" block = hp_tensor[i, start_col:end_col]\n",
"\n",
" # 2. Find amax\n",
" block_amax = torch.max(torch.abs(block))\n",
" if block_amax == 0: # Handle all-zero blocks\n",
" scaling_factors[i, j] = 0.0\n",
" # Quantized data is already 0\n",
" continue\n",
"\n",
" # 3. Calculate scaling factor\n",
" scaling_factor = block_amax / FP4_E2M1_MAX_VAL\n",
" scaling_factors[i, j] = scaling_factor\n",
"\n",
" # 4. Scale the block\n",
" scaled_block = block / scaling_factor\n",
"\n",
" # 5. Cast to FP4 (by finding closest value in LUT)\n",
" for k in range(16):\n",
" quantized_data[i, start_col + k] = find_closest_fp4_val(scaled_block[k])\n",
"\n",
" return quantized_data, scaling_factors\n",
"\n",
"# --- Test it ---\n",
"sample_tensor = torch.randn((2, 32), dtype=torch.bfloat16, device='cuda')\n",
"q_data_1d, scales_1d = quantize_1d_block_reference(sample_tensor)\n",
"\n",
"print(\"--- 1D Quantization Example ---\")\n",
"print(f\"Original Tensor Shape: {sample_tensor.shape}\")\n",
"print(f\"Quantized Data Shape: {q_data_1d.shape} (stores 4-bit integer indices)\")\n",
"print(f\"Scaling Factors Shape: {scales_1d.shape}\")\n",
"print(\"\\nFirst row's scaling factors:\")\n",
"print(scales_1d[0])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lkvQ_1pyJAGT"
},
"source": [
"## Step 3: Implementing 2D Block Quantization\n",
"\n",
"The PR enables 2D quantization for weights. The logic is similar, but the block size is now 16x16. There is only **one scaling factor for the entire 256-element block**. This is implemented in the reference function `quantize_nvfp4_2d` in `test_cast_nvfp4_transpose.cu`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pub7ckEjJAGU"
},
"outputs": [],
"source": [
"def quantize_2d_block_reference(hp_tensor: torch.Tensor):\n",
" \"\"\"\n",
" Reference implementation for 2D block quantization (16x16 blocks).\n",
" \"\"\"\n",
" assert hp_tensor.dim() == 2, \"Input must be a 2D tensor\"\n",
" rows, cols = hp_tensor.shape\n",
" assert rows % 16 == 0 and cols % 16 == 0, \"Dimensions must be divisible by 16\"\n",
"\n",
" # Outputs\n",
" num_blocks_y, num_blocks_x = rows // 16, cols // 16\n",
" quantized_data = torch.zeros_like(hp_tensor, dtype=torch.int8)\n",
" scaling_factors = torch.zeros(num_blocks_y, num_blocks_x, dtype=hp_tensor.dtype, device=hp_tensor.device)\n",
"\n",
" for i in range(num_blocks_y):\n",
" for j in range(num_blocks_x):\n",
" # 1. Get the 16x16 block\n",
" start_row, end_row = i * 16, (i + 1) * 16\n",
" start_col, end_col = j * 16, (j + 1) * 16\n",
" block = hp_tensor[start_row:end_row, start_col:end_col]\n",
"\n",
" # 2. Find amax for the entire 16x16 block\n",
" block_amax = torch.max(torch.abs(block))\n",
" if block_amax == 0:\n",
" scaling_factors[i, j] = 0.0\n",
" continue\n",
"\n",
" # 3. Calculate scaling factor\n",
" scaling_factor = block_amax / FP4_E2M1_MAX_VAL\n",
" scaling_factors[i, j] = scaling_factor\n",
"\n",
" # 4. Scale the block\n",
" scaled_block = block / scaling_factor\n",
"\n",
" # 5. Cast to FP4\n",
" # (Vectorized version for simplicity)\n",
" quantized_block = torch.zeros_like(scaled_block, dtype=torch.int8)\n",
" for y in range(16):\n",
" for x in range(16):\n",
" quantized_block[y, x] = find_closest_fp4_val(scaled_block[y, x])\n",
" quantized_data[start_row:end_row, start_col:end_col] = quantized_block\n",
"\n",
" return quantized_data, scaling_factors\n",
"\n",
"\n",
"# --- Test it ---\n",
"sample_tensor_2d = torch.randn((32, 64), dtype=torch.bfloat16, device='cuda')\n",
"q_data_2d, scales_2d = quantize_2d_block_reference(sample_tensor_2d)\n",
"\n",
"print(\"--- 2D Quantization Example ---\")\n",
"print(f\"Original Tensor Shape: {sample_tensor_2d.shape}\")\n",
"print(f\"Quantized Data Shape: {q_data_2d.shape}\")\n",
"print(f\"Scaling Factors Shape: {scales_2d.shape} (2x4 blocks of 16x16)\")\n",
"print(\"\\nScaling factors for all 16x16 blocks:\")\n",
"print(scales_2d)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QNnQIlU-JAGU"
},
"source": [
"## Step 4: Implementing Random Hadamard Transform (RHT)\n",
"\n",
"RHT is a pre-processing step applied to activations before quantization. It's a matrix multiplication with a special \"Hadamard\" matrix. The goal is to distribute the information across the vector, making quantization less lossy. The PR adds highly optimized kernels for this (`hadamard_transform_cast_fusion.cu`).\n",
"\n",
"Our reference will build the matrix and apply it block-wise.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yRIh9ugLJAGU"
},
"outputs": [],
"source": [
"def get_hadamard_matrix(size, device):\n",
" \"\"\"Constructs a Hadamard matrix of a power-of-two size.\"\"\"\n",
" if size == 1:\n",
" return torch.ones((1, 1), device=device)\n",
" h_prev = get_hadamard_matrix(size // 2, device)\n",
" h_next = torch.cat([\n",
" torch.cat([h_prev, h_prev], dim=1),\n",
" torch.cat([h_prev, -h_prev], dim=1),\n",
" ], dim=0)\n",
" return h_next\n",
"\n",
"def random_hadamard_transform_reference(hp_tensor: torch.Tensor):\n",
" \"\"\"Applies a 16x16 RHT to the tensor block-wise.\"\"\"\n",
" rows, cols = hp_tensor.shape\n",
" assert cols % 16 == 0, \"Columns must be divisible by 16\"\n",
"\n",
" # The transform matrix includes normalization\n",
" h_matrix = get_hadamard_matrix(16, hp_tensor.device).to(hp_tensor.dtype)\n",
" h_matrix *= (1.0 / math.sqrt(16))\n",
"\n",
" transformed_tensor = torch.zeros_like(hp_tensor)\n",
"\n",
" for i in range(rows):\n",
" for j in range(cols // 16):\n",
" start_col, end_col = j * 16, (j + 1) * 16\n",
" block = hp_tensor[i, start_col:end_col]\n",
" # Apply the transform: block @ H\n",
" transformed_block = torch.matmul(block, h_matrix)\n",
" transformed_tensor[i, start_col:end_col] = transformed_block\n",
"\n",
" return transformed_tensor\n",
"\n",
"# --- Test it ---\n",
"sample_tensor_rht = torch.randn((1, 32), dtype=torch.bfloat16, device='cuda')\n",
"transformed_tensor = random_hadamard_transform_reference(sample_tensor_rht)\n",
"\n",
"print(\"--- RHT Example ---\")\n",
"print(\"Original first 16 values:\\n\", sample_tensor_rht[0, :16])\n",
"print(\"\\nTransformed first 16 values:\\n\", transformed_tensor[0, :16])\n",
"print(f\"Shape remains the same: {transformed_tensor.shape}\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "acAqwXOVJAGU"
},
"source": [
"## Step 5: The Fused Operation - Putting It All Together\n",
"\n",
"The true power of the PR is fusing all these steps into a single, efficient CUDA kernel. The kernel performs:\n",
"`Cast -> RHT (optional) -> Quantize -> Transpose -> Quantize (again for transposed layout)`\n",
"\n",
"This avoids materializing intermediate tensors in memory and is much faster. Let's create a Python function that orchestrates our reference components to simulate this entire pipeline. This mimics the `compute_ref` function in `test_cast_nvfp4_transpose.cu`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XntTaPY2JAGU"
},
"outputs": [],
"source": [
"def nvfp4_recipe_reference(\n",
" hp_tensor: torch.Tensor,\n",
" use_rht: bool,\n",
" use_2d_quant_for_weights: bool # In TE, this only applies to weights, but we simulate it here\n",
"):\n",
" \"\"\"\n",
" Simulates the full, fused quantization pipeline.\n",
" \"\"\"\n",
" # --- Process the input for row-wise (activation) usage ---\n",
" processed_tensor = random_hadamard_transform_reference(hp_tensor) if use_rht else hp_tensor\n",
" # Always use 1D quantization for activations/row-wise data\n",
" q_data, scales = quantize_1d_block_reference(processed_tensor)\n",
"\n",
" # --- Process the input for column-wise (weight) usage ---\n",
" hp_tensor_t = hp_tensor.T.contiguous()\n",
" if use_2d_quant_for_weights:\n",
" # NOTE: Real implementation pads to 16x16 blocks. We'll assume divisible dimensions.\n",
" q_data_t, scales_t = quantize_2d_block_reference(hp_tensor_t)\n",
" else:\n",
" q_data_t, scales_t = quantize_1d_block_reference(hp_tensor_t)\n",
"\n",
" print(\"Simulated fused operation successful!\")\n",
" return q_data, scales, q_data_t, scales_t\n",
"\n",
"# --- Test it with a realistic shape ---\n",
"activation_tensor = torch.randn((128, 2048), dtype=torch.bfloat16, device='cuda')\n",
"\n",
"q_activation, scales_activation, q_weight, scales_weight = nvfp4_recipe_reference(\n",
" activation_tensor,\n",
" use_rht=True,\n",
" use_2d_quant_for_weights=True\n",
")\n",
"\n",
"print(\"\\n--- Outputs of the Fused Pipeline ---\")\n",
"print(f\"Quantized Activation Shape: {q_activation.shape}\")\n",
"print(f\"Activation Scales Shape: {scales_activation.shape}\")\n",
"print(f\"Quantized Weight (Transposed) Shape: {q_weight.shape}\")\n",
"print(f\"Weight Scales (Transposed) Shape: {scales_weight.shape}\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lYgXcWRMJAGU"
},
"source": [
"## Step 6: The `NVFP4Tensor` Data Structure\n",
"\n",
"Finally, why does the PR introduce a new `NVFP4Tensor` class in Python?\n",
"\n",
"Because the results of the fused operation (`q_data`, `scales`, `q_data_t`, `scales_t`) all belong together. They represent a single high-precision tensor in its quantized form. The `NVFP4Tensor` acts as a container for all these components.\n",
"\n",
"When a TE layer needs the tensor for a forward pass GEMM (activations), it uses `q_data` and `scales`. When it needs the tensor for a wgrad GEMM (weights), it uses `q_data_t` and `scales_t`. This avoids costly re-quantization or transposing of packed 4-bit data on the fly.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NKYnyH0MJAGU"
},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"\n",
"@dataclass\n",
"class NVFP4TensorReference:\n",
" \"\"\"A Python dataclass to represent the real NVFP4Tensor structure.\"\"\"\n",
" _rowwise_data: torch.Tensor\n",
" _rowwise_scale_inv: torch.Tensor\n",
" _columnwise_data: torch.Tensor\n",
" _columnwise_scale_inv: torch.Tensor\n",
" original_shape: tuple\n",
"\n",
"# Let's package our results into this structure\n",
"nvfp4_tensor_ref = NVFP4TensorReference(\n",
" _rowwise_data=q_activation,\n",
" _rowwise_scale_inv=scales_activation,\n",
" _columnwise_data=q_weight,\n",
" _columnwise_scale_inv=scales_weight,\n",
" original_shape=activation_tensor.shape\n",
")\n",
"\n",
"print(\"Representation of a complete NVFP4Tensor object:\")\n",
"print(nvfp4_tensor_ref)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DEJS6bXUJAGU"
},
"source": [
"## Conclusion\n",
"\n",
"You have now implemented the core algorithmic building blocks of the NVFP4 recipe from scratch.\n",
"\n",
"You've learned that the implementation is not just a simple cast, but a sophisticated, fused pipeline that involves:\n",
"1. **Block-based Scaling**: Calculating per-block scaling factors (either 1D or 2D).\n",
"2. **Optional Pre-processing (RHT)**: Applying a mathematical transform to improve numerical stability.\n",
"3. **Fused Operations**: Performing quantization and transposition in a single step to generate layouts for both forward and backward passes efficiently.\n",
"4. **A Specialized Data Structure**: Using `NVFP4Tensor` to hold all the necessary components (data, scales, transposed versions) together.\n",
"\n",
"The actual C++/CUDA code in the PR takes these exact algorithms and implements them with extreme performance optimizations, using techniques like shared memory, tensor core instructions, and careful data movement to make 4-bit training feasible at scale.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "niTGw1h_JAGU"
},
"source": [
"# Advanced Lessons: Implementing the NVFP4 Recipe\n",
"\n",
"Welcome to the advanced implementation tutorial for the NVFP4 recipe. In the previous session, we built high-level Python models of the core algorithms. Now, we will dissect the engineering principles and low-level details from the PR to understand how this is implemented for maximum performance on a GPU.\n",
"\n",
"### Learning Path:\n",
"* **Lesson 1: The \"Why\" of Fused Kernels** - Why not just call the Python functions in sequence?\n",
"* **Lesson 2: Anatomy of the CUDA Kernel** - A conceptual breakdown of the C++ `block_scaled_1d_cast_transpose_kernel`.\n",
"* **Lesson 3: The Nuances of Two-Level Scaling** - Understanding the global (`S_enc`) and local (`S_dec_b`) scaling factors.\n",
"* **Lesson 4: Distributed Training & Quantized All-Gather** - How to handle custom data types in a multi-GPU setting.\n",
"* **Lesson 5: The Python API Glue** - How the `NVFP4Quantizer` class orchestrates everything.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KyXORSe0JAGU"
},
"outputs": [],
"source": [
"import torch\n",
"import math\n",
"\n",
"# For reproducibility\n",
"torch.manual_seed(0)\n",
"torch.cuda.manual_seed(0)\n",
"\n",
"# Constants from the previous lesson\n",
"FP4_E2M1_MAX_VAL = 6.0\n",
"# A new constant from the PR: the max value of an FP8 E4M3 number, used for scaling factors.\n",
"FP8_E4M3_MAX_VAL = 448.0\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u4AW5GXWJAGU"
},
"source": [
"## Lesson 1: The \"Why\" of Fused Kernels - The Memory Bottleneck\n",
"\n",
"In our previous tutorial, we implemented each step (RHT, Quantize, Transpose) as a separate Python function. On a real GPU, this would be incredibly inefficient. Why? **Memory Bandwidth**.\n",
"\n",
"A GPU is fastest when it's doing math (computing). It's relatively slow when it's moving data between its main memory (HBM) and its compute cores. Operations like ours are often **memory-bound**, meaning the GPU spends more time waiting for data than computing on it.\n",
"\n",
"Consider the \"naive\" approach:\n",
"1. `hp_tensor` is in Global Memory.\n",
"2. **Kernel 1 (RHT)**: Load `hp_tensor`, compute RHT, write `rht_tensor` back to Global Memory.\n",
"3. **Kernel 2 (Amax)**: Load `rht_tensor`, compute amax, write `amax_tensor` back to Global Memory.\n",
"4. **Kernel 3 (Quantize)**: Load `rht_tensor` and `amax_tensor`, compute scales and quantized data, write `q_tensor` and `scales_tensor` to Global Memory.\n",
"5. ...and so on for the transpose.\n",
"\n",
"This involves multiple round-trips to slow global memory. A **fused kernel**, like the one in this PR (`quantize_transpose_vector_blockwise_fp4.cu`), does all of this in a single trip.\n",
"\n",
"### The Fused Kernel Strategy:\n",
"1. **Launch ONE Kernel.**\n",
"2. Threads load a small tile of the `hp_tensor` from Global Memory into ultra-fast **Shared Memory**.\n",
"3. Perform all operations (RHT, amax reduction, scaling, casting) directly on the data in Shared Memory.\n",
"4. Write the final, tiny outputs (`q_tensor` tile, `scales_tensor` tile) back to Global Memory.\n",
"\n",
"This minimizes global memory traffic and maximizes computation, leading to massive speedups. The entire PR is built around this principle.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QElLt6c8JAGU"
},
"source": [
"## Lesson 2: Anatomy of a Fused CUDA Kernel\n",
"\n",
"Let's write a \"pseudo-code\" walkthrough of the main kernel. We can't run CUDA C++ here, but we can model its logic and structure in Python comments to understand how it works. We'll focus on the `block_scaled_1d_cast_transpose_kernel` logic from the new C++ tests.\n",
"\n",
"A CUDA kernel is executed by a grid of *thread blocks*. Each block is responsible for processing one \"tile\" of the input data. Inside a block, threads cooperate using **Shared Memory**.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "32ixAc-zJAGU"
},
"outputs": [],
"source": [
"def conceptual_fused_kernel(hp_tensor):\n",
" \"\"\"A Python simulation of the fused kernel's logic for a single 16x16 tile.\"\"\"\n",
" # --- Kernel Launch Setup (Done by the CUDA runtime) ---\n",
" # Imagine this function is ONE thread block, given an index (blockIdx.x, blockIdx.y)\n",
" # to identify which 16x16 tile of the hp_tensor it should process.\n",
" # Let's assume this block is responsible for the tile starting at (0, 0).\n",
" TILE_DIM = 16\n",
" block_start_row, block_start_col = 0, 0\n",
"\n",
" # --- Inside the Kernel (Execution on GPU) ---\n",
"\n",
" # 1. Cooperative Loading into Shared Memory\n",
" # Each of the 256 threads in the block loads one element from global HBM\n",
" # into the fast, on-chip shared memory scratchpad.\n",
" shared_mem_tile = hp_tensor[\n",
" block_start_row : block_start_row + TILE_DIM,\n",
" block_start_col : block_start_col + TILE_DIM\n",
" ].clone() # .clone() simulates the copy to a new memory space.\n",
" # In CUDA, a `__syncthreads()` barrier would wait for all loads to complete.\n",
"\n",
" # 2. On-Chip AMAX Reduction (Row-wise)\n",
" # The threads now work on the fast shared memory tile.\n",
" # They cooperatively find the amax for each of the 16 rows in the tile.\n",
" row_amaxes = torch.max(torch.abs(shared_mem_tile), dim=1).values\n",
" # This is a simplified view. In CUDA, this is a multi-step reduction using\n",
" # warp-level primitives (`__shfl_down_sync`) and another `__syncthreads()`.\n",
"\n",
" # 3. Calculate Row-wise Scaling Factors\n",
" row_scales = row_amaxes / FP4_E2M1_MAX_VAL\n",
" # Handle division by zero for all-zero rows\n",
" row_scales[row_scales == 0] = 1.0\n",
"\n",
" # 4. Scale and Cast (Row-wise)\n",
" # Each thread scales its value and simulates the cast.\n",
" # The actual CUDA kernel uses a PTX instruction like `cvt.rn.satfinite.e2m1x2.f32`\n",
" # which converts two FP32 numbers to two packed FP4 numbers in one go.\n",
" scaled_tile = shared_mem_tile / row_scales.unsqueeze(1)\n",
" quantized_tile = torch.round(scaled_tile).clamp(-FP4_E2M1_MAX_VAL, FP4_E2M1_MAX_VAL) # Simplified cast logic\n",
"\n",
" # 5. On-Chip Transposition\n",
" # Threads cooperatively write to a second shared memory buffer in a transposed pattern.\n",
" transposed_shared_mem_tile = shared_mem_tile.T.contiguous()\n",
" # `__syncthreads()` ensures the transpose is complete.\n",
"\n",
" # 6. AMAX, Scale, and Cast (Column-wise / Transposed)\n",
" # The process is repeated on the transposed tile to get the column-wise outputs.\n",
" col_amaxes = torch.max(torch.abs(transposed_shared_mem_tile), dim=1).values\n",
" col_scales = col_amaxes / FP4_E2M1_MAX_VAL\n",
" col_scales[col_scales == 0] = 1.0\n",
" scaled_transposed_tile = transposed_shared_mem_tile / col_scales.unsqueeze(1)\n",
" quantized_transposed_tile = torch.round(scaled_transposed_tile).clamp(-FP4_E2M1_MAX_VAL, FP4_E2M1_MAX_VAL)\n",
"\n",
" # 7. Write Final Results to Global Memory\n",
" # The threads write their final results from shared memory back to the final output tensors in HBM.\n",
" # This is the only other time they touch global memory.\n",
" print(\"Conceptual kernel finished processing one tile.\")\n",
" return quantized_tile, row_scales, quantized_transposed_tile, col_scales\n",
"\n",
"# --- Run the conceptual model ---\n",
"sample_tile = torch.randn((16, 16), dtype=torch.float32, device='cuda')\n",
"q_data, scales, q_data_t, scales_t = conceptual_fused_kernel(sample_tile)\n",
"\n",
"print(f\"\\nRow-wise quantized data shape: {q_data.shape}\")\n",
"print(f\"Row-wise scales shape: {scales.shape} (One scale per row in the tile)\")\n",
"print(f\"Column-wise quantized data shape: {q_data_t.shape}\")\n",
"print(f\"Column-wise scales shape: {scales_t.shape} (One scale per column in the tile)\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wdwxO9IJAGV"
},
"source": [
"## Lesson 3: The Nuances of Two-Level Scaling\n",
"\n",
"The previous lessons used a simplified scaling formula: `scale = amax / 6.0`. The actual implementation in the PR is more sophisticated, as seen in the C++ function `compute_global_encode_scaling_factor_FP4`. It uses a **two-level scaling system**.\n",
"\n",
"1. **Global Per-Tensor Scale (`S_enc`)**: A single FP32 scale factor is computed for the *entire tensor*. Its job is to map the tensor's global amax into a range that is friendly to FP8-E4M3, the format used for the *scaling factors themselves*.\n",
"\n",
"2. **Local Per-Block Scale (`S_dec_b`)**: This is the scale we've been calculating (`block_amax / 6.0`). It handles local variations.\n",
"\n",
"**The final scaling factor stored in memory is `S_final = S_dec_b * S_enc`**.\n",
"\n",
"Why do this? It improves numerical precision. By pre-scaling the entire tensor with `S_enc`, we ensure that the per-block `S_dec_b` values can be accurately represented by the FP8-E4M3 format.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qpUmg3QzJAGV"
},
"outputs": [],
"source": [
"def two_level_scaling_reference(hp_tensor: torch.Tensor):\n",
" \"\"\"Reference implementation for the two-level scaling logic.\"\"\"\n",
" # -- Level 1: Global Scaling --\n",
" global_amax = torch.max(torch.abs(hp_tensor))\n",
"\n",
" # This formula is a direct translation of the C++ `compute_global_encode_scaling_factor_FP4`\n",
" # It maps the global amax to the dynamic range of FP8 * FP4\n",
" if global_amax == 0.0:\n",
" S_enc = 1.0\n",
" else:\n",
" S_enc = (FP8_E4M3_MAX_VAL * FP4_E2M1_MAX_VAL) / global_amax\n",
" S_enc = min(S_enc, torch.finfo(torch.float32).max) # Clamp to max float32\n",
"\n",
" # -- Level 2: Local Scaling (within a 1D block) --\n",
" rows, cols = hp_tensor.shape\n",
" num_scale_blocks = cols // 16\n",
" final_scales = torch.zeros(rows, num_scale_blocks, dtype=torch.float32, device=hp_tensor.device)\n",
"\n",
" for i in range(rows):\n",
" for j in range(num_scale_blocks):\n",
" block = hp_tensor[i, j*16:(j+1)*16]\n",
" block_amax = torch.max(torch.abs(block))\n",
"\n",
" # Calculate the local decoding scale\n",
" S_dec_b = block_amax / FP4_E2M1_MAX_VAL\n",
"\n",
" # Combine with global encoding scale to get the final scale\n",
" S_final = S_dec_b * S_enc\n",
"\n",
" # The final scale is then cast to FP8 E4M3 for storage.\n",
" # We will just store it as float32 for this reference.\n",
" final_scales[i, j] = S_final\n",
"\n",
" print(f\"Global Amax: {global_amax:.4f}, S_enc (Global Scale): {S_enc:.4f}\")\n",
" return final_scales\n",
"\n",
"# --- Test the two-level scaling ---\n",
"sample_tensor = torch.randn((2, 32), device='cuda') * 10 # Scale up to see a more interesting amax\n",
"final_scaling_factors = two_level_scaling_reference(sample_tensor)\n",
"print(\"\\nFinal (two-level) scaling factors for the first row:\")\n",
"print(final_scaling_factors[0])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O9cCbq4WJAGV"
},
"source": [
"## Lesson 4: Distributed Training & Quantized All-Gather\n",
"\n",
"Making a new feature work on one GPU is only half the battle. For large models, it must work with tensor parallelism across multiple GPUs. This PR adds a custom `_all_gather_nvfp4` function in `transformer_engine/pytorch/distributed.py`.\n",
"\n",
"**The Problem**: You can't just call `torch.distributed.all_gather` on an `NVFP4Tensor` object. The All-Gather operation only works on single, contiguous `torch.Tensor`s.\n",
"\n",
"**The Solution**:\n",
"1. Deconstruct the `NVFP4Tensor` on each GPU into its constituent `torch.Tensor` components (e.g., `_rowwise_data`, `_rowwise_scale_inv`).\n",
"2. Perform a separate `all_gather` operation on each component tensor.\n",
"3. Reconstruct a new, larger `NVFP4Tensor` on each GPU from the gathered components.\n",
"\n",
"**A New Problem (The \"Interleave\" Issue)**: When you gather a *transposed* tensor (like `_columnwise_data`) along the batch dimension, the data from different GPUs gets interleaved incorrectly.\n",
"\n",
"Imagine 2 GPUs. GPU0 has `[A0, B0]` and GPU1 has `[A1, B1]`. After gathering, the memory layout isn't `[A0, B0, A1, B1]`. It becomes something like `[A0, A1, B0, B1]`.\n",
"\n",
"To fix this, the PR adds a `swap_first_dims` operation. Let's simulate this.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wUWKQa8QJAGV"
},
"outputs": [],
"source": [
"def simulate_distributed_gather_and_fix():\n",
" world_size = 4 # Simulate 4 GPUs\n",
" local_dim0, local_dim1 = 2, 8\n",
"\n",
" # Create dummy transposed data on each GPU\n",
" gpu_data = [torch.arange(local_dim0 * local_dim1, dtype=torch.float32).reshape(local_dim0, local_dim1) + (i*100) for i in range(world_size)]\n",
" print(f\"--- Data on GPU 0 (Transposed Layout) ---\\n{gpu_data[0]}\")\n",
"\n",
" # Simulate `all_gather` on the first dimension. This creates the interleaved result.\n",
" interleaved_data = torch.cat(gpu_data, dim=0)\n",
" print(f\"\\n--- Interleaved Data After All-Gather ---\\n{interleaved_data}\")\n",
"\n",
" # The `swap_first_dims` logic to fix the layout\n",
" # This is what `tex.swap_first_dims` in the PR does in a highly optimized way.\n",
" total_dim0 = interleaved_data.shape[0]\n",
" fixed_data = interleaved_data.reshape(world_size, total_dim0 // world_size, -1).transpose(0, 1).reshape(total_dim0, -1)\n",
"\n",
" print(f\"\\n--- Data After `swap_first_dims` Fix ---\\n{fixed_data}\")\n",
"\n",
"simulate_distributed_gather_and_fix()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6CBe328SJAGV"
},
"source": [
"## Lesson 5: The Python API Glue - `NVFP4Quantizer`\n",
"\n",
"The `NVFP4Quantizer` class in `transformer_engine/pytorch/tensor/nvfp4_tensor.py` is the high-level orchestrator. It's the bridge between the Python world and the C++/CUDA backend.\n",
"\n",
"Let's break down its key responsibilities based on the PR:\n",
"\n",
"1. **Configuration (`__init__`)**: It reads the `Recipe` object and stores flags like `with_rht`, `stochastic_rounding`, and `with_2d_quantization`. It also pre-builds the RHT matrix if needed.\n",
"\n",
"2. **State Management**: It holds stateful information. For example, it generates and stores the random sign mask for the RHT matrix.\n",
"\n",
"3. **Backend Invocation (`quantize`)**: This is the main method. It takes a high-precision `torch.Tensor` as input.\n",
" * It checks the tensor shape and properties.\n",
" * It packages all the configuration flags and tensor pointers into a C-compatible structure (`QuantizationConfigWrapper`).\n",
" * It calls the core C++ function (e.g., `tex.quantize_fp4`) through the Pybind11 bridge. This is the function that launches the fused CUDA kernel we discussed in Lesson 2.\n",
"\n",
"4. **Object Creation**: The C++ function returns raw tensor data. The `NVFP4Quantizer` takes this raw data and uses it to construct and return a proper, user-friendly `NVFP4Tensor` Python object.\n",
"\n",
"This class design cleanly separates the high-level configuration and object management in Python from the low-level, high-performance computations in C++/CUDA.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e7jPBHUNJAGV"
},
"source": [
"## Grand Conclusion\n",
"\n",
"You have now journeyed from a high-level user of the NVFP4 recipe to understanding the deepest implementation details. You've learned:\n",
"\n",
"- **Performance is King**: Fused kernels are essential to overcome memory bandwidth limitations, which is the primary motivation for the C++/CUDA implementation.\n",
"- **CUDA Programming Patterns**: Thread blocks, shared memory, and cooperative execution are the tools used to build these fused kernels.\n",
"- **Numerical Precision Matters**: The two-level scaling system is a clever trick to maintain accuracy when the scaling factors themselves must be stored in a low-precision format.\n",
"- **Distributed Systems are Complex**: Features must be designed with multi-GPU execution in mind, often requiring custom communication patterns like the fix for interleaved gathering.\n",
"- **APIs are Abstractions**: The Python `NVFP4Quantizer` class provides a clean interface that hides the immense complexity of the underlying C++/CUDA/distributed logic.\n",
"\n",
"You are now well-equipped to read through the files in PR #2177, such as `quantize_transpose_vector_blockwise_fp4.cu` and `distributed.py`, and recognize the patterns and algorithms we've discussed here."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
},
"colab": {
"provenance": [],
"name": "Implementing the NVFP4 Recipe From Scratch.ipynb",
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment