PyTorch CPU offloading for training and inference on low VRAM GPUs
Use like this:
import offload
model.gradient_checkpointing_enable() # REQUIRED for this offloading script to work
offload.offload(model.layer1, model.layer2)
model.layer3.cuda() # Keep layer3 always on GPU, without offloading it (example)
offload.offload(model.layer4, merge_forward_backward=True) # Layer 4 won't be offloaded before backward pass. TURN THIS OFF FOR INFERENCE