Created
August 6, 2024 08:30
-
-
Save dknr/2ea29231401ff6eb5e514ba2f92e4481 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from diffusers import FluxPipeline, AutoencoderKL | |
| from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | |
| from optimum.quanto import freeze, qfloat8, qint4, quantize | |
| from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel | |
| dtype = torch.bfloat16 | |
| flux_dir = "/zdata/models/img/flux" | |
| text_encoder = CLIPTextModel.from_pretrained(flux_dir, subfolder="text_encoder", torch_dtype=dtype) | |
| tokenizer = CLIPTokenizer.from_pretrained(flux_dir, subfolder="tokenizer") | |
| text_encoder_2 = T5EncoderModel.from_pretrained(flux_dir, subfolder="text_encoder_2", torch_dtype=dtype) | |
| tokenizer_2 = T5TokenizerFast.from_pretrained(flux_dir, subfolder="tokenizer_2") | |
| vae = AutoencoderKL.from_pretrained(flux_dir, subfolder="vae", torch_dtype=torch.bfloat16) | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| flux_dir, | |
| subfolder="transformer", | |
| torch_dtype=dtype, | |
| ) | |
| print('loading encoders...') | |
| pipeline = FluxPipeline.from_pretrained( | |
| flux_dir, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| text_encoder_2=text_encoder_2, | |
| tokenizer_2=tokenizer_2, | |
| transformer=transformer, | |
| vae=vae, | |
| torch_dtype=dtype | |
| ) | |
| print('quantizing text_encoder...') | |
| quantize(pipeline.text_encoder, weights=qfloat8) | |
| freeze(pipeline.text_encoder) | |
| print('quantizing text_encoder_2...') | |
| quantize(pipeline.text_encoder_2, weights=qint4) | |
| freeze(pipeline.text_encoder_2) | |
| print('quantizing transformer...') | |
| quantize(transformer, weights=qint4) | |
| freeze(transformer) | |
| print('moving encoders/transformer/vae to GPU 0') | |
| pipeline.to("cuda:0") | |
| print('encoders/transformer loaded.') | |
| def generate(prompt, seed, height, width): | |
| print('encoding...') | |
| print('encoding and transforming...') | |
| image = pipeline( | |
| prompt=prompt, prompt_2=None, | |
| max_sequence_length=256, | |
| num_inference_steps=4, | |
| guidance_scale=0.0, | |
| height=height, | |
| width=width, | |
| generator=torch.Generator("cuda").manual_seed(seed) | |
| ).images[0] | |
| image.save(f'test-{seed}.png') | |
| prompt1 = 'TODO: your prompt here' | |
| for i in range(1000, 1020): | |
| generate(prompt1, i, 768, 1024) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment