Skip to content

Instantly share code, notes, and snippets.

@Japanuspus
Created July 10, 2020 06:26
Show Gist options
  • Select an option

  • Save Japanuspus/fa7d352ef0eec9d482956bec668bfaac to your computer and use it in GitHub Desktop.

Select an option

Save Japanuspus/fa7d352ef0eec9d482956bec668bfaac to your computer and use it in GitHub Desktop.
Heterogeneous dispatch with numba can be fast
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Numba heterogeneous dispatch\n",
"\n",
"We have many curve objects (>1000) from a few classes (<10) and want to perform a dispatched call to each object."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import numba\n",
"import math\n",
"import numpy as np\n",
"from numba.experimental import jitclass\n",
"from numba import typed"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"\n",
"@jitclass([\n",
" ('gridsize', numba.float64),\n",
" ('values', numba.float64[:]),\n",
" ('grid0', numba.float64)\n",
"])\n",
"class Interp(object):\n",
" def __init__(self, grid0, gridsize, values):\n",
" self.gridsize = gridsize\n",
" self.grid0 = grid0\n",
" self.values = values\n",
" \n",
" def interp(self, x):\n",
" idx_float =(x-self.grid0)/self.gridsize\n",
" idx = math.floor(idx_float)\n",
" w2 = idx_float-idx\n",
" return self.values[idx]*(1-w2) + self.values[idx+1]*w2\n",
" \n",
" \n",
"a = Interp(10., 2., np.linspace(0, 5, 11))\n",
"assert a.interp(11.) == 0.25"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[3.175, 4.175, 5.175]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@numba.njit\n",
"def interpolate_many(interpolators, x):\n",
" return [it.interp(x) for it in interpolators]\n",
" \n",
"def interpolate_many_slow(interpolators, x):\n",
" return [it.interp(x) for it in interpolators]\n",
"\n",
"interpolators = typed.List(Interp(10., 2., np.linspace(0+k, 5+k, 11)) for k in range(3))\n",
"interpolate_many(interpolators, 0.7)\n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"50 µs ± 365 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n",
"2.03 ms ± 15 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"interpolators = typed.List(Interp(10., 2., np.linspace(0+k, 5+k, 11)) for k in range(1000))\n",
"%timeit interpolate_many(interpolators, 0.7)\n",
"\n",
"%timeit interpolate_many_slow(interpolators, 0.7)\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"@jitclass([\n",
" ('a', numba.float64),\n",
" ('b', numba.float64)\n",
"])\n",
"class Line(object):\n",
" def __init__(self, b, a):\n",
" self.a = a\n",
" self.b = b\n",
"\n",
" def interp(self, x):\n",
" return x*self.a+self.b\n",
" \n",
"b = Line(10., 2.)\n",
"assert b.interp(3.) == 16."
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"694 µs ± 3.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"80.1 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"# presplit mixed:\n",
"mixed = []\n",
"for k in range(500):\n",
" mixed.append(Interp(10., 2., np.linspace(0+k, 5+k, 11)))\n",
" mixed.append(Line(float(k), 2.))\n",
" \n",
"def split_interps_by_type(interpolators):\n",
" interpolators=np.asarray(interpolators)\n",
" u, indices = np.unique(np.asarray([str(type(i)) for i in interpolators]), return_inverse=True)\n",
" # res = typed.List() # typed.List will not work as tuple types are different for each interpolator type\n",
" res = []\n",
" for type_idx,_ in enumerate(u):\n",
" idx = indices == type_idx\n",
" ints = typed.List(interpolators[idx])\n",
" res.append((idx, ints))\n",
" return res\n",
"\n",
"def interpolate_split(split_interps, x):\n",
" res = np.zeros(split_interps[0][0].shape)\n",
" for idx, interps in split_interps:\n",
" res[idx]=interpolate_many(interps, x)\n",
" return res\n",
"\n",
"%timeit [interp.interp(2.) for interp in mixed]\n",
"\n",
"split_interps = split_interps_by_type(mixed)\n",
"%timeit interpolate_split(split_interps, 2.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,py:percent"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": ""
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment