Skip to content

Instantly share code, notes, and snippets.

@noskill
Created January 13, 2025 09:53
Show Gist options
  • Select an option

  • Save noskill/8f05d4cc2244859526b489c93c9a5cf3 to your computer and use it in GitHub Desktop.

Select an option

Save noskill/8f05d4cc2244859526b489c93c9a5cf3 to your computer and use it in GitHub Desktop.
enable_sequential_cpu_offload on multiple gpus
import copy as cp
import torch
from diffusers import FluxPipeline
import optimum.quanto
from optimum.quanto import freeze, qfloat8, quantize as _quantize
from diffusers.utils import is_accelerate_available
import logging
if is_accelerate_available():
from accelerate import init_empty_weights
else:
init_empty_weights = nullcontext
from concurrent.futures import ThreadPoolExecutor
def quantize(pipe, dtype=qfloat8):
components = ['unet', 'transformer', 'text_encoder', 'text_encoder_2', 'vae']
for component in components:
if hasattr(pipe, component):
component_obj = getattr(pipe, component)
_quantize(component_obj, weights=dtype)
freeze(component_obj)
# Add attributes to indicate quantization
component_obj._is_quantized = True
component_obj._quantization_dtype = dtype
IS_QUANTIZED = '_is_quantized'
def weightshare_copy(pipe):
"""
Create a new pipe object then assign weights using load_state_dict from passed 'pipe'
"""
copy = pipe.__class__(**pipe.components)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
for key, component in copy.components.items():
if getattr(copy, key) is None:
continue
if key in ('tokenizer', 'tokenizer_2', 'feature_extractor'):
setattr(copy, key, cp.deepcopy(getattr(copy, key)))
continue
cls = getattr(copy, key).__class__
if hasattr(cls, 'from_config'):
setattr(copy, key, cls.from_config(getattr(copy, key).config))
else:
setattr(copy, key, cls(getattr(copy, key).config))
pipe_component = getattr(pipe, key)
if getattr(pipe_component, IS_QUANTIZED, False):
# Quantize the component in the copy using the same dtype
component_obj = getattr(copy, key)
_quantize(component_obj, weights=pipe_component._quantization_dtype)
setattr(component_obj, IS_QUANTIZED, True)
component_obj._quantization_dtype = pipe_component._quantization_dtype
# assign=True is needed since our copy is on "meta" device, i.g. weights are empty
for key, component in copy.components.items():
if key == 'tokenizer' or key == 'tokenizer_2':
continue
obj = getattr(copy, key)
if hasattr(obj, 'load_state_dict'):
obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True)
# some buffers might not be transfered from pipe to copy
copy.to(pipe.device)
return copy
path = "/home/imgen/models/flux-1-dev/"
pipe_params = {
"width": 1024,
"height": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 50
}
pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16)
pipe1 = weightshare_copy(pipe)
pipe1.enable_sequential_cpu_offload(1)
pipe2 = weightshare_copy(pipe)
pipe2.enable_sequential_cpu_offload(2)
pipe3 = weightshare_copy(pipe)
pipe3.enable_sequential_cpu_offload(3)
executor = ThreadPoolExecutor(max_workers=3)
pipes = [pipe1, pipe2, pipe3]
def gen(pipe, prompt, **kwargs):
return pipe(prompt=prompt, **kwargs).images
futures = []
for i in range(3):
futures.append(executor.submit(gen, pipes[i], prompt=f"{i} sheeps jumping over a fence", **pipe_params))
for i, fut in enumerate(futures):
for k, img in enumerate(fut.result()):
img.save(f"img_{i}_{k}.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment