Created
July 10, 2020 06:26
-
-
Save Japanuspus/fa7d352ef0eec9d482956bec668bfaac to your computer and use it in GitHub Desktop.
Heterogeneous dispatch with numba can be fast
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Numba heterogeneous dispatch\n", | |
| "\n", | |
| "We have many curve objects (>1000) from a few classes (<10) and want to perform a dispatched call to each object." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numba\n", | |
| "import math\n", | |
| "import numpy as np\n", | |
| "from numba.experimental import jitclass\n", | |
| "from numba import typed" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "\n", | |
| "@jitclass([\n", | |
| " ('gridsize', numba.float64),\n", | |
| " ('values', numba.float64[:]),\n", | |
| " ('grid0', numba.float64)\n", | |
| "])\n", | |
| "class Interp(object):\n", | |
| " def __init__(self, grid0, gridsize, values):\n", | |
| " self.gridsize = gridsize\n", | |
| " self.grid0 = grid0\n", | |
| " self.values = values\n", | |
| " \n", | |
| " def interp(self, x):\n", | |
| " idx_float =(x-self.grid0)/self.gridsize\n", | |
| " idx = math.floor(idx_float)\n", | |
| " w2 = idx_float-idx\n", | |
| " return self.values[idx]*(1-w2) + self.values[idx+1]*w2\n", | |
| " \n", | |
| " \n", | |
| "a = Interp(10., 2., np.linspace(0, 5, 11))\n", | |
| "assert a.interp(11.) == 0.25" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[3.175, 4.175, 5.175]" | |
| ] | |
| }, | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "@numba.njit\n", | |
| "def interpolate_many(interpolators, x):\n", | |
| " return [it.interp(x) for it in interpolators]\n", | |
| " \n", | |
| "def interpolate_many_slow(interpolators, x):\n", | |
| " return [it.interp(x) for it in interpolators]\n", | |
| "\n", | |
| "interpolators = typed.List(Interp(10., 2., np.linspace(0+k, 5+k, 11)) for k in range(3))\n", | |
| "interpolate_many(interpolators, 0.7)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "50 µs ± 365 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", | |
| "2.03 ms ± 15 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "interpolators = typed.List(Interp(10., 2., np.linspace(0+k, 5+k, 11)) for k in range(1000))\n", | |
| "%timeit interpolate_many(interpolators, 0.7)\n", | |
| "\n", | |
| "%timeit interpolate_many_slow(interpolators, 0.7)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "@jitclass([\n", | |
| " ('a', numba.float64),\n", | |
| " ('b', numba.float64)\n", | |
| "])\n", | |
| "class Line(object):\n", | |
| " def __init__(self, b, a):\n", | |
| " self.a = a\n", | |
| " self.b = b\n", | |
| "\n", | |
| " def interp(self, x):\n", | |
| " return x*self.a+self.b\n", | |
| " \n", | |
| "b = Line(10., 2.)\n", | |
| "assert b.interp(3.) == 16." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 71, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "694 µs ± 3.93 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n", | |
| "80.1 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# presplit mixed:\n", | |
| "mixed = []\n", | |
| "for k in range(500):\n", | |
| " mixed.append(Interp(10., 2., np.linspace(0+k, 5+k, 11)))\n", | |
| " mixed.append(Line(float(k), 2.))\n", | |
| " \n", | |
| "def split_interps_by_type(interpolators):\n", | |
| " interpolators=np.asarray(interpolators)\n", | |
| " u, indices = np.unique(np.asarray([str(type(i)) for i in interpolators]), return_inverse=True)\n", | |
| " # res = typed.List() # typed.List will not work as tuple types are different for each interpolator type\n", | |
| " res = []\n", | |
| " for type_idx,_ in enumerate(u):\n", | |
| " idx = indices == type_idx\n", | |
| " ints = typed.List(interpolators[idx])\n", | |
| " res.append((idx, ints))\n", | |
| " return res\n", | |
| "\n", | |
| "def interpolate_split(split_interps, x):\n", | |
| " res = np.zeros(split_interps[0][0].shape)\n", | |
| " for idx, interps in split_interps:\n", | |
| " res[idx]=interpolate_many(interps, x)\n", | |
| " return res\n", | |
| "\n", | |
| "%timeit [interp.interp(2.) for interp in mixed]\n", | |
| "\n", | |
| "split_interps = split_interps_by_type(mixed)\n", | |
| "%timeit interpolate_split(split_interps, 2.)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "jupytext": { | |
| "formats": "ipynb,py:percent" | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment