Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save neelabalan/e868d438f4b4d98b0a3081c85555df5f to your computer and use it in GitHub Desktop.

Select an option

Save neelabalan/e868d438f4b4d98b0a3081c85555df5f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"from numba import jit\n",
"\n",
"@jit(nopython=True)\n",
"def monte_carlo_pi(nsamples):\n",
" acc = 0\n",
" for i in range(nsamples):\n",
" x = random.random()\n",
" y = random.random()\n",
" if (x ** 2 + y ** 2) < 1.0:\n",
" acc += 1\n",
" return 4.0 * acc / nsamples\n",
"\n",
"def monte_carlo_pi_no_numba(nsamples):\n",
" acc = 0\n",
" for i in range(nsamples):\n",
" x = random.random()\n",
" y = random.random()\n",
" if (x ** 2 + y ** 2) < 1.0:\n",
" acc += 1\n",
" return 4.0 * acc / nsamples"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.25 ms, sys: 136 µs, total: 4.39 ms\n",
"Wall time: 4.25 ms\n"
]
},
{
"data": {
"text/plain": [
"3.1436"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time monte_carlo_pi_no_numba(10000)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 421 ms, sys: 587 ms, total: 1.01 s\n",
"Wall time: 265 ms\n",
"CPU times: user 82 µs, sys: 67 µs, total: 149 µs\n",
"Wall time: 153 µs\n"
]
},
{
"data": {
"text/plain": [
"3.124"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time monte_carlo_pi(10000)\n",
"%time monte_carlo_pi(10000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Most important imports"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from numba import jit, njit, types, vectorize"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### First, let's do things incorrectly"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def original_function(input_list):\n",
" output_list = []\n",
" for item in input_list:\n",
" if item % 2 == 0:\n",
" output_list.append(2)\n",
" else:\n",
" output_list.append('1')\n",
" return output_list\n",
"\n",
"test_list = list(range(100000))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[2, '1', 2, '1', 2, '1', 2, '1', 2, '1']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"original_function(test_list)[0:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use `@jit` as a decorator on top of the function, but a decorator is just a function that takes accepts and returns functions sooo..."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"jitted_function = jit()(original_function)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CPUDispatcher(<function original_function at 0x7fa3ec1d7c20>)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jitted_function"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-6-af244bf1c758>:1: NumbaWarning: \n",
"Compilation is falling back to object mode WITH looplifting enabled because Function \"original_function\" failed type inference due to: Invalid use of BoundFunction(list.append for list(int64)) with parameters (Literal[str](1))\n",
" * parameterized\n",
"[1] During: resolving callee type: BoundFunction(list.append for list(int64))\n",
"[2] During: typing of call at <ipython-input-6-af244bf1c758> (7)\n",
"\n",
"\n",
"File \"<ipython-input-6-af244bf1c758>\", line 7:\n",
"def original_function(input_list):\n",
" <source elided>\n",
" else:\n",
" output_list.append('1')\n",
" ^\n",
"\n",
" def original_function(input_list):\n",
"<ipython-input-6-af244bf1c758>:1: NumbaWarning: \n",
"Compilation is falling back to object mode WITHOUT looplifting enabled because Function \"original_function\" failed type inference due to: cannot determine Numba type of <class 'numba.dispatcher.LiftedLoop'>\n",
"\n",
"File \"<ipython-input-6-af244bf1c758>\", line 3:\n",
"def original_function(input_list):\n",
" <source elided>\n",
" output_list = []\n",
" for item in input_list:\n",
" ^\n",
"\n",
" def original_function(input_list):\n",
"/home/jari/.virtualenvs/cv/lib/python3.7/site-packages/numba/compiler.py:742: NumbaWarning: Function \"original_function\" was compiled in object mode without forceobj=True, but has lifted loops.\n",
"\n",
"File \"<ipython-input-6-af244bf1c758>\", line 1:\n",
"def original_function(input_list):\n",
"^\n",
"\n",
" self.func_ir.loc))\n",
"/home/jari/.virtualenvs/cv/lib/python3.7/site-packages/numba/compiler.py:751: NumbaDeprecationWarning: \n",
"Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.\n",
"\n",
"For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit\n",
"\n",
"File \"<ipython-input-6-af244bf1c758>\", line 1:\n",
"def original_function(input_list):\n",
"^\n",
"\n",
" warnings.warn(errors.NumbaDeprecationWarning(msg, self.func_ir.loc))\n",
"<ipython-input-6-af244bf1c758>:1: NumbaWarning: \n",
"Compilation is falling back to object mode WITHOUT looplifting enabled because Function \"original_function\" failed type inference due to: non-precise type pyobject\n",
"[1] During: typing of argument at <ipython-input-6-af244bf1c758> (3)\n",
"\n",
"File \"<ipython-input-6-af244bf1c758>\", line 3:\n",
"def original_function(input_list):\n",
" <source elided>\n",
" output_list = []\n",
" for item in input_list:\n",
" ^\n",
"\n",
" def original_function(input_list):\n",
"/home/jari/.virtualenvs/cv/lib/python3.7/site-packages/numba/compiler.py:742: NumbaWarning: Function \"original_function\" was compiled in object mode without forceobj=True.\n",
"\n",
"File \"<ipython-input-6-af244bf1c758>\", line 3:\n",
"def original_function(input_list):\n",
" <source elided>\n",
" output_list = []\n",
" for item in input_list:\n",
" ^\n",
"\n",
" self.func_ir.loc))\n",
"/home/jari/.virtualenvs/cv/lib/python3.7/site-packages/numba/compiler.py:751: NumbaDeprecationWarning: \n",
"Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.\n",
"\n",
"For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit\n",
"\n",
"File \"<ipython-input-6-af244bf1c758>\", line 3:\n",
"def original_function(input_list):\n",
" <source elided>\n",
" output_list = []\n",
" for item in input_list:\n",
" ^\n",
"\n",
" warnings.warn(errors.NumbaDeprecationWarning(msg, self.func_ir.loc))\n"
]
},
{
"data": {
"text/plain": [
"[2, '1', 2, '1', 2, '1', 2, '1', 2, '1']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jitted_function(test_list)[0:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Uhhhhhhhhhh.... ???\n",
"\n",
"\n",
"So types in lists need to be consistent. With `@jit` things still work, but confer no speed benefit."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 10.1 ms, sys: 1.37 ms, total: 11.5 ms\n",
"Wall time: 11 ms\n"
]
}
],
"source": [
"%time _ = original_function(test_list)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 26.1 ms, sys: 364 µs, total: 26.5 ms\n",
"Wall time: 26.2 ms\n"
]
}
],
"source": [
"%time _ = jitted_function(test_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In fact it is _slower_!!!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is to make Numba user friendly, but it does more harm than good. Avoid this situation at all cost by using `@jit(nopython=True)` or `@njit`"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"njitted_function = njit()(original_function)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"ename": "TypingError",
"evalue": "Failed in nopython mode pipeline (step: nopython frontend)\nInvalid use of BoundFunction(list.append for list(int64)) with parameters (Literal[str](1))\n * parameterized\n[1] During: resolving callee type: BoundFunction(list.append for list(int64))\n[2] During: typing of call at <ipython-input-6-af244bf1c758> (7)\n\n\nFile \"<ipython-input-6-af244bf1c758>\", line 7:\ndef original_function(input_list):\n <source elided>\n else:\n output_list.append('1')\n ^\n\nThis is not usually a problem with Numba itself but instead often caused by\nthe use of unsupported features or an issue in resolving types.\n\nTo see Python/NumPy features supported by the latest release of Numba visit:\nhttp://numba.pydata.org/numba-doc/latest/reference/pysupported.html\nand\nhttp://numba.pydata.org/numba-doc/latest/reference/numpysupported.html\n\nFor more information about typing errors and how to debug them visit:\nhttp://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile\n\nIf you think your code should work with Numba, please report the error message\nand traceback, along with a minimal reproducer at:\nhttps://github.com/numba/numba/issues/new\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypingError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-14-7ad5bc480a39>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnjitted_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/.virtualenvs/cv/lib/python3.7/site-packages/numba/dispatcher.py\u001b[0m in \u001b[0;36m_compile_for_args\u001b[0;34m(self, *args, **kws)\u001b[0m\n\u001b[1;32m 374\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpatch_message\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 375\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 376\u001b[0;31m \u001b[0merror_rewrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'typing'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 377\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mUnsupportedError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 378\u001b[0m \u001b[0;31m# Something unsupported is present in the user code, add help info\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.virtualenvs/cv/lib/python3.7/site-packages/numba/dispatcher.py\u001b[0m in \u001b[0;36merror_rewrite\u001b[0;34m(e, issue_type)\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 343\u001b[0;31m \u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 344\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[0margtypes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.virtualenvs/cv/lib/python3.7/site-packages/numba/six.py\u001b[0m in \u001b[0;36mreraise\u001b[0;34m(tp, value, tb)\u001b[0m\n\u001b[1;32m 656\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 657\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtb\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 658\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 659\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypingError\u001b[0m: Failed in nopython mode pipeline (step: nopython frontend)\nInvalid use of BoundFunction(list.append for list(int64)) with parameters (Literal[str](1))\n * parameterized\n[1] During: resolving callee type: BoundFunction(list.append for list(int64))\n[2] During: typing of call at <ipython-input-6-af244bf1c758> (7)\n\n\nFile \"<ipython-input-6-af244bf1c758>\", line 7:\ndef original_function(input_list):\n <source elided>\n else:\n output_list.append('1')\n ^\n\nThis is not usually a problem with Numba itself but instead often caused by\nthe use of unsupported features or an issue in resolving types.\n\nTo see Python/NumPy features supported by the latest release of Numba visit:\nhttp://numba.pydata.org/numba-doc/latest/reference/pysupported.html\nand\nhttp://numba.pydata.org/numba-doc/latest/reference/numpysupported.html\n\nFor more information about typing errors and how to debug them visit:\nhttp://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile\n\nIf you think your code should work with Numba, please report the error message\nand traceback, along with a minimal reproducer at:\nhttps://github.com/numba/numba/issues/new\n"
]
}
],
"source": [
"njitted_function(test_list)[0:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Now** the compilation fails outright. Notice how the compilation happens at _call_ time. This is because types are not specified so the compiler needs to see an example of the data being input to generate the code. The above function can have a variant for floats and another for ints, both types would work and the compiler has no way of knowing which one you'll pass."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's make a new function, a sane one this time ... well ... relatively."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def sane_function(input_list):\n",
" output_list = []\n",
" for item in input_list:\n",
" if item % 2 == 0:\n",
" output_list.append(2)\n",
" else:\n",
" output_list.append(1)\n",
" return output_list\n",
"\n",
"test_list = list(range(100000))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 11.6 ms, sys: 0 ns, total: 11.6 ms\n",
"Wall time: 11.1 ms\n"
]
},
{
"data": {
"text/plain": [
"[2, 1, 2, 1, 2]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time sane_function(test_list)[0:5]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"njitted_sane_function = njit()(sane_function)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jari/.virtualenvs/cv/lib/python3.7/site-packages/numba/ir_utils.py:1959: NumbaPendingDeprecationWarning: \n",
"Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'input_list' of function 'sane_function'.\n",
"\n",
"For more information visit http://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types\n",
"\n",
"File \"<ipython-input-15-9a7e18fa2d25>\", line 1:\n",
"def sane_function(input_list):\n",
"^\n",
"\n",
" warnings.warn(NumbaPendingDeprecationWarning(msg, loc=loc))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 407 ms, sys: 0 ns, total: 407 ms\n",
"Wall time: 405 ms\n"
]
},
{
"data": {
"text/plain": [
"[2, 1, 2, 1, 2]"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time njitted_sane_function(test_list)[0:5]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 228 ms, sys: 178 µs, total: 228 ms\n",
"Wall time: 226 ms\n"
]
},
{
"data": {
"text/plain": [
"[2, 1, 2, 1, 2]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time njitted_sane_function(test_list)[0:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ohhh boy more warnings and _even_ slower code...\n",
"\n",
"Come on this was supposed to speed my code up ...\n",
"\n",
"Long story short it's not a good idea to throw a normal python list at numba because it'll take it a long time verifying types. Instead for now use numpy arrays. You can read more about it here (https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#list)."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"test_list = np.arange(100000)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 130 ms, sys: 362 µs, total: 130 ms\n",
"Wall time: 129 ms\n"
]
},
{
"data": {
"text/plain": [
"[2, 1, 2, 1, 2]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time njitted_sane_function(test_list)[0:5]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ahhh ... finally some speedup.\n",
"\n",
"Now this is all extremely basic and we're not doing any real work. Before we move onto something more practical, let's briefly discuss `@vectorize`"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"@vectorize(nopython=True)\n",
"def non_list_function(item):\n",
" if item % 2 == 0:\n",
" return 2\n",
" else:\n",
" return 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This allows us to write a function to operate on a single element, but then call it on a list!"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 56.2 ms, sys: 6.13 ms, total: 62.3 ms\n",
"Wall time: 60.8 ms\n"
]
},
{
"data": {
"text/plain": [
"array([2, 1, 2, ..., 1, 2, 1])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time non_list_function(test_list)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 304 µs, sys: 134 µs, total: 438 µs\n",
"Wall time: 241 µs\n"
]
},
{
"data": {
"text/plain": [
"array([2, 1, 2, ..., 1, 2, 1])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time non_list_function(test_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"O ...\n",
"\n",
"So what's going on? In the fist call the function is actually getting compiled so it takes much longer to run. In the second call, we see an extreme speed up which we can finally get because numba ensures a properly sized output list is pre-allocated, which is an optimization over the past form of the function where the list was being grown to an unknown size. We can fix this in the original function by allocating an output array first."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"@njit\n",
"def allocated_func(input_list):\n",
" output_list = np.zeros_like(input_list)\n",
" for ii, item in enumerate(input_list):\n",
" if item % 2 == 0:\n",
" output_list[ii] = 2\n",
" else:\n",
" output_list[ii] = 1\n",
" return output_list"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 131 ms, sys: 3.9 ms, total: 134 ms\n",
"Wall time: 133 ms\n"
]
},
{
"data": {
"text/plain": [
"array([2, 1, 2, ..., 1, 2, 1])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time allocated_func(test_list)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 314 µs, sys: 132 µs, total: 446 µs\n",
"Wall time: 451 µs\n"
]
},
{
"data": {
"text/plain": [
"array([2, 1, 2, ..., 1, 2, 1])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time allocated_func(test_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Again, longer run the first time, but extremely fast on the second run. Ok, so now with all of these stumbling blocks out of the way, let's explore a few more things and write some more complex functions and make them fast!\n",
"\n",
"Allow me to introduce: the spring-mass-damper system 😲"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment