Last active
February 24, 2020 00:15
-
-
Save corba777/9a58f4514e7b403db5feffe178f9cb20 to your computer and use it in GitHub Desktop.
JAX_Sort_and_argsort.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "JAX_Sort_and_argsort.ipynb", | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyOyLIs+i4DwfpiAp06XngTt", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/corba777/9a58f4514e7b403db5feffe178f9cb20/jax_sort_and_argsort.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "JYAtYEgTag2Y", | |
| "colab_type": "code", | |
| "outputId": "5f0332dd-6c3d-4865-ff3f-55d0d2aca77e", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 68 | |
| } | |
| }, | |
| "source": [ | |
| "!pip install -q --upgrade jax jaxlib" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "\u001b[K |████████████████████████████████| 276kB 2.8MB/s \n", | |
| "\u001b[K |████████████████████████████████| 27.3MB 144kB/s \n", | |
| "\u001b[?25h Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "U7xppGExbSug", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "from __future__ import print_function, division\n", | |
| "import jax.numpy as np\n", | |
| "from jax import grad, jit, vmap,jacfwd, jacrev\n", | |
| "from jax import random" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "2AcBkBOobbSO", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "import numpy as npp" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "BOQBKfVnbjar", | |
| "colab_type": "code", | |
| "outputId": "026fbaf2-2633-4648-e065-e673532ccb0b", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 68 | |
| } | |
| }, | |
| "source": [ | |
| "key = random.PRNGKey(42)\n", | |
| "test_array=random.normal(key,(10,))\n", | |
| "test_array" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([ 0.36900425, -0.4606755 , -0.8650935 , 1.2080883 ,\n", | |
| " 1.0030649 , -0.87080586, -0.3984998 , -0.6670093 ,\n", | |
| " 0.33689344, 0.39822483], dtype=float32)" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 14 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "JxQYR0SUdIJc", | |
| "colab_type": "code", | |
| "outputId": "68f8af44-e3a3-4568-d3fb-c4f182d0d465", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 187 | |
| } | |
| }, | |
| "source": [ | |
| "jacfwd(np.sort)(test_array)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", | |
| " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n", | |
| " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n", | |
| " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n", | |
| " [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 9 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "hi_Gs5eydqoD", | |
| "colab_type": "code", | |
| "outputId": "90cbbb29-2321-4cbb-c030-17a08dac229e", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 187 | |
| } | |
| }, | |
| "source": [ | |
| "jacrev(np.sort)(test_array)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n", | |
| " [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n", | |
| " [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n", | |
| " [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n", | |
| " [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 10 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "KScU104jd2XB", | |
| "colab_type": "code", | |
| "outputId": "8c066edc-2467-4776-d700-d6bc2ad97577", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 187 | |
| } | |
| }, | |
| "source": [ | |
| "jacfwd(np.argsort)(test_array)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "DeviceArray([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", | |
| " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 16 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "DTs8MD2KeJVg", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "def argsort(x):\n", | |
| " argx=np.argsort(x)\n", | |
| " return list(map(lambda t: t*1.0,argx))" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "yMy-fUH0gumJ", | |
| "colab_type": "code", | |
| "outputId": "c935d231-410f-4dfe-9d86-2fc9f330b491", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 187 | |
| } | |
| }, | |
| "source": [ | |
| "jacrev(argsort)(test_array)" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "[DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", | |
| " DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 37 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Wpb2_sxphN9M", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": 0, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment