Created
July 8, 2020 19:00
-
-
Save borisdayma/79961bd18d278c11271b50730abd376c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from fastai2.tabular.all import *\n", | |
| "\n", | |
| "import wandb\n", | |
| "from fastai2.callback.wandb import *" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "wandb.login()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "path = untar_data(URLs.ADULT_SAMPLE)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "df = pd.read_csv(path/'adult.csv')\n", | |
| "df.head()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, y_names=\"salary\",\n", | |
| " cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race'],\n", | |
| " cont_names = ['age', 'fnlwgt', 'education-num'],\n", | |
| " procs = [Categorify, FillMissing, Normalize])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "dls.show_batch()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "wandb.init(entity='borisd13', project='demo_config', tags=['tabular'])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "WandbCallback has been modified for this experimentation:\n", | |
| "\n", | |
| "* removed try/except so that we don't silently fail\n", | |
| "* print some debug messages\n", | |
| "* convert test_items to Dataframe prior to passing to test_dl" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Modified for experimentation on TabularDataLoaders\n", | |
| "\n", | |
| "from fastai2.callback.wandb import *\n", | |
| "from fastai2.callback.wandb import _format_config\n", | |
| "\n", | |
| "class WandbCallback(Callback):\n", | |
| " \"Saves model topology, losses & metrics\"\n", | |
| " toward_end,remove_on_fetch,run_after = True,True,FetchPredsCallback\n", | |
| " # Record if watch has been called previously (even in another instance)\n", | |
| " _wandb_watch_called = False\n", | |
| "\n", | |
| " def __init__(self, log=\"gradients\", log_preds=True, valid_dl=None, n_preds=36, seed=12345):\n", | |
| " # Check if wandb.init has been called\n", | |
| " if wandb.run is None:\n", | |
| " raise ValueError('You must call wandb.init() before WandbCallback()')\n", | |
| " # W&B log step\n", | |
| " self._wandb_step = wandb.run.step - 1 # -1 except if the run has previously logged data (incremented at each batch)\n", | |
| " self._wandb_epoch = 0 if not(wandb.run.step) else math.ceil(wandb.run.summary['epoch']) # continue to next epoch\n", | |
| " store_attr(self, 'log,log_preds,valid_dl,n_preds,seed')\n", | |
| "\n", | |
| " def begin_fit(self):\n", | |
| " \"Call watch method to log model topology, gradients & weights\"\n", | |
| " self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, \"gather_preds\") and rank_distrib()==0\n", | |
| " if not self.run: return\n", | |
| " \n", | |
| " # Log config parameters\n", | |
| " log_config = self.learn.gather_args()\n", | |
| " _format_config(log_config)\n", | |
| " # Log all parameters at once\n", | |
| " wandb.config.update(log_config, allow_val_change=True)\n", | |
| " \n", | |
| " if not WandbCallback._wandb_watch_called:\n", | |
| " WandbCallback._wandb_watch_called = True\n", | |
| " # Logs model topology and optionally gradients and weights\n", | |
| " wandb.watch(self.learn.model, log=self.log)\n", | |
| "\n", | |
| " if hasattr(self, 'save_model'): self.save_model.add_save = Path(wandb.run.dir)/'bestmodel.pth'\n", | |
| "\n", | |
| " if self.log_preds:\n", | |
| " if not self.valid_dl:\n", | |
| " #Initializes the batch watched\n", | |
| " wandbRandom = random.Random(self.seed) # For repeatability\n", | |
| " self.n_preds = min(self.n_preds, len(self.dls.valid_ds))\n", | |
| " idxs = wandbRandom.sample(range(len(self.dls.valid_ds)), self.n_preds)\n", | |
| " test_items = [getattr(self.dls.valid_ds.items, 'iloc', self.dls.valid_ds.items)[i] for i in idxs]\n", | |
| " \n", | |
| " # debug\n", | |
| " print(f'{type(test_items)=}\\n')\n", | |
| " print(f'{type(test_items[0])=}\\n')\n", | |
| " print(f'{test_items[0]=}\\n')\n", | |
| " \n", | |
| " test_items = pd.DataFrame(test_items)\n", | |
| " self.valid_dl = self.dls.test_dl(test_items, with_labels=True)\n", | |
| "\n", | |
| " self.learn.add_cb(FetchPredsCallback(dl=self.valid_dl, with_input=True, with_decoded=True))\n", | |
| "\n", | |
| " def after_batch(self):\n", | |
| " \"Log hyper-parameters and training loss\"\n", | |
| " if self.training:\n", | |
| " self._wandb_step += 1\n", | |
| " self._wandb_epoch += 1/self.n_iter\n", | |
| " hypers = {f'{k}_{i}':v for i,h in enumerate(self.opt.hypers) for k,v in h.items()}\n", | |
| " wandb.log({'epoch': self._wandb_epoch, 'train_loss': self.smooth_loss, 'raw_loss': self.loss, **hypers}, step=self._wandb_step)\n", | |
| "\n", | |
| " def after_epoch(self):\n", | |
| " \"Log validation loss and custom metrics & log prediction samples\"\n", | |
| " # Correct any epoch rounding error and overwrite value\n", | |
| " self._wandb_epoch = round(self._wandb_epoch)\n", | |
| " wandb.log({'epoch': self._wandb_epoch}, step=self._wandb_step)\n", | |
| " # Log sample predictions\n", | |
| " if self.log_preds:\n", | |
| " inp,preds,targs,out = self.learn.fetch_preds.preds\n", | |
| " b = tuplify(inp) + tuplify(targs)\n", | |
| " x,y,its,outs = self.valid_dl.show_results(b, out, show=False, max_n=self.n_preds)\n", | |
| " wandb.log(wandb_process(x, y, its, outs), step=self._wandb_step)\n", | |
| " wandb.log({n:s for n,s in zip(self.recorder.metric_names, self.recorder.log) if n not in ['train_loss', 'epoch', 'time']}, step=self._wandb_step)\n", | |
| "\n", | |
| " def after_fit(self):\n", | |
| " self.run = True\n", | |
| " if self.log_preds: self.remove_cb(FetchPredsCallback)\n", | |
| " wandb.log({}) # ensure sync of last step" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "learn = tabular_learner(dls, metrics=accuracy, cbs=[WandbCallback()])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Getting predictions works." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "learn.get_preds()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "But we cannot get predictions during the training loop.\n", | |
| "\n", | |
| "Note: The message `Could not gather input dimensions` is related on not being able to get tensor dimensions to log as hyper-parameters and is not really relevant here." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "learn.fit(1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.2" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment