Skip to content

Instantly share code, notes, and snippets.

@drcege
Created July 1, 2020 08:14
Show Gist options
  • Select an option

  • Save drcege/bd96768240ef3cc0a37479c5870b6e0c to your computer and use it in GitHub Desktop.

Select an option

Save drcege/bd96768240ef3cc0a37479c5870b6e0c to your computer and use it in GitHub Desktop.
batch_triplet_loss.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "batch_triplet_loss.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyMqHmAHy2SWbykAVPPGIwDz",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/drcege/bd96768240ef3cc0a37479c5870b6e0c/batch_triplet_loss.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "W3f7ZZ_IWNGo",
"colab_type": "code",
"colab": {}
},
"source": [
"import tensorflow as tf\n",
"import numpy as np"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "3wCISCWcXD_v",
"colab_type": "code",
"colab": {}
},
"source": [
"_margin = 0.3"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "OvbFXgk5XIr5",
"colab_type": "code",
"colab": {}
},
"source": [
"def euclidean_distance(x, y):\n",
" return tf.norm(x - y, axis=-1)\n",
"\n",
"dist_func = euclidean_distance"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oEj2GlRzXQZO",
"colab_type": "code",
"colab": {}
},
"source": [
"def triplet_loss(y_true, y_pred):\n",
" bs = tf.shape(y_pred)[0]\n",
" #tf.print('\\nbs', bs)\n",
" sketch, photo = tf.split(y_pred, 2, axis=-1)\n",
" \n",
" pd = dist_func(sketch, photo)\n",
" \n",
" ss = tf.expand_dims(sketch, 0)\n",
" pp = tf.expand_dims(photo, 1)\n",
" nd = dist_func(ss, pp)\n",
" \n",
" t_loss = pd - nd + _margin # N * N \n",
" eye = tf.eye(bs)\n",
" tt = t_loss * (1 - eye)\n",
" \n",
" LOSS = tf.reduce_sum(tf.nn.relu(tt), axis=0) / tf.cast(bs-1, tf.float32) # N\n",
" #tf.print(tf.shape(LOSS))\n",
" return LOSS"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XT_K5I3VXXjm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "3c7383f1-da8e-4dcf-b028-6fe8dd10d294"
},
"source": [
"A = tf.nn.l2_normalize(tf.random.uniform((3,2)), axis=-1); A"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
"array([[0.87696016, 0.48056316],\n",
" [0.9718337 , 0.23566774],\n",
" [0.6842739 , 0.72922504]], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "jVyEiL7yXko0",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "9e1ea84b-715d-4db5-f9b8-a92850ba2460"
},
"source": [
"B = tf.nn.l2_normalize(tf.random.uniform((3,2)), axis=-1); B"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
"array([[0.66403735, 0.74769926],\n",
" [0.66695917, 0.74509424],\n",
" [0.9976887 , 0.06795112]], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZI9-kbs2YTdC",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "84f0816e-7b1a-4e06-c626-860ea8e6d580"
},
"source": [
"C = tf.concat([A, B], axis=-1); C"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 4), dtype=float32, numpy=\n",
"array([[0.87696016, 0.48056316, 0.66403735, 0.74769926],\n",
" [0.9718337 , 0.23566774, 0.66695917, 0.74509424],\n",
" [0.6842739 , 0.72922504, 0.9976887 , 0.06795112]], dtype=float32)>"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "80xAAzfkX0SM",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "6ac5712a-ef73-42e1-f5db-e63a58d903b2"
},
"source": [
"print(triplet_loss(None, C))"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"tf.Tensor([0.25777826 0.510126 1.0063429 ], shape=(3,), dtype=float32)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qTycXF71ZdYz",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 8,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment