Skip to content

Instantly share code, notes, and snippets.

@peterroelants
Created September 18, 2025 07:58
Show Gist options
  • Select an option

  • Save peterroelants/19a3e8988f2c989982ba9105a902ad8a to your computer and use it in GitHub Desktop.

Select an option

Save peterroelants/19a3e8988f2c989982ba9105a902ad8a to your computer and use it in GitHub Desktop.
Illustration of the difference between a 2x2 conv and a patchify + 1x1 conv
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "6ed99221",
"metadata": {},
"source": [
"# Illustration of the difference between a 2x2 conv and a patchify + 1x1 conv"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "0c6ab949",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import einops"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "aae9b1f3",
"metadata": {},
"outputs": [],
"source": [
"def patchify_3d(\n",
" x: torch.Tensor, # [B, C, T, H, W]\n",
" patch_size_time: int,\n",
" patch_size_height: int,\n",
" patch_size_width: int,\n",
") -> torch.Tensor: # [B, C, T*patch_size_time, H*patch_size_height, W*patch_size_width]\n",
" \"\"\"\n",
" Patchify input tensor (space-to-depth)\n",
"\n",
" This is similar to torch.nn.functional.pixel_shuffle, but for 3D tensors.\n",
" \"\"\"\n",
" return einops.rearrange(\n",
" x,\n",
" \"b c (t pst) (h psh) (w psw) -> b (c pst psh psw) t h w\",\n",
" pst=patch_size_time,\n",
" psh=patch_size_height,\n",
" psw=patch_size_width,\n",
" )\n",
"\n",
"\n",
"def unpatchify_3d(\n",
" x: torch.Tensor, # [B, C, T*patch_size_time, H*patch_size_height, W*patch_size_width]\n",
" patch_size_time: int,\n",
" patch_size_height: int,\n",
" patch_size_width: int,\n",
") -> torch.Tensor: # [B, C, T, H, W]\n",
" return einops.rearrange(\n",
" x,\n",
" \"b (c pst psh psw) t h w -> b c (t pst) (h psh) (w psw)\",\n",
" pst=patch_size_time,\n",
" psh=patch_size_height,\n",
" psw=patch_size_width,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fbd437c7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x.shape=torch.Size([4, 48, 3, 10, 12])\n",
"y.shape=torch.Size([4, 128, 3, 5, 6])\n"
]
}
],
"source": [
"nb_channels_out = 128\n",
"\n",
"batch_size = 4\n",
"nb_channels_in = 48\n",
"time_steps = 3\n",
"height = 10\n",
"width = 12\n",
"\n",
"x = torch.randn(batch_size, nb_channels_in, time_steps, height, width)\n",
"print(f\"{x.shape=}\")\n",
"\n",
"y = torch.empty(batch_size, nb_channels_out, time_steps, height // 2, width // 2)\n",
"print(f\"{y.shape=}\")"
]
},
{
"cell_type": "markdown",
"id": "6167cf13",
"metadata": {},
"source": [
"## Downsample"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b4c59ef6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_down.shape=torch.Size([4, 128, 3, 5, 6])\n",
"conv_down.weight.data.shape=torch.Size([128, 48, 1, 2, 2])\n",
"Total parameters in conv_down: 24,576\n"
]
}
],
"source": [
"# 2x2 conv downsample\n",
"conv_down = torch.nn.Conv3d(\n",
" in_channels=nb_channels_in,\n",
" out_channels=nb_channels_out,\n",
" kernel_size=(1, 2, 2),\n",
" padding=(0, 0, 0),\n",
" stride=(1, 2, 2),\n",
" bias=False,\n",
")\n",
"\n",
"x_down = conv_down(x)\n",
"print(f\"{x_down.shape=}\")\n",
"assert x_down.shape == y.shape\n",
"\n",
"print(f\"{conv_down.weight.data.shape=}\")\n",
"total_params = sum(p.numel() for p in conv_down.parameters())\n",
"print(f\"Total parameters in conv_down: {total_params:,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b3959860",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_patched.shape=torch.Size([4, 192, 3, 5, 6])\n",
"x_patched_down.shape=torch.Size([4, 128, 3, 5, 6])\n",
"conv_1x1_down.weight.data.shape=torch.Size([128, 192, 1, 1, 1])\n",
"Total parameters in conv_1x1: 24,576\n"
]
}
],
"source": [
"# Patchify and 1x1 conv\n",
"x_patched = patchify_3d(x, patch_size_time=1, patch_size_height=2, patch_size_width=2)\n",
"print(f\"{x_patched.shape=}\")\n",
"\n",
"\n",
"conv_1x1_down = torch.nn.Conv3d(\n",
" in_channels=nb_channels_in*4,\n",
" out_channels=nb_channels_out,\n",
" kernel_size=(1, 1, 1),\n",
" padding=(0, 0, 0),\n",
" stride=(1, 1, 1),\n",
" bias=False,\n",
")\n",
"# Set identity weights\n",
"conv_1x1_down.weight.data = einops.rearrange(\n",
" conv_down.weight.data, \n",
" 'o i t (h ph) (w pw) -> o (i ph pw) t h w',\n",
" ph=2,\n",
" pw=2,\n",
")\n",
"\n",
"x_patched_down = conv_1x1_down(x_patched)\n",
"print(f\"{x_patched_down.shape=}\")\n",
"assert x_patched_down.shape == y.shape\n",
"\n",
"print(f\"{conv_1x1_down.weight.data.shape=}\")\n",
"total_params_1x1 = sum(p.numel() for p in conv_1x1_down.parameters())\n",
"print(f\"Total parameters in conv_1x1: {total_params_1x1:,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fe93c965",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.abs(x_patched_down - x_down).max()=tensor(1.0729e-06, grad_fn=<MaxBackward1>)\n"
]
}
],
"source": [
"assert x_down.shape == x_patched_down.shape == y.shape\n",
"print(f\"{torch.abs(x_patched_down - x_down).max()=}\")\n",
"assert torch.allclose(x_down, x_patched_down, atol=1e-6)"
]
},
{
"cell_type": "markdown",
"id": "0f21265c",
"metadata": {},
"source": [
"## Upsample"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "756acfee",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y_up.shape=torch.Size([4, 48, 3, 10, 12])\n",
"conv_up.weight.data.shape=torch.Size([128, 48, 1, 2, 2])\n",
"Total parameters in conv_up: 24,576\n"
]
}
],
"source": [
"# 2x2 conv upsample\n",
"conv_up = torch.nn.ConvTranspose3d(\n",
" in_channels=nb_channels_out,\n",
" out_channels=nb_channels_in,\n",
" kernel_size=(1, 2, 2),\n",
" padding=(0, 0, 0),\n",
" stride=(1, 2, 2),\n",
" bias=False,\n",
")\n",
"\n",
"\n",
"\n",
"y_up = conv_up(y)\n",
"print(f\"{y_up.shape=}\")\n",
"assert y_up.shape == x.shape\n",
"print(f\"{conv_up.weight.data.shape=}\")\n",
"total_params = sum(p.numel() for p in conv_up.parameters())\n",
"print(f\"Total parameters in conv_up: {total_params:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dbcf12dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y_proj.shape=torch.Size([4, 192, 3, 5, 6])\n",
"y_unpatched.shape=torch.Size([4, 48, 3, 10, 12])\n",
"conv1x1_up.weight.data.shape=torch.Size([192, 128, 1, 1, 1])\n",
"Total parameters in conv1x1_up: 24,576\n"
]
}
],
"source": [
"# 1x1 projection and unpatchify\n",
"conv1x1_up = torch.nn.Conv3d(\n",
" in_channels=nb_channels_out,\n",
" out_channels=nb_channels_in * 4,\n",
" kernel_size=(1, 1, 1),\n",
" padding=(0, 0, 0),\n",
" stride=(1, 1, 1),\n",
" bias=False,\n",
")\n",
"# Set identity weights\n",
"conv1x1_up.weight.data = einops.rearrange(\n",
" conv_up.weight.data, \n",
" 'o i t (h ph) (w pw) -> (i ph pw) o t h w',\n",
" ph=2,\n",
" pw=2,\n",
")\n",
"\n",
"y_proj = conv1x1_up(y)\n",
"print(f\"{y_proj.shape=}\")\n",
"\n",
"y_unpatched = unpatchify_3d(\n",
" y_proj,\n",
" patch_size_time=1,\n",
" patch_size_height=2,\n",
" patch_size_width=2,\n",
")\n",
"print(f\"{y_unpatched.shape=}\")\n",
"assert y_unpatched.shape == x.shape\n",
"\n",
"print(f\"{conv1x1_up.weight.data.shape=}\")\n",
"total_params_unpatch = sum(p.numel() for p in conv1x1_up.parameters())\n",
"print(f\"Total parameters in conv1x1_up: {total_params_unpatch:,}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "0f51c0af",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.abs(y_up - y_unpatched).max()=tensor(0., grad_fn=<MaxBackward1>)\n"
]
}
],
"source": [
"assert y_up.shape == y_unpatched.shape == x.shape\n",
"print(f\"{torch.abs(y_up - y_unpatched).max()=}\")\n",
"assert torch.allclose(y_up, y_unpatched, atol=1e-6)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "anam-audio-to-latent",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment