-
-
Save sailfish009/14fb135179d16cdb6f9dae12aaffbab8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import gc\n", | |
| "\n", | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "from sklearn.datasets import make_classification" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(1024,)" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X, y = make_classification(\n", | |
| " n_samples=1024, \n", | |
| " n_features=256, \n", | |
| " n_informative=128, \n", | |
| " n_redundant=0, \n", | |
| " n_repeated=0, \n", | |
| " n_classes=2, \n", | |
| " n_clusters_per_class=2, \n", | |
| " flip_y=0.01, \n", | |
| " class_sep=1.0, \n", | |
| " hypercube=True, \n", | |
| " shuffle=True, \n", | |
| " random_state=42\n", | |
| ")\n", | |
| "y.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def get_model(swish_module):\n", | |
| " # Deliberately make the model very large\n", | |
| " width = 2 ** 19\n", | |
| " return nn.Sequential(\n", | |
| " nn.Linear(256, width),\n", | |
| " swish_module(),\n", | |
| " nn.BatchNorm1d(width),\n", | |
| " nn.Linear(width, 1)\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "criterion = nn.BCEWithLogitsLoss()\n", | |
| "batch_size = 128" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def print_parameter_count(model):\n", | |
| " print(\"# of parameters: {:,d}\".format(\n", | |
| " np.sum(list(p.numel() for p in model.parameters()))))\n", | |
| " print(\"# of trainable parameters: {:,d}\".format(\n", | |
| " np.sum(list(p.numel() for p in model.parameters() if p.requires_grad)))) " | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Plain Swish Version" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class PlainSwish(nn.Module):\n", | |
| " def forward(self, input_tensor):\n", | |
| " return input_tensor * torch.sigmoid(input_tensor)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "# of parameters: 136,314,881\n", | |
| "# of trainable parameters: 136,314,881\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "524.0009765625" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model = get_model(PlainSwish).cuda()\n", | |
| "print_parameter_count(model)\n", | |
| "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n", | |
| "optimizer.zero_grad()\n", | |
| "torch.cuda.memory_allocated() / 1024 ** 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "data: 524.12646484375\n", | |
| "forw: 1552.126953125\n", | |
| "loss: 1552.12744140625\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 2072.1279296875\n", | |
| "loss: 2072.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 2072.1279296875\n", | |
| "loss: 2072.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 2072.1279296875\n", | |
| "loss: 2072.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 2072.1279296875\n", | |
| "loss: 2072.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 2072.1279296875\n", | |
| "loss: 2072.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 2072.1279296875\n", | |
| "loss: 2072.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 2072.1279296875\n", | |
| "loss: 2072.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for i in range(0, 1024, batch_size):\n", | |
| " Xt, yt = torch.tensor(X[i:i+batch_size], dtype=torch.float).cuda(), torch.tensor(y[i:i+batch_size], dtype=torch.float).cuda()\n", | |
| " print(\"data:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " pred = model(Xt)[:, 0]\n", | |
| " print(\"forw:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " loss = criterion(pred, yt)\n", | |
| " # print(loss)\n", | |
| " print(\"loss:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " loss.backward()\n", | |
| " print(\"back:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " optimizer.step()\n", | |
| " optimizer.zero_grad()\n", | |
| " print(\"step:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " print(\"=\" * 20)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.0" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "del optimizer, model, Xt, yt, loss, pred\n", | |
| "gc.collect()\n", | |
| "torch.cuda.memory_allocated() / 1024" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "## Custom Swith Version\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Swish(torch.autograd.Function):\n", | |
| " @staticmethod\n", | |
| " def forward(ctx, i):\n", | |
| " result = i * torch.sigmoid(i)\n", | |
| " ctx.save_for_backward(i)\n", | |
| " return result\n", | |
| "\n", | |
| " @staticmethod\n", | |
| " def backward(ctx, grad_output):\n", | |
| " i = ctx.saved_variables[0]\n", | |
| " sigmoid_i = torch.sigmoid(i)\n", | |
| " return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))\n", | |
| " \n", | |
| "class CustomSwish(nn.Module):\n", | |
| " def forward(self, input_tensor):\n", | |
| " return Swish.apply(input_tensor)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "536577.0" | |
| ] | |
| }, | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model = get_model(CustomSwish).cuda()\n", | |
| "optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n", | |
| "optimizer.zero_grad()\n", | |
| "torch.cuda.memory_allocated() / 1024" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "data: 524.12646484375\n", | |
| "forw: 1296.126953125\n", | |
| "loss: 1296.12744140625\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 1816.1279296875\n", | |
| "loss: 1816.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 1816.1279296875\n", | |
| "loss: 1816.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/ceshine/miniconda3/envs/deep/lib/python3.7/site-packages/ipykernel_launcher.py:10: DeprecationWarning: 'saved_variables' is deprecated; use 'saved_tensors'\n", | |
| " # Remove the CWD from sys.path while we load stuff.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "data: 1044.1279296875\n", | |
| "forw: 1816.1279296875\n", | |
| "loss: 1816.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 1816.1279296875\n", | |
| "loss: 1816.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 1816.1279296875\n", | |
| "loss: 1816.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 1816.1279296875\n", | |
| "loss: 1816.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n", | |
| "data: 1044.1279296875\n", | |
| "forw: 1816.1279296875\n", | |
| "loss: 1816.1279296875\n", | |
| "back: 1044.1279296875\n", | |
| "step: 1044.1279296875\n", | |
| "====================\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for i in range(0, 1024, batch_size):\n", | |
| " Xt, yt = torch.tensor(X[i:i+batch_size], dtype=torch.float).cuda(), torch.tensor(y[i:i+batch_size], dtype=torch.float).cuda()\n", | |
| " print(\"data:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " pred = model(Xt)[:, 0]\n", | |
| " print(\"forw:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " loss = criterion(pred, yt)\n", | |
| " # print(loss)\n", | |
| " print(\"loss:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " loss.backward()\n", | |
| " print(\"back:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " optimizer.step()\n", | |
| " optimizer.zero_grad()\n", | |
| " print(\"step:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
| " print(\"=\" * 20)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.0" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "del optimizer, model, Xt, yt, loss, pred\n", | |
| "gc.collect()\n", | |
| "torch.cuda.memory_allocated() / 1024" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment