Created
June 26, 2025 14:14
-
-
Save adrn/a4881385a18b93b433ba80344b10ecee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 57, | |
| "id": "7c132282", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import jax\n", | |
| "import jax.numpy as jnp\n", | |
| "import numba\n", | |
| "import numpy as np\n", | |
| "import array_api_compat\n", | |
| "import torch\n", | |
| "import dask.array as da" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 66, | |
| "id": "0dd124da", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def custom_func(A, x):\n", | |
| " xp = array_api_compat.array_namespace(A, x)\n", | |
| " return xp.matmul(A, x) + xp.sum(A, axis=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 67, | |
| "id": "5ef90bae", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "numba_custom_func = numba.jit(custom_func)\n", | |
| "jax_custom_func = jax.jit(custom_func)\n", | |
| "# torch_custom_func = torch.jit.script(custom_func)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 70, | |
| "id": "fc8752c2", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "rng = np.random.default_rng(42)\n", | |
| "\n", | |
| "A = rng.normal(size=(16, 8))\n", | |
| "x = rng.normal(size=(8,))\n", | |
| "\n", | |
| "all_arrs = {\n", | |
| " \"numpy\": {\"A\": A, \"x\": x},\n", | |
| " \"jax\": {\"A\": jnp.array(A), \"x\": jnp.array(x)},\n", | |
| " \"torch\": {\"A\": torch.tensor(A), \"x\": torch.tensor(x)},\n", | |
| " \"dask\": {\"A\": da.from_array(A), \"x\": da.from_array(x)},\n", | |
| "}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 71, | |
| "id": "e30d9592", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "for name, arrs in all_arrs.items():\n", | |
| " custom_func(**arrs)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 72, | |
| "id": "bd6e2191", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "Array([-4.1212697 , 0.95292354, -0.7877971 , 0.01498628, -0.7554294 ,\n", | |
| " -0.23877025, -2.2117355 , -3.9803839 , 2.1482267 , -1.3329437 ,\n", | |
| " -1.3176045 , 0.71187663, 0.583696 , -0.9810766 , -0.4296934 ,\n", | |
| " 0.23938954], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 72, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "jax_custom_func(**all_arrs[\"jax\"])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 73, | |
| "id": "b21099cc", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "ename": "TypingError", | |
| "evalue": "Failed in nopython mode pipeline (step: nopython frontend)\nUnknown attribute 'array_namespace' of type Module(<module 'array_api_compat' from '/Users/aprice-whelan/projects/scratch/.venv/lib/python3.12/site-packages/array_api_compat/__init__.py'>)\n\nFile \"../../../../var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py\", line 2:\n<source missing, REPL/exec in use?>\n\nDuring: typing of get attribute at /var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py (2)\n\nFile \"../../../../var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py\", line 2:\n<source missing, REPL/exec in use?>\n\nDuring: Pass nopython_type_inference", | |
| "output_type": "error", | |
| "traceback": [ | |
| "\u001b[31m---------------------------------------------------------------------------\u001b[39m", | |
| "\u001b[31mTypingError\u001b[39m Traceback (most recent call last)", | |
| "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[73]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mnumba_custom_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mall_arrs\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mnumpy\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", | |
| "\u001b[36mFile \u001b[39m\u001b[32m~/projects/scratch/.venv/lib/python3.12/site-packages/numba/core/dispatcher.py:424\u001b[39m, in \u001b[36m_DispatcherBase._compile_for_args\u001b[39m\u001b[34m(self, *args, **kws)\u001b[39m\n\u001b[32m 420\u001b[39m msg = (\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mstr\u001b[39m(e).rstrip()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m \u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mThis error may have been caused \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 421\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mby the following argument(s):\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00margs_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 422\u001b[39m e.patch_message(msg)\n\u001b[32m--> \u001b[39m\u001b[32m424\u001b[39m \u001b[43merror_rewrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mtyping\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 425\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m errors.UnsupportedError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 426\u001b[39m \u001b[38;5;66;03m# Something unsupported is present in the user code, add help info\u001b[39;00m\n\u001b[32m 427\u001b[39m error_rewrite(e, \u001b[33m'\u001b[39m\u001b[33munsupported_error\u001b[39m\u001b[33m'\u001b[39m)\n", | |
| "\u001b[36mFile \u001b[39m\u001b[32m~/projects/scratch/.venv/lib/python3.12/site-packages/numba/core/dispatcher.py:365\u001b[39m, in \u001b[36m_DispatcherBase._compile_for_args.<locals>.error_rewrite\u001b[39m\u001b[34m(e, issue_type)\u001b[39m\n\u001b[32m 363\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[32m 364\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m365\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e.with_traceback(\u001b[38;5;28;01mNone\u001b[39;00m)\n", | |
| "\u001b[31mTypingError\u001b[39m: Failed in nopython mode pipeline (step: nopython frontend)\nUnknown attribute 'array_namespace' of type Module(<module 'array_api_compat' from '/Users/aprice-whelan/projects/scratch/.venv/lib/python3.12/site-packages/array_api_compat/__init__.py'>)\n\nFile \"../../../../var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py\", line 2:\n<source missing, REPL/exec in use?>\n\nDuring: typing of get attribute at /var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py (2)\n\nFile \"../../../../var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py\", line 2:\n<source missing, REPL/exec in use?>\n\nDuring: Pass nopython_type_inference" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "numba_custom_func(**all_arrs[\"numpy\"])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "f3089209", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": ".venv", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment