Skip to content

Instantly share code, notes, and snippets.

@dknr
Created August 6, 2024 08:30
Show Gist options
  • Select an option

  • Save dknr/2ea29231401ff6eb5e514ba2f92e4481 to your computer and use it in GitHub Desktop.

Select an option

Save dknr/2ea29231401ff6eb5e514ba2f92e4481 to your computer and use it in GitHub Desktop.
import torch
from diffusers import FluxPipeline, AutoencoderKL
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from optimum.quanto import freeze, qfloat8, qint4, quantize
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
dtype = torch.bfloat16
flux_dir = "/zdata/models/img/flux"
text_encoder = CLIPTextModel.from_pretrained(flux_dir, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(flux_dir, subfolder="tokenizer")
text_encoder_2 = T5EncoderModel.from_pretrained(flux_dir, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer_2 = T5TokenizerFast.from_pretrained(flux_dir, subfolder="tokenizer_2")
vae = AutoencoderKL.from_pretrained(flux_dir, subfolder="vae", torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(
flux_dir,
subfolder="transformer",
torch_dtype=dtype,
)
print('loading encoders...')
pipeline = FluxPipeline.from_pretrained(
flux_dir,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
transformer=transformer,
vae=vae,
torch_dtype=dtype
)
print('quantizing text_encoder...')
quantize(pipeline.text_encoder, weights=qfloat8)
freeze(pipeline.text_encoder)
print('quantizing text_encoder_2...')
quantize(pipeline.text_encoder_2, weights=qint4)
freeze(pipeline.text_encoder_2)
print('quantizing transformer...')
quantize(transformer, weights=qint4)
freeze(transformer)
print('moving encoders/transformer/vae to GPU 0')
pipeline.to("cuda:0")
print('encoders/transformer loaded.')
def generate(prompt, seed, height, width):
print('encoding...')
print('encoding and transforming...')
image = pipeline(
prompt=prompt, prompt_2=None,
max_sequence_length=256,
num_inference_steps=4,
guidance_scale=0.0,
height=height,
width=width,
generator=torch.Generator("cuda").manual_seed(seed)
).images[0]
image.save(f'test-{seed}.png')
prompt1 = 'TODO: your prompt here'
for i in range(1000, 1020):
generate(prompt1, i, 768, 1024)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment