Skip to content

Instantly share code, notes, and snippets.

@corba777
Last active February 24, 2020 00:15
Show Gist options
  • Select an option

  • Save corba777/9a58f4514e7b403db5feffe178f9cb20 to your computer and use it in GitHub Desktop.

Select an option

Save corba777/9a58f4514e7b403db5feffe178f9cb20 to your computer and use it in GitHub Desktop.
JAX_Sort_and_argsort.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX_Sort_and_argsort.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyOyLIs+i4DwfpiAp06XngTt",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/corba777/9a58f4514e7b403db5feffe178f9cb20/jax_sort_and_argsort.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "JYAtYEgTag2Y",
"colab_type": "code",
"outputId": "5f0332dd-6c3d-4865-ff3f-55d0d2aca77e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"!pip install -q --upgrade jax jaxlib"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 276kB 2.8MB/s \n",
"\u001b[K |████████████████████████████████| 27.3MB 144kB/s \n",
"\u001b[?25h Building wheel for jax (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "U7xppGExbSug",
"colab_type": "code",
"colab": {}
},
"source": [
"from __future__ import print_function, division\n",
"import jax.numpy as np\n",
"from jax import grad, jit, vmap,jacfwd, jacrev\n",
"from jax import random"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2AcBkBOobbSO",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as npp"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "BOQBKfVnbjar",
"colab_type": "code",
"outputId": "026fbaf2-2633-4648-e065-e673532ccb0b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
}
},
"source": [
"key = random.PRNGKey(42)\n",
"test_array=random.normal(key,(10,))\n",
"test_array"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([ 0.36900425, -0.4606755 , -0.8650935 , 1.2080883 ,\n",
" 1.0030649 , -0.87080586, -0.3984998 , -0.6670093 ,\n",
" 0.33689344, 0.39822483], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "JxQYR0SUdIJc",
"colab_type": "code",
"outputId": "68f8af44-e3a3-4568-d3fb-c4f182d0d465",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 187
}
},
"source": [
"jacfwd(np.sort)(test_array)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n",
" [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
" [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hi_Gs5eydqoD",
"colab_type": "code",
"outputId": "90cbbb29-2321-4cbb-c030-17a08dac229e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 187
}
},
"source": [
"jacrev(np.sort)(test_array)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
" [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],\n",
" [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
" [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KScU104jd2XB",
"colab_type": "code",
"outputId": "8c066edc-2467-4776-d700-d6bc2ad97577",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 187
}
},
"source": [
"jacfwd(np.argsort)(test_array)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DTs8MD2KeJVg",
"colab_type": "code",
"colab": {}
},
"source": [
"def argsort(x):\n",
" argx=np.argsort(x)\n",
" return list(map(lambda t: t*1.0,argx))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "yMy-fUH0gumJ",
"colab_type": "code",
"outputId": "c935d231-410f-4dfe-9d86-2fc9f330b491",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 187
}
},
"source": [
"jacrev(argsort)(test_array)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n",
" DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)]"
]
},
"metadata": {
"tags": []
},
"execution_count": 37
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Wpb2_sxphN9M",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment