|
import torch |
|
from dataclasses import dataclass |
|
from contextlib import contextmanager |
|
# Helpful: https://github.com/jvatsal/torch-offloading/ |
|
|
|
enabled = True |
|
device = torch.device('cuda') |
|
stream = torch.cuda.Stream() |
|
|
|
@dataclass |
|
class CPUParam: |
|
data: torch.Tensor |
|
grad: torch.Tensor | None |
|
ref_count: int |
|
|
|
cpu_params: dict[torch.nn.Parameter, CPUParam] = {} # CPU copies of loaded model parameters |
|
def load(param: torch.nn.Parameter, load_grad=True): |
|
'''Load parameter from CPU''' |
|
if param in cpu_params: |
|
cpu_params[param].ref_count += 1 |
|
return |
|
|
|
cpu_param = CPUParam( |
|
data=param.data.cpu().pin_memory(), |
|
grad=param.grad.cpu().pin_memory() if param.grad is not None else None, |
|
ref_count=1, |
|
) |
|
cpu_params[param] = cpu_param |
|
|
|
current_stream = torch.cuda.current_stream() |
|
current_stream.wait_stream(stream) |
|
stream.wait_stream(current_stream) |
|
|
|
with torch.cuda.stream(stream): |
|
param.data = torch.empty_like(cpu_param.data, device=device) |
|
param.data.copy_(cpu_param.data, non_blocking=True) |
|
if load_grad and cpu_param.grad is not None: |
|
param.grad = torch.empty_like(cpu_param.grad, device=device) |
|
param.grad.copy_(cpu_param.grad, non_blocking=True) |
|
else: param.grad = None |
|
|
|
def unload(param: torch.nn.Parameter, sync_data=True, unload_grad=True): |
|
'''Unload parameter to CPU''' |
|
cpu_param = cpu_params[param] |
|
cpu_param.ref_count -= 1 |
|
if cpu_param.ref_count > 0: return |
|
del cpu_params[param] |
|
|
|
stream.wait_stream(torch.cuda.current_stream()) |
|
with torch.cuda.stream(stream): |
|
if sync_data: cpu_param.data.copy_(param.data, non_blocking=True) |
|
param.data = cpu_param.data |
|
if unload_grad and param.grad is not None: |
|
if cpu_param.grad is None: cpu_param.grad = torch.empty_like(param.grad, device='cpu').pin_memory() |
|
assert cpu_param.grad is not None |
|
cpu_param.grad.copy_(param.grad, non_blocking=True) |
|
param.grad = cpu_param.grad |
|
|
|
trace = False |
|
def trace_log(module: torch.nn.Module, msg: str): |
|
if not trace: return |
|
if isinstance(trace, str): |
|
if type(module).__name__ not in trace: return |
|
print(f'{type(module).__name__} {msg}; VRAM usage: {torch.cuda.memory_allocated(0) / 1024 / 1024} MiB') |
|
|
|
_merge_forward_backward = set() |
|
inside_backward_pass = False |
|
def merge_forward_backward(module: torch.nn.Module): |
|
return module in _merge_forward_backward or inside_backward_pass |
|
|
|
loaded_modules = set() |
|
def forward_pre_hook(module: torch.nn.Module, inputs): |
|
if not enabled: return |
|
trace_log(module, 'forward_pre_hook') |
|
|
|
del inputs |
|
load_grad = merge_forward_backward(module) |
|
for param in module.parameters(): load(param, load_grad=load_grad) |
|
loaded_modules.add(module) |
|
torch.cuda.synchronize(device) |
|
|
|
def forward_post_hook(module: torch.nn.Module, inputs, outputs): |
|
if not enabled or merge_forward_backward(module): return |
|
trace_log(module, 'forward_post_hook') |
|
|
|
del inputs, outputs |
|
for param in module.parameters(): unload(param, sync_data=False, unload_grad=False) # NOTE: Assumes param was not modified during forward |
|
loaded_modules.remove(module) |
|
|
|
def backward_pre_hook(module: torch.nn.Module, grad_output): |
|
if not enabled or module in loaded_modules: return |
|
trace_log(module, 'backward_pre_hook') |
|
|
|
del grad_output |
|
for param in module.parameters(): load(param, load_grad=True) |
|
loaded_modules.add(module) |
|
|
|
def backward_post_hook(module: torch.nn.Module, grad_input, grad_output): |
|
if not enabled: return |
|
trace_log(module, 'backward_post_hook') |
|
|
|
del grad_input, grad_output |
|
for param in module.parameters(): |
|
if not param.requires_grad: unload(param, unload_grad=True) |
|
loaded_modules.remove(module) |
|
|
|
def make_post_accumulate_grad_hook(module: torch.nn.Module): |
|
def post_accumulate_grad_hook(param: torch.nn.Parameter): |
|
if not enabled: return |
|
trace_log(module, 'post_accumulate_grad_hook') # Note: This fires multiple times, once per param, and is also a bit wrong for tied layers |
|
unload(param, unload_grad=True) |
|
torch.cuda.synchronize(device) |
|
|
|
return post_accumulate_grad_hook |
|
|
|
registered_parameters = set() |
|
def offload(*modules: torch.nn.Module, merge_forward_backward=False): |
|
''' |
|
Setup CPU offloading for those modules. |
|
If merge_forward_backward is True, the module will be loaded before forward pass with gradients and won't be unloaded until backward is finished. |
|
Useful for the last module in the chain during training |
|
''' |
|
for module in modules: |
|
module.cpu() |
|
if merge_forward_backward: _merge_forward_backward.add(module) |
|
module.register_forward_pre_hook(forward_pre_hook) |
|
module.register_forward_hook(forward_post_hook) |
|
module.register_full_backward_pre_hook(backward_pre_hook) |
|
module.register_full_backward_hook(backward_post_hook) |
|
post_accumulate_grad_hook = make_post_accumulate_grad_hook(module) |
|
for param in module.parameters(): |
|
if param in registered_parameters: continue # Helps with tied modules |
|
registered_parameters.add(param) |
|
|
|
if param.requires_grad: |
|
param.register_post_accumulate_grad_hook(post_accumulate_grad_hook) |
|
|
|
def move(data, device, non_blocking=False): |
|
if isinstance(data, tuple): |
|
return tuple(v.to(device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v for v in data) |
|
elif isinstance(data, dict): |
|
return {key: val.to(device, non_blocking=non_blocking) if isinstance(val, torch.Tensor) else val for key, val in data.items()} |
|
elif isinstance(data, torch.Tensor): |
|
return data.to(device, non_blocking=non_blocking) |
|
|
|
@contextmanager |
|
def no_offload(): |
|
'''Use this context manager to disable offloading temproarily''' |
|
global enabled |
|
last, enabled = enabled, False |
|
try: yield |
|
finally: enabled = last |
|
|
|
@contextmanager |
|
def backward(): |
|
''' |
|
Wrap your `loss.backward()` with this context manager |
|
to skip unnecessary unloads and speed up backward pass (only with gradient checkpointing) |
|
''' |
|
global inside_backward_pass |
|
last, inside_backward_pass = inside_backward_pass, True |
|
try: yield |
|
finally: inside_backward_pass = last |