Skip to content

Instantly share code, notes, and snippets.

@borisdayma
Created June 3, 2020 17:41
Show Gist options
  • Select an option

  • Save borisdayma/6a8222c8be62d54d6648972a84aa6786 to your computer and use it in GitHub Desktop.

Select an option

Save borisdayma/6a8222c8be62d54d6648972a84aa6786 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from argparse import Namespace\n",
"import math\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from torch.nn import functional as F\n",
"from torch.utils.data import DataLoader, random_split,Dataset,SubsetRandomSampler,ConcatDataset\n",
"from torch.optim import Adam\n",
"from torch import Tensor\n",
"from torch.autograd import Variable\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.loggers import WandbLogger\n",
"import wandb\n",
"\n",
"class DatasetLoader(Dataset):\n",
" def __init__(self):\n",
" self.data = np.random.randn(100,768)\n",
"\n",
" def __len__(self):\n",
" return self.data.shape[0]\n",
"\n",
" def __getitem__(self, item):\n",
" return self.data[item].astype(np.float32)\n",
"class AutoEncoder(pl.LightningModule):\n",
" def __init__(self,hparams):\n",
" super().__init__()\n",
" self.hparams = hparams\n",
" hiddenDims = [384,192]\n",
" hiddenDims = hiddenDims[:self.hparams.hdims+1]\n",
" modules = []\n",
" inDim = 768\n",
" for hDim in hiddenDims:\n",
" modules.append(\n",
" nn.Sequential(\n",
" nn.Linear(inDim, hDim),\n",
" nn.ReLU())\n",
" )\n",
" inDim = hDim\n",
" self.encoder = nn.Sequential(*modules)\n",
" modules = []\n",
" hiddenDims.reverse()\n",
" hiddenDims.append(768)\n",
" for hDim in hiddenDims[1:]:\n",
" print(inDim,hDim)\n",
" modules.append(\n",
" nn.Sequential(\n",
" nn.Linear(inDim, hDim),\n",
" nn.ReLU())\n",
" )\n",
" inDim = hDim\n",
" self.decoder = nn.Sequential(*modules)\n",
"\n",
" def forward(self, x):\n",
" encoded = self.encoder(x)\n",
" decoded = self.decoder(encoded)\n",
"\n",
" return encoded,decoded\n",
"\n",
" def prepare_data(self):\n",
" data = DatasetLoader() \n",
" self.train_data = data[:]\n",
"\n",
" def train_valid_loaders(self,dataset, valid_fraction = 0.1, **kwargs):\n",
" num_train = len(dataset)\n",
" indices = list(range(num_train))\n",
" split = int(math.floor(valid_fraction* num_train))\n",
" np.random.seed(17)\n",
" np.random.shuffle(indices)\n",
" if 'num_workers' not in kwargs:\n",
" kwargs['num_workers'] = 1\n",
"\n",
" train_idx, valid_idx = indices[split:], indices[:split]\n",
" train_sampler = SubsetRandomSampler(train_idx)\n",
" valid_sampler = SubsetRandomSampler(valid_idx)\n",
"\n",
" train_loader = DataLoader(dataset,sampler=train_sampler,pin_memory=True,\n",
" **kwargs)\n",
" valid_loader = DataLoader(dataset,sampler=valid_sampler,pin_memory=True,\n",
" **kwargs)\n",
" return train_loader, valid_loader\n",
" def train_dataloader(self):\n",
" train_loader,_ = self.train_valid_loaders(self.train_data, valid_fraction = 0.1,batch_size=self.hparams.bs)\n",
" return train_loader\n",
"\n",
" def configure_optimizers(self):\n",
" return Adam(self.parameters(), lr = self.hparams.lr,weight_decay = self.hparams.wd)\n",
" return dataTrain\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" x = batch\n",
" encoded,decoded = self(x) \n",
" loss = F.mse_loss(decoded,x)\n",
" logs = {'train_loss': loss}\n",
" return {'loss': loss,'log':logs}\n",
"\n",
" def val_dataloader(self):\n",
" _,valid_loader = self.train_valid_loaders(self.train_data,batch_size=640)\n",
" return valid_loader\n",
"\n",
" def validation_step(self, batch, batch_idx):\n",
" x = batch\n",
" encoded,decoded = self(x)\n",
" loss = F.mse_loss(decoded,x)\n",
" logs = {'val_loss': loss}\n",
" return {'val_loss': loss,'log':logs}\n",
"\n",
" def validation_epoch_end(self, outputs):\n",
" avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n",
" tensorboard_logs = {'val_loss': avg_loss,'step': self.current_epoch}\n",
" return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}\n",
"\n",
"# random data\n",
"x = np.random.random_sample((3, 768))\n",
"x = torch.tensor(x).float()\n",
"hparams = Namespace(\n",
" lr = 1e-3,\n",
" wd = 1e-5,\n",
" hdims = 6,\n",
" bs = 64)\n",
"\n",
"\n",
"# Init sweep\n",
"sweep_config = {\n",
" 'method': 'random', #grid, random\n",
" 'metric': {\n",
" 'name': 'val_loss',\n",
" 'goal': 'minimize' \n",
" },\n",
" 'parameters': {\n",
" 'hdims': {\n",
" 'values':[1,2]\n",
" },\n",
" }\n",
"}\n",
"sweep_id = wandb.sweep(sweep_config, project='bug-sweep')\n",
"\n",
"def train():\n",
" config_defaults = {\n",
" 'lr': 1e-3,\n",
" 'epochs':3,\n",
" 'wd':1e-5,\n",
" 'bs':64\n",
"}\n",
" run = wandb.init(config = config_defaults, reinit=False)\n",
" config = run.config\n",
" hparams = Namespace(\n",
" lr = config['lr'],\n",
" bs = config['bs'],\n",
" hdims = config['hdims'],\n",
" wd = config['wd']\n",
" )\n",
" hdims = config['hdims']\n",
" wandb_logger = WandbLogger() \n",
"\n",
" wandb_logger.log_hyperparams(hparams)\n",
" model = AutoEncoder(hparams)\n",
" trainer = pl.Trainer(\n",
" logger=wandb_logger,\n",
" max_epochs=3)\n",
" trainer.fit(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wandb.agent(sweep_id, function=train)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment