Created
November 18, 2025 14:51
-
-
Save ogrisel/95033cb862acd8f3dac7d4b7c78216d9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "4b5599b2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import os\n", | |
| "\n", | |
| "os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n", | |
| "os.environ[\"SCIPY_ARRAY_API\"] = \"1\"\n", | |
| "\n", | |
| "import torch\n", | |
| "import numpy as np\n", | |
| "from array_api_compat import get_namespace\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "ef949524", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "The memory_profiler extension is already loaded. To reload it, use:\n", | |
| " %reload_ext memory_profiler\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%load_ext memory_profiler" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "6bb78031", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def nested_where(data):\n", | |
| " xp = get_namespace(data)\n", | |
| " return xp.where(\n", | |
| " data > 0,\n", | |
| " xp.where(data > 5, xp.where(data > 20, data * 3, data * 2), data + 10),\n", | |
| " xp.where(data < 10, data - 10, xp.where(data < -20, data / 2, data)),\n", | |
| " )\n", | |
| "\n", | |
| "nested_where_compiled = torch.compile(nested_where)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "id": "913d8ec8", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "rng = np.random.default_rng(42)\n", | |
| "data_np = rng.uniform(-50, 50, size=int(1e8)).astype(\"float32\")\n", | |
| "data_mps = torch.from_numpy(data_np).to(\"mps\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "id": "f5447762", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "33" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "import gc\n", | |
| "\n", | |
| "gc.collect()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "id": "28a95751", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "peak memory: 1002.81 MiB, increment: 3.61 MiB\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%memit\n", | |
| "nested_where(data_mps)[:5].cpu()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "70c7af1e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "181 ms ± 2.74 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "nested_where(data_mps)[:5].cpu()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "id": "205ea02f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "peak memory: 315.00 MiB, increment: 4.06 MiB\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%memit\n", | |
| "nested_where_compiled(data_mps)[:5].cpu()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "id": "a11d85c8", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "180 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "nested_where_compiled(data_mps)[:5].cpu()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "id": "e77bccb3", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "peak memory: 2476.77 MiB, increment: 2165.73 MiB\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%memit\n", | |
| "nested_where(data_np)[:5]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "8875834a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "707 ms ± 30.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "nested_where(data_np)[:5]" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "dev", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.13.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment