Created
July 1, 2020 08:14
-
-
Save drcege/bd96768240ef3cc0a37479c5870b6e0c to your computer and use it in GitHub Desktop.
batch_triplet_loss.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "batch_triplet_loss.ipynb", | |
| "provenance": [], | |
| "authorship_tag": "ABX9TyMqHmAHy2SWbykAVPPGIwDz", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/drcege/bd96768240ef3cc0a37479c5870b6e0c/batch_triplet_loss.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "W3f7ZZ_IWNGo", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "import tensorflow as tf\n", | |
| "import numpy as np" | |
| ], | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "3wCISCWcXD_v", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "_margin = 0.3" | |
| ], | |
| "execution_count": 2, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "OvbFXgk5XIr5", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "def euclidean_distance(x, y):\n", | |
| " return tf.norm(x - y, axis=-1)\n", | |
| "\n", | |
| "dist_func = euclidean_distance" | |
| ], | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "oEj2GlRzXQZO", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "def triplet_loss(y_true, y_pred):\n", | |
| " bs = tf.shape(y_pred)[0]\n", | |
| " #tf.print('\\nbs', bs)\n", | |
| " sketch, photo = tf.split(y_pred, 2, axis=-1)\n", | |
| " \n", | |
| " pd = dist_func(sketch, photo)\n", | |
| " \n", | |
| " ss = tf.expand_dims(sketch, 0)\n", | |
| " pp = tf.expand_dims(photo, 1)\n", | |
| " nd = dist_func(ss, pp)\n", | |
| " \n", | |
| " t_loss = pd - nd + _margin # N * N \n", | |
| " eye = tf.eye(bs)\n", | |
| " tt = t_loss * (1 - eye)\n", | |
| " \n", | |
| " LOSS = tf.reduce_sum(tf.nn.relu(tt), axis=0) / tf.cast(bs-1, tf.float32) # N\n", | |
| " #tf.print(tf.shape(LOSS))\n", | |
| " return LOSS" | |
| ], | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "XT_K5I3VXXjm", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 85 | |
| }, | |
| "outputId": "3c7383f1-da8e-4dcf-b028-6fe8dd10d294" | |
| }, | |
| "source": [ | |
| "A = tf.nn.l2_normalize(tf.random.uniform((3,2)), axis=-1); A" | |
| ], | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n", | |
| "array([[0.87696016, 0.48056316],\n", | |
| " [0.9718337 , 0.23566774],\n", | |
| " [0.6842739 , 0.72922504]], dtype=float32)>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 5 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "jVyEiL7yXko0", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 85 | |
| }, | |
| "outputId": "9e1ea84b-715d-4db5-f9b8-a92850ba2460" | |
| }, | |
| "source": [ | |
| "B = tf.nn.l2_normalize(tf.random.uniform((3,2)), axis=-1); B" | |
| ], | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n", | |
| "array([[0.66403735, 0.74769926],\n", | |
| " [0.66695917, 0.74509424],\n", | |
| " [0.9976887 , 0.06795112]], dtype=float32)>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 6 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "ZI9-kbs2YTdC", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 85 | |
| }, | |
| "outputId": "84f0816e-7b1a-4e06-c626-860ea8e6d580" | |
| }, | |
| "source": [ | |
| "C = tf.concat([A, B], axis=-1); C" | |
| ], | |
| "execution_count": 7, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<tf.Tensor: shape=(3, 4), dtype=float32, numpy=\n", | |
| "array([[0.87696016, 0.48056316, 0.66403735, 0.74769926],\n", | |
| " [0.9718337 , 0.23566774, 0.66695917, 0.74509424],\n", | |
| " [0.6842739 , 0.72922504, 0.9976887 , 0.06795112]], dtype=float32)>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 7 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "80xAAzfkX0SM", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "6ac5712a-ef73-42e1-f5db-e63a58d903b2" | |
| }, | |
| "source": [ | |
| "print(triplet_loss(None, C))" | |
| ], | |
| "execution_count": 8, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "tf.Tensor([0.25777826 0.510126 1.0063429 ], shape=(3,), dtype=float32)\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "qTycXF71ZdYz", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "" | |
| ], | |
| "execution_count": 8, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment