Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Created November 18, 2025 14:51
Show Gist options
  • Select an option

  • Save ogrisel/95033cb862acd8f3dac7d4b7c78216d9 to your computer and use it in GitHub Desktop.

Select an option

Save ogrisel/95033cb862acd8f3dac7d4b7c78216d9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"id": "4b5599b2",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
"os.environ[\"SCIPY_ARRAY_API\"] = \"1\"\n",
"\n",
"import torch\n",
"import numpy as np\n",
"from array_api_compat import get_namespace\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ef949524",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The memory_profiler extension is already loaded. To reload it, use:\n",
" %reload_ext memory_profiler\n"
]
}
],
"source": [
"%load_ext memory_profiler"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6bb78031",
"metadata": {},
"outputs": [],
"source": [
"def nested_where(data):\n",
" xp = get_namespace(data)\n",
" return xp.where(\n",
" data > 0,\n",
" xp.where(data > 5, xp.where(data > 20, data * 3, data * 2), data + 10),\n",
" xp.where(data < 10, data - 10, xp.where(data < -20, data / 2, data)),\n",
" )\n",
"\n",
"nested_where_compiled = torch.compile(nested_where)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "913d8ec8",
"metadata": {},
"outputs": [],
"source": [
"rng = np.random.default_rng(42)\n",
"data_np = rng.uniform(-50, 50, size=int(1e8)).astype(\"float32\")\n",
"data_mps = torch.from_numpy(data_np).to(\"mps\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "f5447762",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"33"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import gc\n",
"\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "28a95751",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"peak memory: 1002.81 MiB, increment: 3.61 MiB\n"
]
}
],
"source": [
"%%memit\n",
"nested_where(data_mps)[:5].cpu()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "70c7af1e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"181 ms ± 2.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"nested_where(data_mps)[:5].cpu()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "205ea02f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"peak memory: 315.00 MiB, increment: 4.06 MiB\n"
]
}
],
"source": [
"%%memit\n",
"nested_where_compiled(data_mps)[:5].cpu()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "a11d85c8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"180 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"nested_where_compiled(data_mps)[:5].cpu()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "e77bccb3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"peak memory: 2476.77 MiB, increment: 2165.73 MiB\n"
]
}
],
"source": [
"%%memit\n",
"nested_where(data_np)[:5]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "8875834a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"707 ms ± 30.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"nested_where(data_np)[:5]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment