Skip to content

Instantly share code, notes, and snippets.

@InfiniteCoder01
Last active January 19, 2026 04:35
Show Gist options
  • Select an option

  • Save InfiniteCoder01/d8e49d80e1a9b5f8085c3f6c5de55d85 to your computer and use it in GitHub Desktop.

Select an option

Save InfiniteCoder01/d8e49d80e1a9b5f8085c3f6c5de55d85 to your computer and use it in GitHub Desktop.
PyTorch CPU offloading for training and inference on low VRAM GPUs

PyTorch CPU offloading for training and inference on low VRAM GPUs

Use like this:

import offload

model.gradient_checkpointing_enable() # REQUIRED for this offloading script to work
offload.offload(model.layer1, model.layer2)
model.layer3.cuda() # Keep layer3 always on GPU, without offloading it (example)
offload.offload(model.layer4, merge_forward_backward=True) # Layer 4 won't be offloaded before backward pass. TURN THIS OFF FOR INFERENCE

# ...

model.train()
for batch in train_dataset:
    outputs = model(batch.cuda()) # Keep all inputs in GPU
    with offload.backward(): outputs.loss.backward() # offload.backward() is simiar to merge_forward_backward, saves extra loads/unloads, but not *strictly* necessary
    optimizer.step()
    optimizer.zero_grad()
import torch
from dataclasses import dataclass
from contextlib import contextmanager
# Helpful: https://github.com/jvatsal/torch-offloading/
enabled = True
device = torch.device('cuda')
stream = torch.cuda.Stream()
@dataclass
class CPUParam:
data: torch.Tensor
grad: torch.Tensor | None
ref_count: int
cpu_params: dict[torch.nn.Parameter, CPUParam] = {} # CPU copies of loaded model parameters
def load(param: torch.nn.Parameter, load_grad=True):
'''Load parameter from CPU'''
if param in cpu_params:
cpu_params[param].ref_count += 1
return
cpu_param = CPUParam(
data=param.data.cpu().pin_memory(),
grad=param.grad.cpu().pin_memory() if param.grad is not None else None,
ref_count=1,
)
cpu_params[param] = cpu_param
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(stream)
stream.wait_stream(current_stream)
with torch.cuda.stream(stream):
param.data = torch.empty_like(cpu_param.data, device=device)
param.data.copy_(cpu_param.data, non_blocking=True)
if load_grad and cpu_param.grad is not None:
param.grad = torch.empty_like(cpu_param.grad, device=device)
param.grad.copy_(cpu_param.grad, non_blocking=True)
else: param.grad = None
def unload(param: torch.nn.Parameter, sync_data=True, unload_grad=True):
'''Unload parameter to CPU'''
cpu_param = cpu_params[param]
cpu_param.ref_count -= 1
if cpu_param.ref_count > 0: return
del cpu_params[param]
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
if sync_data: cpu_param.data.copy_(param.data, non_blocking=True)
param.data = cpu_param.data
if unload_grad and param.grad is not None:
if cpu_param.grad is None: cpu_param.grad = torch.empty_like(param.grad, device='cpu').pin_memory()
assert cpu_param.grad is not None
cpu_param.grad.copy_(param.grad, non_blocking=True)
param.grad = cpu_param.grad
trace = False
def trace_log(module: torch.nn.Module, msg: str):
if not trace: return
if isinstance(trace, str):
if type(module).__name__ not in trace: return
print(f'{type(module).__name__} {msg}; VRAM usage: {torch.cuda.memory_allocated(0) / 1024 / 1024} MiB')
_merge_forward_backward = set()
inside_backward_pass = False
def merge_forward_backward(module: torch.nn.Module):
return module in _merge_forward_backward or inside_backward_pass
loaded_modules = set()
def forward_pre_hook(module: torch.nn.Module, inputs):
if not enabled: return
trace_log(module, 'forward_pre_hook')
del inputs
load_grad = merge_forward_backward(module)
for param in module.parameters(): load(param, load_grad=load_grad)
loaded_modules.add(module)
torch.cuda.synchronize(device)
def forward_post_hook(module: torch.nn.Module, inputs, outputs):
if not enabled or merge_forward_backward(module): return
trace_log(module, 'forward_post_hook')
del inputs, outputs
for param in module.parameters(): unload(param, sync_data=False, unload_grad=False) # NOTE: Assumes param was not modified during forward
loaded_modules.remove(module)
def backward_pre_hook(module: torch.nn.Module, grad_output):
if not enabled or module in loaded_modules: return
trace_log(module, 'backward_pre_hook')
del grad_output
for param in module.parameters(): load(param, load_grad=True)
loaded_modules.add(module)
def backward_post_hook(module: torch.nn.Module, grad_input, grad_output):
if not enabled: return
trace_log(module, 'backward_post_hook')
del grad_input, grad_output
for param in module.parameters():
if not param.requires_grad: unload(param, unload_grad=True)
loaded_modules.remove(module)
def make_post_accumulate_grad_hook(module: torch.nn.Module):
def post_accumulate_grad_hook(param: torch.nn.Parameter):
if not enabled: return
trace_log(module, 'post_accumulate_grad_hook') # Note: This fires multiple times, once per param, and is also a bit wrong for tied layers
unload(param, unload_grad=True)
torch.cuda.synchronize(device)
return post_accumulate_grad_hook
registered_parameters = set()
def offload(*modules: torch.nn.Module, merge_forward_backward=False):
'''
Setup CPU offloading for those modules.
If merge_forward_backward is True, the module will be loaded before forward pass with gradients and won't be unloaded until backward is finished.
Useful for the last module in the chain during training
'''
for module in modules:
module.cpu()
if merge_forward_backward: _merge_forward_backward.add(module)
module.register_forward_pre_hook(forward_pre_hook)
module.register_forward_hook(forward_post_hook)
module.register_full_backward_pre_hook(backward_pre_hook)
module.register_full_backward_hook(backward_post_hook)
post_accumulate_grad_hook = make_post_accumulate_grad_hook(module)
for param in module.parameters():
if param in registered_parameters: continue # Helps with tied modules
registered_parameters.add(param)
if param.requires_grad:
param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
def move(data, device, non_blocking=False):
if isinstance(data, tuple):
return tuple(v.to(device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v for v in data)
elif isinstance(data, dict):
return {key: val.to(device, non_blocking=non_blocking) if isinstance(val, torch.Tensor) else val for key, val in data.items()}
elif isinstance(data, torch.Tensor):
return data.to(device, non_blocking=non_blocking)
@contextmanager
def no_offload():
'''Use this context manager to disable offloading temproarily'''
global enabled
last, enabled = enabled, False
try: yield
finally: enabled = last
@contextmanager
def backward():
'''
Wrap your `loss.backward()` with this context manager
to skip unnecessary unloads and speed up backward pass (only with gradient checkpointing)
'''
global inside_backward_pass
last, inside_backward_pass = inside_backward_pass, True
try: yield
finally: inside_backward_pass = last
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment