Created
June 3, 2020 17:41
-
-
Save borisdayma/6a8222c8be62d54d6648972a84aa6786 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from argparse import Namespace\n", | |
| "import math\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "\n", | |
| "import torch\n", | |
| "from torch import nn\n", | |
| "from torch.nn import functional as F\n", | |
| "from torch.utils.data import DataLoader, random_split,Dataset,SubsetRandomSampler,ConcatDataset\n", | |
| "from torch.optim import Adam\n", | |
| "from torch import Tensor\n", | |
| "from torch.autograd import Variable\n", | |
| "import pytorch_lightning as pl\n", | |
| "from pytorch_lightning.loggers import WandbLogger\n", | |
| "import wandb\n", | |
| "\n", | |
| "class DatasetLoader(Dataset):\n", | |
| " def __init__(self):\n", | |
| " self.data = np.random.randn(100,768)\n", | |
| "\n", | |
| " def __len__(self):\n", | |
| " return self.data.shape[0]\n", | |
| "\n", | |
| " def __getitem__(self, item):\n", | |
| " return self.data[item].astype(np.float32)\n", | |
| "class AutoEncoder(pl.LightningModule):\n", | |
| " def __init__(self,hparams):\n", | |
| " super().__init__()\n", | |
| " self.hparams = hparams\n", | |
| " hiddenDims = [384,192]\n", | |
| " hiddenDims = hiddenDims[:self.hparams.hdims+1]\n", | |
| " modules = []\n", | |
| " inDim = 768\n", | |
| " for hDim in hiddenDims:\n", | |
| " modules.append(\n", | |
| " nn.Sequential(\n", | |
| " nn.Linear(inDim, hDim),\n", | |
| " nn.ReLU())\n", | |
| " )\n", | |
| " inDim = hDim\n", | |
| " self.encoder = nn.Sequential(*modules)\n", | |
| " modules = []\n", | |
| " hiddenDims.reverse()\n", | |
| " hiddenDims.append(768)\n", | |
| " for hDim in hiddenDims[1:]:\n", | |
| " print(inDim,hDim)\n", | |
| " modules.append(\n", | |
| " nn.Sequential(\n", | |
| " nn.Linear(inDim, hDim),\n", | |
| " nn.ReLU())\n", | |
| " )\n", | |
| " inDim = hDim\n", | |
| " self.decoder = nn.Sequential(*modules)\n", | |
| "\n", | |
| " def forward(self, x):\n", | |
| " encoded = self.encoder(x)\n", | |
| " decoded = self.decoder(encoded)\n", | |
| "\n", | |
| " return encoded,decoded\n", | |
| "\n", | |
| " def prepare_data(self):\n", | |
| " data = DatasetLoader() \n", | |
| " self.train_data = data[:]\n", | |
| "\n", | |
| " def train_valid_loaders(self,dataset, valid_fraction = 0.1, **kwargs):\n", | |
| " num_train = len(dataset)\n", | |
| " indices = list(range(num_train))\n", | |
| " split = int(math.floor(valid_fraction* num_train))\n", | |
| " np.random.seed(17)\n", | |
| " np.random.shuffle(indices)\n", | |
| " if 'num_workers' not in kwargs:\n", | |
| " kwargs['num_workers'] = 1\n", | |
| "\n", | |
| " train_idx, valid_idx = indices[split:], indices[:split]\n", | |
| " train_sampler = SubsetRandomSampler(train_idx)\n", | |
| " valid_sampler = SubsetRandomSampler(valid_idx)\n", | |
| "\n", | |
| " train_loader = DataLoader(dataset,sampler=train_sampler,pin_memory=True,\n", | |
| " **kwargs)\n", | |
| " valid_loader = DataLoader(dataset,sampler=valid_sampler,pin_memory=True,\n", | |
| " **kwargs)\n", | |
| " return train_loader, valid_loader\n", | |
| " def train_dataloader(self):\n", | |
| " train_loader,_ = self.train_valid_loaders(self.train_data, valid_fraction = 0.1,batch_size=self.hparams.bs)\n", | |
| " return train_loader\n", | |
| "\n", | |
| " def configure_optimizers(self):\n", | |
| " return Adam(self.parameters(), lr = self.hparams.lr,weight_decay = self.hparams.wd)\n", | |
| " return dataTrain\n", | |
| "\n", | |
| " def training_step(self, batch, batch_idx):\n", | |
| " x = batch\n", | |
| " encoded,decoded = self(x) \n", | |
| " loss = F.mse_loss(decoded,x)\n", | |
| " logs = {'train_loss': loss}\n", | |
| " return {'loss': loss,'log':logs}\n", | |
| "\n", | |
| " def val_dataloader(self):\n", | |
| " _,valid_loader = self.train_valid_loaders(self.train_data,batch_size=640)\n", | |
| " return valid_loader\n", | |
| "\n", | |
| " def validation_step(self, batch, batch_idx):\n", | |
| " x = batch\n", | |
| " encoded,decoded = self(x)\n", | |
| " loss = F.mse_loss(decoded,x)\n", | |
| " logs = {'val_loss': loss}\n", | |
| " return {'val_loss': loss,'log':logs}\n", | |
| "\n", | |
| " def validation_epoch_end(self, outputs):\n", | |
| " avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n", | |
| " tensorboard_logs = {'val_loss': avg_loss,'step': self.current_epoch}\n", | |
| " return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}\n", | |
| "\n", | |
| "# random data\n", | |
| "x = np.random.random_sample((3, 768))\n", | |
| "x = torch.tensor(x).float()\n", | |
| "hparams = Namespace(\n", | |
| " lr = 1e-3,\n", | |
| " wd = 1e-5,\n", | |
| " hdims = 6,\n", | |
| " bs = 64)\n", | |
| "\n", | |
| "\n", | |
| "# Init sweep\n", | |
| "sweep_config = {\n", | |
| " 'method': 'random', #grid, random\n", | |
| " 'metric': {\n", | |
| " 'name': 'val_loss',\n", | |
| " 'goal': 'minimize' \n", | |
| " },\n", | |
| " 'parameters': {\n", | |
| " 'hdims': {\n", | |
| " 'values':[1,2]\n", | |
| " },\n", | |
| " }\n", | |
| "}\n", | |
| "sweep_id = wandb.sweep(sweep_config, project='bug-sweep')\n", | |
| "\n", | |
| "def train():\n", | |
| " config_defaults = {\n", | |
| " 'lr': 1e-3,\n", | |
| " 'epochs':3,\n", | |
| " 'wd':1e-5,\n", | |
| " 'bs':64\n", | |
| "}\n", | |
| " run = wandb.init(config = config_defaults, reinit=False)\n", | |
| " config = run.config\n", | |
| " hparams = Namespace(\n", | |
| " lr = config['lr'],\n", | |
| " bs = config['bs'],\n", | |
| " hdims = config['hdims'],\n", | |
| " wd = config['wd']\n", | |
| " )\n", | |
| " hdims = config['hdims']\n", | |
| " wandb_logger = WandbLogger() \n", | |
| "\n", | |
| " wandb_logger.log_hyperparams(hparams)\n", | |
| " model = AutoEncoder(hparams)\n", | |
| " trainer = pl.Trainer(\n", | |
| " logger=wandb_logger,\n", | |
| " max_epochs=3)\n", | |
| " trainer.fit(model)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "wandb.agent(sweep_id, function=train)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment