Created
May 18, 2023 17:27
-
-
Save shoyer/57107c2d2023d708683cdb5087810179 to your computer and use it in GitHub Desktop.
JAX einsum primitive .ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyO0R2VDwSaxyTeVPyjiTCCE", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/shoyer/57107c2d2023d708683cdb5087810179/jax-einsum-primitive.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "effqC7pgqK0j", | |
| "outputId": "e28b151a-8483-4621-eac2-42906b12a629", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 136 | |
| } | |
| }, | |
| "source": [ | |
| "! pip install -U jax jaxlib" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Requirement already up-to-date: jax in /usr/local/lib/python3.6/dist-packages (0.1.69)\n", | |
| "Requirement already up-to-date: jaxlib in /usr/local/lib/python3.6/dist-packages (0.1.47)\n", | |
| "Requirement already satisfied, skipping upgrade: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.9.0)\n", | |
| "Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.5)\n", | |
| "Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.2.1)\n", | |
| "Requirement already satisfied, skipping upgrade: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib) (1.4.1)\n", | |
| "Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.12.0)\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "E987RUewqIch" | |
| }, | |
| "source": [ | |
| "# Copyright 2023 Google LLC\n", | |
| "#\n", | |
| "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", | |
| "# you may not use this file except in compliance with the License.\n", | |
| "# You may obtain a copy of the License at\n", | |
| "#\n", | |
| "# https://www.apache.org/licenses/LICENSE-2.0\n", | |
| "#\n", | |
| "# Unless required by applicable law or agreed to in writing, software\n", | |
| "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", | |
| "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | |
| "# See the License for the specific language governing permissions and\n", | |
| "# limitations under the License.\n", | |
| "\n", | |
| "import collections\n", | |
| "import functools\n", | |
| "import itertools\n", | |
| "import operator\n", | |
| "import threading\n", | |
| "\n", | |
| "import numpy as onp\n", | |
| "\n", | |
| "from jax import api\n", | |
| "from jax import core\n", | |
| "from jax import dtypes\n", | |
| "from jax.lax import lax\n", | |
| "from jax import linear_util as lu\n", | |
| "from jax.abstract_arrays import ShapedArray, raise_to_shaped\n", | |
| "from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs\n", | |
| "from jax.interpreters import ad\n", | |
| "from jax.interpreters import partial_eval as pe\n", | |
| "from jax.interpreters import xla\n", | |
| "from jax.interpreters import batching\n", | |
| "from jax.interpreters import masking\n", | |
| "from jax.lib import xla_bridge as xb\n", | |
| "from jax.lib import xla_client\n", | |
| "from jax.util import (partial, unzip2, safe_map, safe_zip, split_list,\n", | |
| " split_dict, cache, extend_name_stack)\n", | |
| "from jax.tree_util import (tree_flatten, tree_unflatten, treedef_is_leaf,\n", | |
| " treedef_children, treedef_tuple)\n", | |
| "from jax import ad_util\n", | |
| "import jax.numpy as jnp\n", | |
| "import jax.test_util as jtu\n", | |
| "\n", | |
| "map = safe_map\n", | |
| "zip = safe_zip\n", | |
| "\n", | |
| "\n", | |
| "def einsum(*operands):\n", | |
| " input_string, output_string, operands = _parse_einsum_input(operands)\n", | |
| " out, = einsum_p.bind(*operands, input_strings=input_string.split(','),\n", | |
| " output_string=output_string)\n", | |
| " return out\n", | |
| "\n", | |
| "def _einsum_impl(*operands, input_strings, output_string):\n", | |
| " subscripts = ','.join(input_strings) + '->' + output_string\n", | |
| " return [jnp.einsum(subscripts, *operands)]\n", | |
| "\n", | |
| "def sum_tangents(tangents):\n", | |
| " return functools.reduce(ad.add_tangents, tangents)\n", | |
| "\n", | |
| "def _einsum_jvp(primals, tangents, *, input_strings, output_string):\n", | |
| " subscripts = ','.join(input_strings) + '->' + output_string\n", | |
| " this_einsum = functools.partial(einsum, subscripts)\n", | |
| " operands_list = []\n", | |
| " for index, tangent in enumerate(tangents):\n", | |
| " if type(tangent) is not ad.Zero:\n", | |
| " operands = list(primals)\n", | |
| " operands[index] = tangent\n", | |
| " operands_list.append(operands)\n", | |
| " out_primal = this_einsum(*primals)\n", | |
| " out_tangent = sum_tangents(this_einsum(*ops) for ops in operands_list)\n", | |
| " return [out_primal], [out_tangent]\n", | |
| "\n", | |
| "def _einsum_transpose_rule(cotangent, *primals, input_strings, output_string):\n", | |
| " index, = [i for i, p in enumerate(primals) if ad.is_undefined_primal(p)]\n", | |
| " subscripts = (','.join(input_strings[:index] + input_strings[index+1:])\n", | |
| " + ',' + output_string\n", | |
| " + '->' + input_strings[index])\n", | |
| " operands = primals[:index] + primals[index+1:] + tuple(cotangent)\n", | |
| " out = [None] * len(primals)\n", | |
| " out[index] = einsum(subscripts, *operands)\n", | |
| " return out\n", | |
| "\n", | |
| "einsum_p = core.Primitive('einsum')\n", | |
| "einsum_p.multiple_results = True\n", | |
| "einsum_p.def_impl(_einsum_impl)\n", | |
| "\n", | |
| "def generic_abstract_eval(*avals, **params):\n", | |
| " return pe.abstract_eval_fun(_einsum_impl, *avals, **params)\n", | |
| "einsum_p.def_abstract_eval(generic_abstract_eval)\n", | |
| "\n", | |
| "ad.primitive_jvps[einsum_p] = _einsum_jvp\n", | |
| "\n", | |
| "xla.initial_style_translations[einsum_p] = xla.lower_fun_initial_style(_einsum_impl)\n", | |
| "\n", | |
| "ad.primitive_transposes[einsum_p] = _einsum_transpose_rule\n", | |
| "\n", | |
| "# TODO(shoyer): batching rule (should be pretty easy)\n", | |
| "# batching.primitive_batchers[einsum_p] = _einsum_batching_rule\n" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "jqWG8ZJbqfAa" | |
| }, | |
| "source": [ | |
| "#@title define `_parse_einsum_input` (from numpy) { display-mode: \"form\" }\n", | |
| "# from numpy.core.einsumfunc\n", | |
| "\n", | |
| "einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'\n", | |
| "einsum_symbols_set = set(einsum_symbols)\n", | |
| "\n", | |
| "# asarray = lambda x: x\n", | |
| "asarray = jnp.asarray\n", | |
| "\n", | |
| "def _parse_einsum_input(operands):\n", | |
| " \"\"\"\n", | |
| " A reproduction of einsum c side einsum parsing in python.\n", | |
| "\n", | |
| " Returns\n", | |
| " -------\n", | |
| " input_strings : str\n", | |
| " Parsed input strings\n", | |
| " output_string : str\n", | |
| " Parsed output string\n", | |
| " operands : list of array_like\n", | |
| " The operands to use in the numpy contraction\n", | |
| "\n", | |
| " Examples\n", | |
| " --------\n", | |
| " The operand list is simplified to reduce printing:\n", | |
| "\n", | |
| " >>> a = np.random.rand(4, 4)\n", | |
| " >>> b = np.random.rand(4, 4, 4)\n", | |
| " >>> __parse_einsum_input(('...a,...a->...', a, b))\n", | |
| " ('za,xza', 'xz', [a, b])\n", | |
| "\n", | |
| " >>> __parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))\n", | |
| " ('za,xza', 'xz', [a, b])\n", | |
| " \"\"\"\n", | |
| "\n", | |
| " if len(operands) == 0:\n", | |
| " raise ValueError(\"No input operands\")\n", | |
| "\n", | |
| " if isinstance(operands[0], str):\n", | |
| " subscripts = operands[0].replace(\" \", \"\")\n", | |
| " operands = [asarray(v) for v in operands[1:]]\n", | |
| "\n", | |
| " # Ensure all characters are valid\n", | |
| " for s in subscripts:\n", | |
| " if s in '.,->':\n", | |
| " continue\n", | |
| " if s not in einsum_symbols:\n", | |
| " raise ValueError(\"Character %s is not a valid symbol.\" % s)\n", | |
| "\n", | |
| " else:\n", | |
| " tmp_operands = list(operands)\n", | |
| " operand_list = []\n", | |
| " subscript_list = []\n", | |
| " for p in range(len(operands) // 2):\n", | |
| " operand_list.append(tmp_operands.pop(0))\n", | |
| " subscript_list.append(tmp_operands.pop(0))\n", | |
| "\n", | |
| " output_list = tmp_operands[-1] if len(tmp_operands) else None\n", | |
| " operands = [asarray(v) for v in operand_list]\n", | |
| " subscripts = \"\"\n", | |
| " last = len(subscript_list) - 1\n", | |
| " for num, sub in enumerate(subscript_list):\n", | |
| " for s in sub:\n", | |
| " if s is Ellipsis:\n", | |
| " subscripts += \"...\"\n", | |
| " elif isinstance(s, int):\n", | |
| " subscripts += einsum_symbols[s]\n", | |
| " else:\n", | |
| " raise TypeError(\"For this input type lists must contain \"\n", | |
| " \"either int or Ellipsis\")\n", | |
| " if num != last:\n", | |
| " subscripts += \",\"\n", | |
| "\n", | |
| " if output_list is not None:\n", | |
| " subscripts += \"->\"\n", | |
| " for s in output_list:\n", | |
| " if s is Ellipsis:\n", | |
| " subscripts += \"...\"\n", | |
| " elif isinstance(s, int):\n", | |
| " subscripts += einsum_symbols[s]\n", | |
| " else:\n", | |
| " raise TypeError(\"For this input type lists must contain \"\n", | |
| " \"either int or Ellipsis\")\n", | |
| " # Check for proper \"->\"\n", | |
| " if (\"-\" in subscripts) or (\">\" in subscripts):\n", | |
| " invalid = (subscripts.count(\"-\") > 1) or (subscripts.count(\">\") > 1)\n", | |
| " if invalid or (subscripts.count(\"->\") != 1):\n", | |
| " raise ValueError(\"Subscripts can only contain one '->'.\")\n", | |
| "\n", | |
| " # Parse ellipses\n", | |
| " if \".\" in subscripts:\n", | |
| " used = subscripts.replace(\".\", \"\").replace(\",\", \"\").replace(\"->\", \"\")\n", | |
| " unused = list(einsum_symbols_set - set(used))\n", | |
| " ellipse_inds = \"\".join(unused)\n", | |
| " longest = 0\n", | |
| "\n", | |
| " if \"->\" in subscripts:\n", | |
| " input_tmp, output_sub = subscripts.split(\"->\")\n", | |
| " split_subscripts = input_tmp.split(\",\")\n", | |
| " out_sub = True\n", | |
| " else:\n", | |
| " split_subscripts = subscripts.split(',')\n", | |
| " out_sub = False\n", | |
| "\n", | |
| " for num, sub in enumerate(split_subscripts):\n", | |
| " if \".\" in sub:\n", | |
| " if (sub.count(\".\") != 3) or (sub.count(\"...\") != 1):\n", | |
| " raise ValueError(\"Invalid Ellipses.\")\n", | |
| "\n", | |
| " # Take into account numerical values\n", | |
| " if operands[num].shape == ():\n", | |
| " ellipse_count = 0\n", | |
| " else:\n", | |
| " ellipse_count = max(operands[num].ndim, 1)\n", | |
| " ellipse_count -= (len(sub) - 3)\n", | |
| "\n", | |
| " if ellipse_count > longest:\n", | |
| " longest = ellipse_count\n", | |
| "\n", | |
| " if ellipse_count < 0:\n", | |
| " raise ValueError(\"Ellipses lengths do not match.\")\n", | |
| " elif ellipse_count == 0:\n", | |
| " split_subscripts[num] = sub.replace('...', '')\n", | |
| " else:\n", | |
| " rep_inds = ellipse_inds[-ellipse_count:]\n", | |
| " split_subscripts[num] = sub.replace('...', rep_inds)\n", | |
| "\n", | |
| " subscripts = \",\".join(split_subscripts)\n", | |
| " if longest == 0:\n", | |
| " out_ellipse = \"\"\n", | |
| " else:\n", | |
| " out_ellipse = ellipse_inds[-longest:]\n", | |
| "\n", | |
| " if out_sub:\n", | |
| " subscripts += \"->\" + output_sub.replace(\"...\", out_ellipse)\n", | |
| " else:\n", | |
| " # Special care for outputless ellipses\n", | |
| " output_subscript = \"\"\n", | |
| " tmp_subscripts = subscripts.replace(\",\", \"\")\n", | |
| " for s in sorted(set(tmp_subscripts)):\n", | |
| " if s not in (einsum_symbols):\n", | |
| " raise ValueError(\"Character %s is not a valid symbol.\" % s)\n", | |
| " if tmp_subscripts.count(s) == 1:\n", | |
| " output_subscript += s\n", | |
| " normal_inds = ''.join(sorted(set(output_subscript) -\n", | |
| " set(out_ellipse)))\n", | |
| "\n", | |
| " subscripts += \"->\" + out_ellipse + normal_inds\n", | |
| "\n", | |
| " # Build output string if does not exist\n", | |
| " if \"->\" in subscripts:\n", | |
| " input_subscripts, output_subscript = subscripts.split(\"->\")\n", | |
| " else:\n", | |
| " input_subscripts = subscripts\n", | |
| " # Build output subscripts\n", | |
| " tmp_subscripts = subscripts.replace(\",\", \"\")\n", | |
| " output_subscript = \"\"\n", | |
| " for s in sorted(set(tmp_subscripts)):\n", | |
| " if s not in einsum_symbols:\n", | |
| " raise ValueError(\"Character %s is not a valid symbol.\" % s)\n", | |
| " if tmp_subscripts.count(s) == 1:\n", | |
| " output_subscript += s\n", | |
| "\n", | |
| " # Make sure output subscripts are in the input\n", | |
| " for char in output_subscript:\n", | |
| " if char not in input_subscripts:\n", | |
| " raise ValueError(\"Output character %s did not appear in the input\"\n", | |
| " % char)\n", | |
| "\n", | |
| " # Make sure number operands is equivalent to the number of terms\n", | |
| " if len(input_subscripts.split(',')) != len(operands):\n", | |
| " raise ValueError(\"Number of einsum subscripts must be equal to the \"\n", | |
| " \"number of operands.\")\n", | |
| "\n", | |
| " return (input_subscripts, output_subscript, operands)\n" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "3VUxAr_HqWlE" | |
| }, | |
| "source": [ | |
| "import jax\n", | |
| "import jax.test_util as jtu" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "SncY3Y9RqhuT", | |
| "outputId": "879f01db-15fc-4bdd-df98-4689e46ddb29", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 119 | |
| } | |
| }, | |
| "source": [ | |
| "jax.make_jaxpr(partial(einsum, 'i,ij->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU.\n", | |
| " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" | |
| ], | |
| "name": "stderr" | |
| }, | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda ; a b.\n", | |
| " let c = einsum[ input_strings=['i', 'ij']\n", | |
| " output_string=ij ] a b\n", | |
| " in (c,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 4 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "cYymXcbBqkeL", | |
| "outputId": "d1c24b47-d5b9-41d8-f1cf-dd1d7f12d79e", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 85 | |
| } | |
| }, | |
| "source": [ | |
| "jax.make_jaxpr(partial(einsum, 'i,ij,jk->ij'))(jnp.zeros((2,)), jnp.zeros((2, 3)), jnp.zeros((3, 4)))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda ; a b c.\n", | |
| " let d = einsum[ input_strings=['i', 'ij', 'jk']\n", | |
| " output_string=ij ] a b c\n", | |
| " in (d,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 5 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "EARozZRGqlmL" | |
| }, | |
| "source": [ | |
| "def make_einsum_grad(subscripts, einsum_fun=einsum, argnums=0):\n", | |
| " @partial(jax.grad, argnums=argnums)\n", | |
| " def f(*operands):\n", | |
| " return jnp.sum(einsum_fun(subscripts, *operands) ** 2)\n", | |
| " return f" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "T2NGlsvSqm3E", | |
| "outputId": "d1b03be7-fc51-41d5-f545-fc5412b1cd37", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 153 | |
| } | |
| }, | |
| "source": [ | |
| "jax.make_jaxpr(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4)))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda c ; a b.\n", | |
| " let d = einsum[ input_strings=['ij', 'jk']\n", | |
| " output_string=ij ] a b\n", | |
| " e = mul 2.0 d\n", | |
| " f = mul c e\n", | |
| " g = einsum[ input_strings=['jk', 'ij']\n", | |
| " output_string=ij ] b f\n", | |
| " in (g,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 10 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "CinG5HQOKEgM" | |
| }, | |
| "source": [ | |
| "import opt_einsum" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Cb59R0xmKFHE" | |
| }, | |
| "source": [ | |
| "opt_einsum" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "C2HW3evBKehE" | |
| }, | |
| "source": [ | |
| "import collections" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "MhcVxm5gJvKN", | |
| "outputId": "03bf46ac-e2e1-49fb-83fc-e6c24334b3d5", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 153 | |
| } | |
| }, | |
| "source": [ | |
| "operands = 'abc,ad,be,cf,def,dg,eh,fi->ghi'\n", | |
| "sizes = collections.defaultdict(lambda: 100)\n", | |
| "arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n", | |
| "jax.make_jaxpr(make_einsum_grad(operands))(*arrays)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda i ; a b c d e f g h.\n", | |
| " let j = einsum[ input_strings=['abc', 'ad', 'be', 'cf', 'def', 'dg', 'eh', 'fi']\n", | |
| " output_string=ghi ] a b c d e f g h\n", | |
| " k = mul 2.0 j\n", | |
| " l = mul i k\n", | |
| " m = einsum[ input_strings=['ad', 'be', 'cf', 'def', 'dg', 'eh', 'fi', 'ghi']\n", | |
| " output_string=abc ] b c d e f g h l\n", | |
| " in (m,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 19 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "QltFpuSnLCit", | |
| "outputId": "a50ae749-5524-49e6-da5a-70d506ecf090", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 391 | |
| } | |
| }, | |
| "source": [ | |
| "operands = 'ad,be,cf,def,dg,eh,fi,ghi->abc'\n", | |
| "arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n", | |
| "jax.make_jaxpr(partial(jnp.einsum, operands))(*arrays)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda ; a b c d e f g h.\n", | |
| " let i = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; a b c d e f g h.\n", | |
| " let i = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n", | |
| " precision=None ] h e\n", | |
| " j = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n", | |
| " precision=None ] i f\n", | |
| " k = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n", | |
| " precision=None ] j g\n", | |
| " l = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n", | |
| " precision=None ] k d\n", | |
| " m = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n", | |
| " precision=None ] l a\n", | |
| " n = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n", | |
| " precision=None ] m b\n", | |
| " o = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n", | |
| " precision=None ] n c\n", | |
| " in (o,) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False, False)\n", | |
| " name=_einsum ] a b c d e f g h\n", | |
| " in (i,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 25 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "mXEiJoeVK8eN", | |
| "outputId": "ede223bf-9d72-4baf-d598-bab651161ddd", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 887 | |
| } | |
| }, | |
| "source": [ | |
| "operands = 'abc,ad,be,cf,def,dg,eh,fi->ghi'\n", | |
| "sizes = collections.defaultdict(lambda: 100)\n", | |
| "arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n", | |
| "jax.make_jaxpr(make_einsum_grad(operands, einsum=jnp.einsum))(*arrays)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda i s ; a b c d e f g h.\n", | |
| " let j k l m n o p q r = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; a b c d e f g h i.\n", | |
| " let j = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n", | |
| " precision=None ] b a\n", | |
| " k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", | |
| " precision=None ] j c\n", | |
| " l = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\n", | |
| " precision=None ] k d\n", | |
| " m = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n", | |
| " precision=None ] l e\n", | |
| " n = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n", | |
| " precision=None ] m f\n", | |
| " o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n", | |
| " precision=None ] n g\n", | |
| " p = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n", | |
| " precision=None ] o h\n", | |
| " in (p, *, b, c, d, e, f, g, h) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False, False, False)\n", | |
| " name=jvp(_einsum) ] a b c d e f g h i\n", | |
| " t = mul 2.0 j\n", | |
| " u = mul s t\n", | |
| " v = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; a b c d e f g h i j k l m n o.\n", | |
| " let p = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n", | |
| " precision=None ] o g\n", | |
| " q = transpose[ permutation=(2, 0, 1) ] p\n", | |
| " r = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n", | |
| " precision=None ] q f\n", | |
| " s = transpose[ permutation=(2, 0, 1) ] r\n", | |
| " t = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n", | |
| " precision=None ] s e\n", | |
| " u = transpose[ permutation=(2, 0, 1) ] t\n", | |
| " v = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n", | |
| " precision=None ] u d\n", | |
| " w = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n", | |
| " precision=None ] v c\n", | |
| " x = transpose[ permutation=(0, 2, 1) ] w\n", | |
| " y = dot_general[ dimension_numbers=(((2,), (1,)), ((), ()))\n", | |
| " precision=None ] x b\n", | |
| " z = transpose[ permutation=(0, 2, 1) ] y\n", | |
| " ba = dot_general[ dimension_numbers=(((0,), (1,)), ((), ()))\n", | |
| " precision=None ] z a\n", | |
| " bb = transpose[ permutation=(2, 0, 1) ] ba\n", | |
| " in (bb,) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False)\n", | |
| " name=transpose(jvp(_einsum)) ] l m n o p q r i i i i i i i u\n", | |
| " in (v,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 20 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "T5nLpDnXL7H3", | |
| "outputId": "58c81d37-b3ba-4019-d395-a721d3b08c68", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 153 | |
| } | |
| }, | |
| "source": [ | |
| "operands = 'bdik,acaj,ikab,ajac,ikbd->'\n", | |
| "sizes = collections.defaultdict(lambda: 10)\n", | |
| "arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n", | |
| "jax.make_jaxpr(make_einsum_grad(operands))(*arrays)\n" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda ; a b c d e.\n", | |
| " let f = einsum[ input_strings=['bdik', 'acaj', 'ikab', 'ajac', 'ikbd']\n", | |
| " output_string= ] a b c d e\n", | |
| " g = mul 2.0 f\n", | |
| " h = mul 1.0 g\n", | |
| " i = einsum[ input_strings=['acaj', 'ikab', 'ajac', 'ikbd', '']\n", | |
| " output_string=bdik ] b c d e h\n", | |
| " in (i,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 26 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "ikKb81hVRREY" | |
| }, | |
| "source": [ | |
| "block_until_ready = partial(jax.tree_map, lambda x: x.block_until_ready())" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "qn1xnP2VXj5H" | |
| }, | |
| "source": [ | |
| "def make_einsum_grad2(subscripts, einsum_fun=einsum, argnums=0):\n", | |
| " @partial(jax.grad, argnums=argnums)\n", | |
| " def f(*operands):\n", | |
| " return einsum_fun(subscripts, *operands)\n", | |
| " return f" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "wrMmF7xgQJQV", | |
| "outputId": "cadad34c-65b0-4204-d54a-efab70d977af", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 170 | |
| } | |
| }, | |
| "source": [ | |
| "operands = 'abcde,abfg,cdhi,ghjk,ielm,fjno,klpq,nopqm->'\n", | |
| "dim_size = 8\n", | |
| "print(f\"expression: {operands}\")\n", | |
| "print(f\"dim_size: {dim_size}\")\n", | |
| "sizes = collections.defaultdict(lambda: dim_size)\n", | |
| "arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n", | |
| "argnums = (1, 2, 3, 4, 5, 6, 7)\n", | |
| "print(f\"gradient argnums: {argnums}\")\n", | |
| "\n", | |
| "print()\n", | |
| "print(\"einsum primitive\")\n", | |
| "f = jax.jit(make_einsum_grad(operands, einsum_fun=einsum, argnums=argnums))\n", | |
| "# print(jax.make_jaxpr(f)(*arrays))\n", | |
| "block_until_ready(f(*arrays)) # compile\n", | |
| "%timeit block_until_ready(f(*arrays))\n", | |
| "\n", | |
| "print()\n", | |
| "print(\"dot_general primitive\")\n", | |
| "f = jax.jit(make_einsum_grad(operands, einsum_fun=jnp.einsum, argnums=argnums))\n", | |
| "# print(jax.make_jaxpr(f)(*arrays))\n", | |
| "block_until_ready(f(*arrays)) # compile\n", | |
| "%timeit block_until_ready(f(*arrays))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "expression: abcde,abfg,cdhi,ghjk,ielm,fjno,klpq,nopqm->\n", | |
| "dim_size: 8\n", | |
| "gradient argnums: (1, 2, 3, 4, 5, 6, 7)\n", | |
| "\n", | |
| "einsum primitive\n", | |
| "100 loops, best of 3: 4.88 ms per loop\n", | |
| "\n", | |
| "dot_general primitive\n", | |
| "100 loops, best of 3: 3.93 ms per loop\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "icUau75pQV8M", | |
| "outputId": "acfdb690-8fb5-45a7-bbd4-e23ae4ab4f99", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 972 | |
| } | |
| }, | |
| "source": [ | |
| "operands = 'abcde,abfg,cdhi,ghjk,ielm,fjno,klpq->nopqm'\n", | |
| "sizes = collections.defaultdict(lambda: 16)\n", | |
| "arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n", | |
| "f = jax.jit(make_einsum_grad(operands, einsum=jnp.einsum, argnums=(0,)))\n", | |
| "print(jax.make_jaxpr(f)(*arrays))\n", | |
| "block_until_ready(f(*arrays)) # compile\n", | |
| "%timeit block_until_ready(f(*arrays))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "{ lambda h i ; a b c d e f g.\n", | |
| " let j = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; h q a b c d e f g.\n", | |
| " let i j k l m n o p = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; a b c d e f g h.\n", | |
| " let i = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ()))\n", | |
| " precision=None ] b a\n", | |
| " j = dot_general[ dimension_numbers=(((2, 3), (0, 1)), ((), ()))\n", | |
| " precision=None ] i c\n", | |
| " k = dot_general[ dimension_numbers=(((1, 3), (0, 1)), ((), ()))\n", | |
| " precision=None ] j d\n", | |
| " l = dot_general[ dimension_numbers=(((1, 2), (1, 0)), ((), ()))\n", | |
| " precision=None ] k e\n", | |
| " m = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ()))\n", | |
| " precision=None ] l f\n", | |
| " n = dot_general[ dimension_numbers=(((0, 1), (0, 1)), ((), ()))\n", | |
| " precision=None ] m g\n", | |
| " o = transpose[ permutation=(1, 2, 3, 4, 0) ] n\n", | |
| " in (o, *, b, c, d, e, f, g) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False, False)\n", | |
| " name=jvp(_einsum) ] a b c d e f g h\n", | |
| " r = mul 2.0 i\n", | |
| " s = mul q r\n", | |
| " t = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; a b c d e f g h i j k l m.\n", | |
| " let n = transpose[ permutation=(4, 0, 1, 2, 3) ] m\n", | |
| " o = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n", | |
| " precision=None ] n f\n", | |
| " p = transpose[ permutation=(3, 4, 0, 1, 2) ] o\n", | |
| " q = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n", | |
| " precision=None ] p e\n", | |
| " r = transpose[ permutation=(3, 4, 0, 1, 2) ] q\n", | |
| " s = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n", | |
| " precision=None ] r d\n", | |
| " t = transpose[ permutation=(0, 4, 3, 1, 2) ] s\n", | |
| " u = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n", | |
| " precision=None ] t c\n", | |
| " v = transpose[ permutation=(0, 3, 1, 4, 2) ] u\n", | |
| " w = dot_general[ dimension_numbers=(((3, 4), (2, 3)), ((), ()))\n", | |
| " precision=None ] v b\n", | |
| " x = transpose[ permutation=(0, 1, 3, 4, 2) ] w\n", | |
| " y = dot_general[ dimension_numbers=(((0, 1), (2, 3)), ((), ()))\n", | |
| " precision=None ] x a\n", | |
| " z = transpose[ permutation=(3, 4, 0, 1, 2) ] y\n", | |
| " in (z,) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False, False, False, False, False, False, False)\n", | |
| " name=transpose(jvp(_einsum)) ] k l m n o p h h h h h h s\n", | |
| " in (t,) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False, False, False)\n", | |
| " name=f ] h i a b c d e f g\n", | |
| " in (j,) }\n", | |
| "10 loops, best of 3: 130 ms per loop\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "twzMuFgsNncy", | |
| "outputId": "57f07c77-ca7a-40f5-fdc5-b89aa8cced02", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 442 | |
| } | |
| }, | |
| "source": [ | |
| "operands = 'acaj,ikab,ajac,ikbd,->bdik'\n", | |
| "arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n", | |
| "jax.make_jaxpr(partial(jnp.einsum, operands))(*arrays)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda f g ; a b c d e.\n", | |
| " let h = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; f j a b c d e.\n", | |
| " let g = mul c f\n", | |
| " h = reduce_sum[ axes=(0,) ] g\n", | |
| " i = transpose[ permutation=(1, 0, 2) ] h\n", | |
| " k = mul a j\n", | |
| " l = reduce_sum[ axes=(0,) ] k\n", | |
| " m = transpose[ permutation=(1, 0, 2) ] l\n", | |
| " n = dot_general[ dimension_numbers=(((2, 1), (1, 2)), ((0,), (0,)))\n", | |
| " precision=None ] i m\n", | |
| " o = dot_general[ dimension_numbers=(((0,), (2,)), ((), ()))\n", | |
| " precision=None ] n b\n", | |
| " p = reshape[ dimensions=None\n", | |
| " new_sizes=() ] e\n", | |
| " q = dot_general[ dimension_numbers=(((), ()), ((), ()))\n", | |
| " precision=None ] o p\n", | |
| " r = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n", | |
| " precision=None ] q d\n", | |
| " s = transpose[ permutation=(2, 3, 0, 1) ] r\n", | |
| " in (s,) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False)\n", | |
| " name=_einsum ] f g a b c d e\n", | |
| " in (h,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 27 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "BXbdK8W1P638" | |
| }, | |
| "source": [ | |
| "#" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "8nQdioSXKo78", | |
| "outputId": "059ee47d-77a5-467d-cfc2-872eed5029d8", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 714 | |
| } | |
| }, | |
| "source": [ | |
| "operands = 'bdik,acaj,ikab,ajac,ikbd->'\n", | |
| "sizes = collections.defaultdict(lambda: 10)\n", | |
| "arrays = [jnp.zeros(tuple(sizes[k] for k in op)) for op in operands.split('->')[0].split(',')]\n", | |
| "jax.make_jaxpr(make_einsum_grad(operands, einsum=jnp.einsum))(*arrays)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda f g h ; a b c d e.\n", | |
| " let i j k l m = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; l p a b c d e f.\n", | |
| " let g = transpose[ permutation=(2, 0, 1, 3) ] e\n", | |
| " h = transpose[ permutation=(0, 2, 3, 1) ] a\n", | |
| " i = dot_general[ dimension_numbers=(((3,), (3,)), ((0, 1, 2), (0, 1, 2)))\n", | |
| " precision=None ] g h\n", | |
| " j = transpose[ permutation=(1, 2, 0) ] i\n", | |
| " k = dot_general[ dimension_numbers=(((2, 0, 1), (3, 0, 1)), ((), ()))\n", | |
| " precision=None ] j c\n", | |
| " m = mul d l\n", | |
| " n = reduce_sum[ axes=(0,) ] m\n", | |
| " o = transpose[ permutation=(1, 0, 2) ] n\n", | |
| " q = mul b p\n", | |
| " r = reduce_sum[ axes=(0,) ] q\n", | |
| " s = transpose[ permutation=(1, 0, 2) ] r\n", | |
| " t = dot_general[ dimension_numbers=(((2, 1), (1, 2)), ((0,), (0,)))\n", | |
| " precision=None ] o s\n", | |
| " u = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\n", | |
| " precision=None ] k t\n", | |
| " in (u, *, g, c, t) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False, False)\n", | |
| " name=jvp(_einsum) ] f g a b c d e h\n", | |
| " n = mul 2.0 i\n", | |
| " o = mul 1.0 n\n", | |
| " p = xla_call[ backend=None\n", | |
| " call_jaxpr={ lambda ; a b c d e f g h.\n", | |
| " let i = dot_general[ dimension_numbers=(((), ()), ((), ()))\n", | |
| " precision=None ] h c\n", | |
| " j = dot_general[ dimension_numbers=(((0,), (2,)), ((), ()))\n", | |
| " precision=None ] i b\n", | |
| " k = transpose[ permutation=(2, 0, 1) ] j\n", | |
| " l = dot_general[ dimension_numbers=(((), ()), ((0, 1, 2), (0, 1, 2)))\n", | |
| " precision=None ] k a\n", | |
| " m = transpose[ permutation=(0, 3, 1, 2) ] l\n", | |
| " in (m,) }\n", | |
| " device=None\n", | |
| " donated_invars=(False, False, False, False, False, False, False, False)\n", | |
| " name=transpose(jvp(_einsum)) ] k l m h h h h o\n", | |
| " in (p,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 28 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Dg_RhNupqoaD", | |
| "outputId": "3b089faa-59da-4a2c-98da-8363f2c42054", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 52 | |
| } | |
| }, | |
| "source": [ | |
| "jax.jit(make_einsum_grad('ij,jk->ij'))(jnp.zeros((2, 3)), jnp.zeros((3, 4)))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([[0., 0., 0.],\n", | |
| " [0., 0., 0.]], dtype=float32)" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 10 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Lt8vgNuSqxj7", | |
| "outputId": "cc23622e-1946-4b33-eddb-b85c7f7a96a4", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 52 | |
| } | |
| }, | |
| "source": [ | |
| "make_einsum_grad('ij,jk->ij')(jnp.zeros((2, 3)), jnp.zeros((3, 4)))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([[0., 0., 0.],\n", | |
| " [0., 0., 0.]], dtype=float32)" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 11 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "7BEZ6RkKrOCr", | |
| "outputId": "b23acff3-4f61-4b56-f26e-943e4ae7d6c5", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 52 | |
| } | |
| }, | |
| "source": [ | |
| "from functools import partial\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "rs = np.random.RandomState(0)\n", | |
| "f = partial(einsum, 'i,ij,j->ij')\n", | |
| "args = (rs.randn(2), rs.randn(2, 3), rs.randn(3,))\n", | |
| "jtu.check_grads(f, args, order=2)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:116: UserWarning: No GPU/TPU found, falling back to CPU.\n", | |
| " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" | |
| ], | |
| "name": "stderr" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "xAvXZH4XrjjF", | |
| "outputId": "d6774b36-088a-48bc-aecb-61dd23b604e8", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 51 | |
| } | |
| }, | |
| "source": [ | |
| "from functools import partial\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "rs = np.random.RandomState(0)\n", | |
| "operands = 'ijk,ij,jk->ij'\n", | |
| "f = partial(einsum, operands)\n", | |
| "args = (rs.randn(2, 3, 4), rs.randn(2, 3), rs.randn(3, 4))\n", | |
| "jtu.check_grads(f, args, order=2)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:127: UserWarning: No GPU/TPU found, falling back to CPU.\n", | |
| " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" | |
| ], | |
| "name": "stderr" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "PEvOPu8xr5j8", | |
| "outputId": "f1ccef7a-4e8b-4aa2-e193-ec645b1f859b", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 155 | |
| } | |
| }, | |
| "source": [ | |
| "jax.make_jaxpr(make_einsum_grad(operands))(*args)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "{ lambda d ; a b c.\n", | |
| " let e = einsum[ input_strings=['ijk', 'ij', 'jk']\n", | |
| " output_string=ij ] a b c\n", | |
| " f = mul 2.0 e\n", | |
| " g = mul d f\n", | |
| " h = einsum[ input_strings=['ij', 'jk', 'ij']\n", | |
| " output_string=ijk ] b c g\n", | |
| " in (h,) }" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 14 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "8_zcLutJujVl", | |
| "outputId": "35ebfc26-58d1-4189-80c8-391d9090c169", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 953 | |
| } | |
| }, | |
| "source": [ | |
| "print(jax.xla_computation(make_einsum_grad(operands, einsum=jnp.einsum))(*args).as_hlo_text())" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "HloModule xla_computation_f__3.44\n", | |
| "\n", | |
| "jit_pe_jvp__einsum__.8 {\n", | |
| " parameter.12 = pred[] parameter(3)\n", | |
| " parameter.11 = f32[3,4]{1,0} parameter(2)\n", | |
| " parameter.9 = f32[2,3,4]{2,1,0} parameter(0)\n", | |
| " transpose.14 = f32[3,2,4]{2,0,1} transpose(parameter.9), dimensions={1,0,2}\n", | |
| " dot.15 = f32[3,2]{1,0} dot(parameter.11, transpose.14), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={2}\n", | |
| " parameter.10 = f32[2,3]{1,0} parameter(1)\n", | |
| " transpose.16 = f32[3,2]{0,1} transpose(parameter.10), dimensions={1,0}\n", | |
| " dot.17 = f32[3,2]{1,0} dot(dot.15, transpose.16), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={}\n", | |
| " transpose.18 = f32[2,3]{0,1} transpose(dot.17), dimensions={1,0}\n", | |
| " constant.13 = pred[] constant(false)\n", | |
| " ROOT tuple.19 = (f32[2,3]{0,1}, pred[], f32[3,4]{1,0}, f32[3,2]{0,1}) tuple(transpose.18, constant.13, parameter.11, transpose.16)\n", | |
| "}\n", | |
| "\n", | |
| "jit_transpose_pe_jvp__einsum___.29 {\n", | |
| " parameter.32 = pred[] parameter(2)\n", | |
| " parameter.33 = pred[] parameter(3)\n", | |
| " constant.35 = pred[] constant(false)\n", | |
| " parameter.34 = f32[2,3]{1,0} parameter(4)\n", | |
| " transpose.36 = f32[3,2]{0,1} transpose(parameter.34), dimensions={1,0}\n", | |
| " parameter.31 = f32[3,2]{0,1} parameter(1)\n", | |
| " dot.37 = f32[3,2]{1,0} dot(transpose.36, parameter.31), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={}\n", | |
| " parameter.30 = f32[3,4]{1,0} parameter(0)\n", | |
| " dot.38 = f32[3,2,4]{2,1,0} dot(dot.37, parameter.30), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}\n", | |
| " transpose.39 = f32[2,3,4]{2,0,1} transpose(dot.38), dimensions={1,0,2}\n", | |
| " ROOT tuple.40 = (f32[2,3,4]{2,0,1}) tuple(transpose.39)\n", | |
| "}\n", | |
| "\n", | |
| "ENTRY xla_computation_f__3.44 {\n", | |
| " constant.7 = pred[] constant(false)\n", | |
| " parameter.4 = f32[2,3,4]{2,1,0} parameter(0)\n", | |
| " parameter.5 = f32[2,3]{1,0} parameter(1)\n", | |
| " parameter.6 = f32[3,4]{1,0} parameter(2)\n", | |
| " constant.1 = pred[] constant(false)\n", | |
| " call.20 = (f32[2,3]{0,1}, pred[], f32[3,4]{1,0}, f32[3,2]{0,1}) call(parameter.4, parameter.5, parameter.6, constant.1), to_apply=jit_pe_jvp__einsum__.8\n", | |
| " get-tuple-element.22 = pred[] get-tuple-element(call.20), index=1\n", | |
| " get-tuple-element.23 = f32[3,4]{1,0} get-tuple-element(call.20), index=2\n", | |
| " get-tuple-element.24 = f32[3,2]{0,1} get-tuple-element(call.20), index=3\n", | |
| " constant.2 = f32[] constant(1)\n", | |
| " broadcast.3 = f32[2,3]{1,0} broadcast(constant.2), dimensions={}\n", | |
| " constant.25 = f32[] constant(2)\n", | |
| " broadcast.26 = f32[2,3]{1,0} broadcast(constant.25), dimensions={}\n", | |
| " get-tuple-element.21 = f32[2,3]{0,1} get-tuple-element(call.20), index=0\n", | |
| " multiply.27 = f32[2,3]{1,0} multiply(broadcast.26, get-tuple-element.21)\n", | |
| " multiply.28 = f32[2,3]{1,0} multiply(broadcast.3, multiply.27)\n", | |
| " call.41 = (f32[2,3,4]{2,0,1}) call(get-tuple-element.23, get-tuple-element.24, constant.1, constant.1, multiply.28), to_apply=jit_transpose_pe_jvp__einsum___.29\n", | |
| " get-tuple-element.42 = f32[2,3,4]{2,0,1} get-tuple-element(call.41), index=0\n", | |
| " ROOT tuple.43 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.42)\n", | |
| "}\n", | |
| "\n", | |
| "\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "OFQ8PslmsclX", | |
| "outputId": "8c4644d0-fc11-4a5c-8ee8-1377c8f39da9", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 935 | |
| } | |
| }, | |
| "source": [ | |
| "print(jax.xla_computation(make_einsum_grad(operands))(*args).as_hlo_text())" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "HloModule xla_computation_f__2.43\n", | |
| "\n", | |
| "jit__einsum__260.8 {\n", | |
| " constant.12 = pred[] constant(false)\n", | |
| " parameter.11 = f32[3,4]{1,0} parameter(2)\n", | |
| " parameter.9 = f32[2,3,4]{2,1,0} parameter(0)\n", | |
| " transpose.13 = f32[3,2,4]{2,0,1} transpose(parameter.9), dimensions={1,0,2}\n", | |
| " dot.14 = f32[3,2]{1,0} dot(parameter.11, transpose.13), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={2}\n", | |
| " parameter.10 = f32[2,3]{1,0} parameter(1)\n", | |
| " transpose.15 = f32[3,2]{0,1} transpose(parameter.10), dimensions={1,0}\n", | |
| " dot.16 = f32[3,2]{1,0} dot(dot.14, transpose.15), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={}\n", | |
| " transpose.17 = f32[2,3]{0,1} transpose(dot.16), dimensions={1,0}\n", | |
| " ROOT tuple.18 = (f32[2,3]{0,1}) tuple(transpose.17)\n", | |
| "}\n", | |
| "\n", | |
| "jit__einsum__261.28 {\n", | |
| " constant.32 = pred[] constant(false)\n", | |
| " parameter.31 = f32[2,3]{1,0} parameter(2)\n", | |
| " parameter.29 = f32[2,3]{1,0} parameter(0)\n", | |
| " dot.33 = f32[2,3]{1,0} dot(parameter.31, parameter.29), lhs_batch_dims={0,1}, lhs_contracting_dims={}, rhs_batch_dims={0,1}, rhs_contracting_dims={}\n", | |
| " transpose.34 = f32[3,2]{0,1} transpose(dot.33), dimensions={1,0}\n", | |
| " parameter.30 = f32[3,4]{1,0} parameter(1)\n", | |
| " dot.35 = f32[3,2,4]{2,1,0} dot(transpose.34, parameter.30), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={}\n", | |
| " transpose.36 = f32[2,3,4]{2,0,1} transpose(dot.35), dimensions={1,0,2}\n", | |
| " ROOT tuple.37 = (f32[2,3,4]{2,0,1}) tuple(transpose.36)\n", | |
| "}\n", | |
| "\n", | |
| "ENTRY xla_computation_f__2.43 {\n", | |
| " constant.6 = pred[] constant(false)\n", | |
| " constant.7 = pred[] constant(false)\n", | |
| " constant.27 = pred[] constant(false)\n", | |
| " parameter.4 = f32[2,3]{1,0} parameter(1)\n", | |
| " parameter.5 = f32[3,4]{1,0} parameter(2)\n", | |
| " constant.1 = f32[] constant(1)\n", | |
| " broadcast.2 = f32[2,3]{1,0} broadcast(constant.1), dimensions={}\n", | |
| " constant.23 = f32[] constant(2)\n", | |
| " broadcast.24 = f32[2,3]{1,0} broadcast(constant.23), dimensions={}\n", | |
| " parameter.3 = f32[2,3,4]{2,1,0} parameter(0)\n", | |
| " call.19 = (f32[2,3]{0,1}) call(parameter.3, parameter.4, parameter.5), to_apply=jit__einsum__260.8\n", | |
| " get-tuple-element.20 = f32[2,3]{0,1} get-tuple-element(call.19), index=0\n", | |
| " tuple.21 = (f32[2,3]{0,1}) tuple(get-tuple-element.20)\n", | |
| " get-tuple-element.22 = f32[2,3]{0,1} get-tuple-element(tuple.21), index=0\n", | |
| " multiply.25 = f32[2,3]{1,0} multiply(broadcast.24, get-tuple-element.22)\n", | |
| " multiply.26 = f32[2,3]{1,0} multiply(broadcast.2, multiply.25)\n", | |
| " call.38 = (f32[2,3,4]{2,0,1}) call(parameter.4, parameter.5, multiply.26), to_apply=jit__einsum__261.28\n", | |
| " get-tuple-element.39 = f32[2,3,4]{2,0,1} get-tuple-element(call.38), index=0\n", | |
| " tuple.40 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.39)\n", | |
| " get-tuple-element.41 = f32[2,3,4]{2,0,1} get-tuple-element(tuple.40), index=0\n", | |
| " ROOT tuple.42 = (f32[2,3,4]{2,0,1}) tuple(get-tuple-element.41)\n", | |
| "}\n", | |
| "\n", | |
| "\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "DrAru3d7ueL-" | |
| }, | |
| "source": [], | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment