Skip to content

Instantly share code, notes, and snippets.

@calebcase
Last active October 23, 2025 14:18
Show Gist options
  • Select an option

  • Save calebcase/030dad94976b42075687ba84ebc6d925 to your computer and use it in GitHub Desktop.

Select an option

Save calebcase/030dad94976b42075687ba84ebc6d925 to your computer and use it in GitHub Desktop.
qwen plus: compile(fullgraph=True, mode='max-autotune')

qwen plus: compile(fullgraph=True, mode='max-autotune')

Demonstrates failure to torch compile qwen image plus.

accelerate==1.11.0
anyio==4.11.0
certifi==2025.10.5
charset-normalizer==3.4.4
diffusers @ git+https://github.com/huggingface/diffusers@9c3b58dcf16ebd027fd3d85ec703ad5a142b1e1e
filelock==3.20.0
fsspec==2025.9.0
h11==0.16.0
hf-xet==1.1.10
httpcore==1.0.9
httpx==0.28.1
huggingface-hub==0.36.0
idna==3.11
importlib_metadata==8.7.0
Jinja2==3.1.6
MarkupSafe==3.0.3
mpmath==1.3.0
networkx==3.5
numpy==2.3.4
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.5
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvshmem-cu12==3.3.20
nvidia-nvtx-cu12==12.8.90
packaging==25.0
pillow==12.0.0
psutil==7.1.1
PyYAML==6.0.3
regex==2025.10.23
requests==2.32.5
safetensors==0.6.2
setuptools==80.9.0
sniffio==1.3.1
sympy==1.14.0
tokenizers==0.22.1
torch==2.9.0
torchvision==0.24.0
tqdm==4.67.1
transformers==4.57.1
triton==3.5.0
typing_extensions==4.15.0
urllib3==2.5.0
wheel==0.45.1
zipp==3.23.0
Loading pipeline components...: 17%|████████████ | 1/6 [00:00<00:01, 4.94it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 54.81it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 56.23it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████| 6/6 [00:01<00:00, 4.23it/s]
pipeline loaded
guidance_scale is passed as 1.0, but ignored since the model is not guidance-distilled.
0%| | 0/40 [00:00<?, ?it/s]/root/qwen/venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:1692: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
torch._dynamo.utils.warn_once(msg)
/root/qwen/venv/lib/python3.12/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
return torch._C._get_cublas_allow_tf32()
/root/qwen/venv/lib/python3.12/site-packages/torch/_inductor/lowering.py:1988: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
warnings.warn(
Autotune Choices Stats:
{"num_choices": 20, "num_triton_choices": 19, "best_kernel": "mm", "best_time": 0.554207980632782, "best_triton_pos": 1, "best_triton_time": 0.8834879994392395, "best_triton_kernel": "triton_mm_207", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4"}
AUTOTUNE mm(12288x3072, 3072x12288)
<--- snip --->
SingleProcess AUTOTUNE benchmarking takes 0.2783 seconds and 0.0005 seconds precompiling for 20 choices
cudagraph partition due to non gpu ops. Found from :
File "/root/qwen/venv/lib/python3.12/site-packages/diffusers/models/transformers/transformer_qwenimage.py", line 633, in forward
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
File "/root/qwen/venv/lib/python3.12/site-packages/diffusers/models/transformers/transformer_qwenimage.py", line 213, in forward
self.pos_freqs = self.pos_freqs.to(device)
cudagraph partition due to non gpu ops. Found from :
File "/root/qwen/venv/lib/python3.12/site-packages/diffusers/models/transformers/transformer_qwenimage.py", line 633, in forward
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
File "/root/qwen/venv/lib/python3.12/site-packages/diffusers/models/transformers/transformer_qwenimage.py", line 214, in forward
self.neg_freqs = self.neg_freqs.to(device)
cudagraph partition into 2 partitions
Autotune Choices Stats:
{"num_choices": 20, "num_triton_choices": 19, "best_kernel": "mm", "best_time": 0.030559999868273735, "best_triton_pos": 1, "best_triton_time": 0.05939200147986412, "best_triton_kernel": "triton_mm_13925", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8"}
AUTOTUNE mm(412x3072, 3072x12288)
<--- snip --->
SingleProcess AUTOTUNE benchmarking takes 0.5303 seconds and 0.0005 seconds precompiling for 20 choices
0%| | 0/40 [06:34<?, ?it/s]
Traceback (most recent call last):
File "/root/qwen/main.py", line 31, in <module>
output = pipeline(**inputs)
^^^^^^^^^^^^^^^^^^
File "/root/qwen/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/qwen/venv/lib/python3.12/site-packages/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py", line 828, in __call__
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
~~~~~~~~~~~^~~~~~~~~~~~~~~~
RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/root/qwen/venv/lib/python3.12/site-packages/diffusers/models/transformers/transformer_qwenimage.py", line 664, in forward
output = self.proj_out(hidden_states). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.
from PIL import Image
from diffusers import QwenImageEditPlusPipeline
from io import BytesIO
import datetime
import os
import requests
import torch
pipeline = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16)
print("pipeline loaded")
pipeline.to('cuda')
pipeline.set_progress_bar_config(disable=None)
pipeline.transformer.compile(fullgraph=True, mode='max-autotune')
image1 = Image.open(BytesIO(requests.get("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg").content))
image2 = Image.open(BytesIO(requests.get("https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg").content))
prompt = "The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square."
inputs = {
"image": [image1, image2],
"prompt": prompt,
"generator": torch.manual_seed(0),
"true_cfg_scale": 4.0,
"negative_prompt": " ",
"num_inference_steps": 40,
"guidance_scale": 1.0,
"num_images_per_prompt": 1,
}
with torch.inference_mode():
output = pipeline(**inputs)
output_image = output.images[0]
output_image.save("output_image_edit_plus.png")
print("image saved at", os.path.abspath("output_image_edit_plus.png"))
with torch.inference_mode():
print(f"start: {datetime.datetime.now()}")
output = pipeline(**inputs)
print(f"stop : {datetime.datetime.now()}")
output_image = output.images[0]
output_image.save("output_image_edit_plus.png")
print("image saved at", os.path.abspath("output_image_edit_plus.png"))
accelerate
diffusers @ git+https://github.com/huggingface/diffusers
pillow
torch
torchvision
transformers
wheel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment