Skip to content

Instantly share code, notes, and snippets.

@adrn
Created June 26, 2025 14:14
Show Gist options
  • Select an option

  • Save adrn/a4881385a18b93b433ba80344b10ecee to your computer and use it in GitHub Desktop.

Select an option

Save adrn/a4881385a18b93b433ba80344b10ecee to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 57,
"id": "7c132282",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import numba\n",
"import numpy as np\n",
"import array_api_compat\n",
"import torch\n",
"import dask.array as da"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "0dd124da",
"metadata": {},
"outputs": [],
"source": [
"def custom_func(A, x):\n",
" xp = array_api_compat.array_namespace(A, x)\n",
" return xp.matmul(A, x) + xp.sum(A, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "5ef90bae",
"metadata": {},
"outputs": [],
"source": [
"numba_custom_func = numba.jit(custom_func)\n",
"jax_custom_func = jax.jit(custom_func)\n",
"# torch_custom_func = torch.jit.script(custom_func)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "fc8752c2",
"metadata": {},
"outputs": [],
"source": [
"rng = np.random.default_rng(42)\n",
"\n",
"A = rng.normal(size=(16, 8))\n",
"x = rng.normal(size=(8,))\n",
"\n",
"all_arrs = {\n",
" \"numpy\": {\"A\": A, \"x\": x},\n",
" \"jax\": {\"A\": jnp.array(A), \"x\": jnp.array(x)},\n",
" \"torch\": {\"A\": torch.tensor(A), \"x\": torch.tensor(x)},\n",
" \"dask\": {\"A\": da.from_array(A), \"x\": da.from_array(x)},\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "e30d9592",
"metadata": {},
"outputs": [],
"source": [
"for name, arrs in all_arrs.items():\n",
" custom_func(**arrs)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "bd6e2191",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([-4.1212697 , 0.95292354, -0.7877971 , 0.01498628, -0.7554294 ,\n",
" -0.23877025, -2.2117355 , -3.9803839 , 2.1482267 , -1.3329437 ,\n",
" -1.3176045 , 0.71187663, 0.583696 , -0.9810766 , -0.4296934 ,\n",
" 0.23938954], dtype=float32)"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jax_custom_func(**all_arrs[\"jax\"])"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "b21099cc",
"metadata": {},
"outputs": [
{
"ename": "TypingError",
"evalue": "Failed in nopython mode pipeline (step: nopython frontend)\nUnknown attribute 'array_namespace' of type Module(<module 'array_api_compat' from '/Users/aprice-whelan/projects/scratch/.venv/lib/python3.12/site-packages/array_api_compat/__init__.py'>)\n\nFile \"../../../../var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py\", line 2:\n<source missing, REPL/exec in use?>\n\nDuring: typing of get attribute at /var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py (2)\n\nFile \"../../../../var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py\", line 2:\n<source missing, REPL/exec in use?>\n\nDuring: Pass nopython_type_inference",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mTypingError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[73]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mnumba_custom_func\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mall_arrs\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mnumpy\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/projects/scratch/.venv/lib/python3.12/site-packages/numba/core/dispatcher.py:424\u001b[39m, in \u001b[36m_DispatcherBase._compile_for_args\u001b[39m\u001b[34m(self, *args, **kws)\u001b[39m\n\u001b[32m 420\u001b[39m msg = (\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mstr\u001b[39m(e).rstrip()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m \u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33mThis error may have been caused \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 421\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mby the following argument(s):\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00margs_str\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 422\u001b[39m e.patch_message(msg)\n\u001b[32m--> \u001b[39m\u001b[32m424\u001b[39m \u001b[43merror_rewrite\u001b[49m\u001b[43m(\u001b[49m\u001b[43me\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mtyping\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 425\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m errors.UnsupportedError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 426\u001b[39m \u001b[38;5;66;03m# Something unsupported is present in the user code, add help info\u001b[39;00m\n\u001b[32m 427\u001b[39m error_rewrite(e, \u001b[33m'\u001b[39m\u001b[33munsupported_error\u001b[39m\u001b[33m'\u001b[39m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/projects/scratch/.venv/lib/python3.12/site-packages/numba/core/dispatcher.py:365\u001b[39m, in \u001b[36m_DispatcherBase._compile_for_args.<locals>.error_rewrite\u001b[39m\u001b[34m(e, issue_type)\u001b[39m\n\u001b[32m 363\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[32m 364\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m365\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m e.with_traceback(\u001b[38;5;28;01mNone\u001b[39;00m)\n",
"\u001b[31mTypingError\u001b[39m: Failed in nopython mode pipeline (step: nopython frontend)\nUnknown attribute 'array_namespace' of type Module(<module 'array_api_compat' from '/Users/aprice-whelan/projects/scratch/.venv/lib/python3.12/site-packages/array_api_compat/__init__.py'>)\n\nFile \"../../../../var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py\", line 2:\n<source missing, REPL/exec in use?>\n\nDuring: typing of get attribute at /var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py (2)\n\nFile \"../../../../var/folders/67/2zgxpmyd2z183j4k6r33nf740000gr/T/ipykernel_19851/3285249787.py\", line 2:\n<source missing, REPL/exec in use?>\n\nDuring: Pass nopython_type_inference"
]
}
],
"source": [
"numba_custom_func(**all_arrs[\"numpy\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3089209",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment