Skip to content

Instantly share code, notes, and snippets.

@egafni
Created August 5, 2021 23:45
Show Gist options
  • Select an option

  • Save egafni/b3ffe9bb1c4986fd0a86afcbe6e10236 to your computer and use it in GitHub Desktop.

Select an option

Save egafni/b3ffe9bb1c4986fd0a86afcbe6e10236 to your computer and use it in GitHub Desktop.
minimal pytorch lightning
"""
barebones replacement module for pytorch_lightning which can't be installed with tensorflow1 because of a
tensorboard dependency
"""
import abc
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
class LitModule:
def __init__(self):
super().__init__()
self.logger = SummaryWriter()
self.prog_bar_dict = dict()
def log(self, name, value, step, prog_bar=True):
self.logger.add_scalar(name, value, step)
if prog_bar:
self.prog_bar_dict[name] = value
@abc.abstractmethod
def forward(self, batch):
pass
@abc.abstractmethod
def configure_optimizers(self):
pass
@abc.abstractmethod
def training_step(self, batch, batch_idx):
# return loss
pass
def validation_step(self, batch, batch_idx):
pass
def test_step(self, batch, batch_idx):
pass
def cuda(self):
self.model = self.model.cuda()
return self
class LitDataModule:
def __init__(self):
pass
def prepare_data(self):
pass
def setup(self, stage=None):
pass
def train_dataloader(self):
pass
def val_dataloader(self):
pass
def test_dataloader(self):
pass
class LitTrainer:
def __init__(self, max_epochs, gpus=None):
self.max_epochs = max_epochs
self.gpus = gpus
self.global_step = 0
def fit(self, lit_module: LitModule, datamodule: LitDataModule):
if self.gpus:
lit_module = lit_module.cuda()
# lightning_module.train()
torch.set_grad_enabled(True)
losses = []
datamodule.prepare_data()
datamodule.setup()
train_dataloader = datamodule.train_dataloader()
optimizer, scheduler = lit_module.configure_optimizers()
optimizer, scheulder = optimizer[0], scheduler[0]
for epoch in range(self.max_epochs):
with tqdm(train_dataloader, unit="batch") as pbar:
metrics = " ".join(
f"[{k}={v}]" for k, v in lit_module.prog_bar_dict.items()
)
pbar.set_description(f"[e{epoch}] {metrics}")
for batch_idx, batch in enumerate(pbar):
self.global_step += 1
# forward
if self.gpus:
batch = [x.cuda() for x in batch]
loss = lit_module.training_step(batch, batch_idx)
losses.append(loss.detach())
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment