Created
August 5, 2021 23:45
-
-
Save egafni/b3ffe9bb1c4986fd0a86afcbe6e10236 to your computer and use it in GitHub Desktop.
minimal pytorch lightning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| barebones replacement module for pytorch_lightning which can't be installed with tensorflow1 because of a | |
| tensorboard dependency | |
| """ | |
| import abc | |
| import torch | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| class LitModule: | |
| def __init__(self): | |
| super().__init__() | |
| self.logger = SummaryWriter() | |
| self.prog_bar_dict = dict() | |
| def log(self, name, value, step, prog_bar=True): | |
| self.logger.add_scalar(name, value, step) | |
| if prog_bar: | |
| self.prog_bar_dict[name] = value | |
| @abc.abstractmethod | |
| def forward(self, batch): | |
| pass | |
| @abc.abstractmethod | |
| def configure_optimizers(self): | |
| pass | |
| @abc.abstractmethod | |
| def training_step(self, batch, batch_idx): | |
| # return loss | |
| pass | |
| def validation_step(self, batch, batch_idx): | |
| pass | |
| def test_step(self, batch, batch_idx): | |
| pass | |
| def cuda(self): | |
| self.model = self.model.cuda() | |
| return self | |
| class LitDataModule: | |
| def __init__(self): | |
| pass | |
| def prepare_data(self): | |
| pass | |
| def setup(self, stage=None): | |
| pass | |
| def train_dataloader(self): | |
| pass | |
| def val_dataloader(self): | |
| pass | |
| def test_dataloader(self): | |
| pass | |
| class LitTrainer: | |
| def __init__(self, max_epochs, gpus=None): | |
| self.max_epochs = max_epochs | |
| self.gpus = gpus | |
| self.global_step = 0 | |
| def fit(self, lit_module: LitModule, datamodule: LitDataModule): | |
| if self.gpus: | |
| lit_module = lit_module.cuda() | |
| # lightning_module.train() | |
| torch.set_grad_enabled(True) | |
| losses = [] | |
| datamodule.prepare_data() | |
| datamodule.setup() | |
| train_dataloader = datamodule.train_dataloader() | |
| optimizer, scheduler = lit_module.configure_optimizers() | |
| optimizer, scheulder = optimizer[0], scheduler[0] | |
| for epoch in range(self.max_epochs): | |
| with tqdm(train_dataloader, unit="batch") as pbar: | |
| metrics = " ".join( | |
| f"[{k}={v}]" for k, v in lit_module.prog_bar_dict.items() | |
| ) | |
| pbar.set_description(f"[e{epoch}] {metrics}") | |
| for batch_idx, batch in enumerate(pbar): | |
| self.global_step += 1 | |
| # forward | |
| if self.gpus: | |
| batch = [x.cuda() for x in batch] | |
| loss = lit_module.training_step(batch, batch_idx) | |
| losses.append(loss.detach()) | |
| # clear gradients | |
| optimizer.zero_grad() | |
| # backward | |
| loss.backward() | |
| # update parameters | |
| optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment