Skip to content

Instantly share code, notes, and snippets.

@rodjjo
Last active November 29, 2025 15:37
Show Gist options
  • Select an option

  • Save rodjjo/20e2e842fea9ed58114adb560a4566b6 to your computer and use it in GitHub Desktop.

Select an option

Save rodjjo/20e2e842fea9ed58114adb560a4566b6 to your computer and use it in GitHub Desktop.
Keep an ammount of layers of the model on the gpu
import torch
import gc
class PartialOffloadMixin:
'''
usage example:
class MyQwen3ForCausalLM(Qwen3ForCausalLM, PartialOffloadMixin):
LAYERS_KEEP_GPU = 22
MODEL_ATTR_NAME = "model"
MODEL_LAYERS_ATTR_NAME = "layers"
OFFLOAD_ON_CALL = True
model = MyQwen3ForCausalLM.from_pretrained(
repo_id,
subfolder="text_encoder",
local_files_only=True,
torch_dtype=torch.bfloat16,
)
model.eval()
model.enable_partial_cpu_offload()
# pseudo code of inference
result = model(...) # call was overrided and calls go_gpu(True) go_gpu(False)
example transformer:
class MyZImageTransformer(ZImageTransformer2DModel, PartialOffloadMixin):
MODEL_LAYERS_ATTR_NAME = "layers"
LAYERS_KEEP_GPU = 22
model = MyZImageTransformer.from_pretrained(
repo_id,
subfolder="transformer",
torch_dtype=torch.bfloat16,
)
model.eval()
model.enable_partial_cpu_offload()
# denoise step
model.go_gpu(True)
while denoising: #pseudo code
predicted = model(...)
model.go_gpu(False)
# vae decode, etc
'''
LAYERS_KEEP_GPU = 22
MODEL_ATTR_NAME = ""
MODEL_LAYERS_ATTR_NAME = "layers"
OFFLOAD_ON_CALL = False
def __call__(self, *args, **kwds):
if self.OFFLOAD_ON_CALL:
self.go_gpu(to_gpu=True)
result = super().__call__(*args, **kwds)
if self.OFFLOAD_ON_CALL:
self.go_gpu(to_gpu=False)
return result
def generate(self, *args, **kwds):
if self.OFFLOAD_ON_CALL:
self.go_gpu(to_gpu=True)
result = super().generate(*args, **kwds)
if self.OFFLOAD_ON_CALL:
self.go_gpu(to_gpu=False)
return result
def get_model(self):
if not self.MODEL_ATTR_NAME:
return self
return getattr(self, self.MODEL_ATTR_NAME)
def all_nn_modules_to_device(self, device):
'''
check the class model for its attributes, and move all nn.Module attributes to the specified device
'''
model = self.get_model()
layers_obj = getattr(model, self.MODEL_LAYERS_ATTR_NAME)
setattr(model, self.MODEL_LAYERS_ATTR_NAME, torch.nn.ModuleList())
if model != self:
self.to(device)
model.to(device)
setattr(model, self.MODEL_LAYERS_ATTR_NAME, layers_obj)
def offload_layers_to_device(self, device):
model = self.get_model()
layers_obj = getattr(model, self.MODEL_LAYERS_ATTR_NAME)
count = int((len(layers_obj) / 20) * 18)
if count > self.LAYERS_KEEP_GPU:
count = self.LAYERS_KEEP_GPU
for i in range(count):
layer = layers_obj[i]
layer.to(device)
def go_gpu(self, to_gpu: bool):
if to_gpu:
self.all_nn_modules_to_device(torch.device("cuda"))
self.offload_layers_to_device(torch.device("cuda"))
else:
self.all_nn_modules_to_device(torch.device("cpu"))
self.offload_layers_to_device(torch.device("cpu"))
gc.collect()
torch.cuda.empty_cache()
def enable_partial_cpu_offload(self):
model = self.get_model()
layer_count = len(getattr(model, self.MODEL_LAYERS_ATTR_NAME))
layers_obj = getattr(model, self.MODEL_LAYERS_ATTR_NAME)
count = int((layer_count / 20) * 18)
if count > self.LAYERS_KEEP_GPU:
count = self.LAYERS_KEEP_GPU
for i in range(count, layer_count):
self._enable_sequential_cpu_offload(layers_obj[i])
model = self.get_model()
setattr(self, "sequential_offloaded", True)
def _enable_sequential_cpu_offload(self, module):
import torch
from accelerate import cpu_offload
torch_device = torch.device("cuda")
device_type = torch_device.type
device = torch.device(f"{device_type}:0")
offload_buffers = len(module._parameters) > 0
cpu_offload(module, device, offload_buffers=offload_buffers)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment