Created
August 21, 2018 18:29
-
-
Save c0nn3r/b446ccaf241003632ccc7ca4fb2a37af to your computer and use it in GitHub Desktop.
Learned Positional Embedding Masking
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "import torch", | |
| "execution_count": 1, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "#### Learned Embedding Masking" | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "First, we construct a batch of `[batch (3), sequence (3), channel (4)]`. We want an end result" | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "example_batch = torch.Tensor([\n [[1, 2, 0, 0],\n [2, 1, 3, 0],\n [2, 1, 3, 0]]\n])", | |
| "execution_count": 2, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "scrolled": true, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "example_batch", | |
| "execution_count": 3, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": "tensor([[[1., 2., 0., 0.],\n [2., 1., 3., 0.],\n [2., 1., 3., 0.]]])" | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "mask = (example_batch == 0)", | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "mask", | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": "tensor([[[0, 0, 1, 1],\n [0, 0, 0, 1],\n [0, 0, 0, 1]]], dtype=torch.uint8)" | |
| }, | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "example_batch.size(-1)", | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": "4" | |
| }, | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "index_filled = torch.arange(1, example_batch.size(-1) + 1).expand_as(example_batch)", | |
| "execution_count": 7, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "scrolled": false, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "index_filled", | |
| "execution_count": 8, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": "tensor([[[1, 2, 3, 4],\n [1, 2, 3, 4],\n [1, 2, 3, 4]]])" | |
| }, | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "index_filled.size()", | |
| "execution_count": 9, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": "torch.Size([1, 3, 4])" | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "mask", | |
| "execution_count": 10, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": "tensor([[[0, 0, 1, 1],\n [0, 0, 0, 1],\n [0, 0, 0, 1]]], dtype=torch.uint8)" | |
| }, | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "masked = index_filled.masked_fill_(mask, 0)", | |
| "execution_count": 11, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "Our resulting mask result seems to have the last two columns filled..." | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "masked", | |
| "execution_count": 12, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": "tensor([[[1, 2, 0, 0],\n [1, 2, 0, 0],\n [1, 2, 0, 0]]])" | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "However, the end result should look like this:" | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "torch.Tensor([\n [[1, 2, 0, 0],\n [1, 2, 3, 0],\n [1, 2, 3, 0]]\n])", | |
| "execution_count": 13, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": "tensor([[[1., 2., 0., 0.],\n [1., 2., 3., 0.],\n [1., 2., 3., 0.]]])" | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "", | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ], | |
| "metadata": { | |
| "gist": { | |
| "id": "", | |
| "data": { | |
| "description": "Learned Positional Embedding Masking", | |
| "public": true | |
| } | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3", | |
| "language": "python" | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.6.6", | |
| "mimetype": "text/x-python", | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "pygments_lexer": "ipython3", | |
| "nbconvert_exporter": "python", | |
| "file_extension": ".py" | |
| }, | |
| "latex_envs": { | |
| "eqNumInitial": 1, | |
| "eqLabelWithNumbers": true, | |
| "current_citInitial": 1, | |
| "cite_by": "apalike", | |
| "bibliofile": "biblio.bib", | |
| "LaTeX_envs_menu_present": true, | |
| "labels_anchors": false, | |
| "latex_user_defs": false, | |
| "user_envs_cfg": false, | |
| "report_style_numbering": false, | |
| "autoclose": false, | |
| "autocomplete": true, | |
| "hotkeys": { | |
| "equation": "Ctrl-E", | |
| "itemize": "Ctrl-I" | |
| } | |
| }, | |
| "varInspector": { | |
| "window_display": false, | |
| "cols": { | |
| "lenName": 16, | |
| "lenType": 16, | |
| "lenVar": 40 | |
| }, | |
| "kernels_config": { | |
| "python": { | |
| "library": "var_list.py", | |
| "delete_cmd_prefix": "del ", | |
| "delete_cmd_postfix": "", | |
| "varRefreshCmd": "print(var_dic_list())" | |
| }, | |
| "r": { | |
| "library": "var_list.r", | |
| "delete_cmd_prefix": "rm(", | |
| "delete_cmd_postfix": ") ", | |
| "varRefreshCmd": "cat(var_dic_list()) " | |
| } | |
| }, | |
| "types_to_exclude": [ | |
| "module", | |
| "function", | |
| "builtin_function_or_method", | |
| "instance", | |
| "_Feature" | |
| ] | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment