Created
January 13, 2025 09:53
-
-
Save noskill/8f05d4cc2244859526b489c93c9a5cf3 to your computer and use it in GitHub Desktop.
enable_sequential_cpu_offload on multiple gpus
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import copy as cp | |
| import torch | |
| from diffusers import FluxPipeline | |
| import optimum.quanto | |
| from optimum.quanto import freeze, qfloat8, quantize as _quantize | |
| from diffusers.utils import is_accelerate_available | |
| import logging | |
| if is_accelerate_available(): | |
| from accelerate import init_empty_weights | |
| else: | |
| init_empty_weights = nullcontext | |
| from concurrent.futures import ThreadPoolExecutor | |
| def quantize(pipe, dtype=qfloat8): | |
| components = ['unet', 'transformer', 'text_encoder', 'text_encoder_2', 'vae'] | |
| for component in components: | |
| if hasattr(pipe, component): | |
| component_obj = getattr(pipe, component) | |
| _quantize(component_obj, weights=dtype) | |
| freeze(component_obj) | |
| # Add attributes to indicate quantization | |
| component_obj._is_quantized = True | |
| component_obj._quantization_dtype = dtype | |
| IS_QUANTIZED = '_is_quantized' | |
| def weightshare_copy(pipe): | |
| """ | |
| Create a new pipe object then assign weights using load_state_dict from passed 'pipe' | |
| """ | |
| copy = pipe.__class__(**pipe.components) | |
| ctx = init_empty_weights if is_accelerate_available() else nullcontext | |
| with ctx(): | |
| for key, component in copy.components.items(): | |
| if getattr(copy, key) is None: | |
| continue | |
| if key in ('tokenizer', 'tokenizer_2', 'feature_extractor'): | |
| setattr(copy, key, cp.deepcopy(getattr(copy, key))) | |
| continue | |
| cls = getattr(copy, key).__class__ | |
| if hasattr(cls, 'from_config'): | |
| setattr(copy, key, cls.from_config(getattr(copy, key).config)) | |
| else: | |
| setattr(copy, key, cls(getattr(copy, key).config)) | |
| pipe_component = getattr(pipe, key) | |
| if getattr(pipe_component, IS_QUANTIZED, False): | |
| # Quantize the component in the copy using the same dtype | |
| component_obj = getattr(copy, key) | |
| _quantize(component_obj, weights=pipe_component._quantization_dtype) | |
| setattr(component_obj, IS_QUANTIZED, True) | |
| component_obj._quantization_dtype = pipe_component._quantization_dtype | |
| # assign=True is needed since our copy is on "meta" device, i.g. weights are empty | |
| for key, component in copy.components.items(): | |
| if key == 'tokenizer' or key == 'tokenizer_2': | |
| continue | |
| obj = getattr(copy, key) | |
| if hasattr(obj, 'load_state_dict'): | |
| obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True) | |
| # some buffers might not be transfered from pipe to copy | |
| copy.to(pipe.device) | |
| return copy | |
| path = "/home/imgen/models/flux-1-dev/" | |
| pipe_params = { | |
| "width": 1024, | |
| "height": 1024, | |
| "guidance_scale": 3.5, | |
| "num_inference_steps": 50 | |
| } | |
| pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16) | |
| pipe1 = weightshare_copy(pipe) | |
| pipe1.enable_sequential_cpu_offload(1) | |
| pipe2 = weightshare_copy(pipe) | |
| pipe2.enable_sequential_cpu_offload(2) | |
| pipe3 = weightshare_copy(pipe) | |
| pipe3.enable_sequential_cpu_offload(3) | |
| executor = ThreadPoolExecutor(max_workers=3) | |
| pipes = [pipe1, pipe2, pipe3] | |
| def gen(pipe, prompt, **kwargs): | |
| return pipe(prompt=prompt, **kwargs).images | |
| futures = [] | |
| for i in range(3): | |
| futures.append(executor.submit(gen, pipes[i], prompt=f"{i} sheeps jumping over a fence", **pipe_params)) | |
| for i, fut in enumerate(futures): | |
| for k, img in enumerate(fut.result()): | |
| img.save(f"img_{i}_{k}.png") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment