Last active
January 12, 2026 10:17
-
-
Save nicklegr/40d73e9c709a8491e8d6aef846aa49a1 to your computer and use it in GitHub Desktop.
ComfyUI-FlashVSR_Stable with streaming video by disk I/O. Can process long video with lower system memory usage
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import os | |
| import sys | |
| import gc | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import time | |
| # ============================================================================= | |
| # CLI argument parsing - EXHAUSTIVE mapping from ComfyUI node INPUT_TYPES | |
| # ============================================================================= | |
| def parse_args(): | |
| """ | |
| Parse command-line arguments. | |
| Every argument corresponds directly to a parameter in the ComfyUI node | |
| INPUT_TYPES (FlashVSRNode, FlashVSRNodeAdv, FlashVSRNodeInitPipe). | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="FlashVSR CLI - Video Super Resolution", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Basic 2x upscale with defaults | |
| python cli_main.py --input video.mp4 --output upscaled.mp4 --scale 2 | |
| # 4x upscale with tiling enabled for lower VRAM | |
| python cli_main.py --input video.mp4 --output upscaled.mp4 --scale 4 \\ | |
| --tiled_vae --tiled_dit --tile_size 256 --tile_overlap 24 | |
| # Long video with chunking to prevent OOM | |
| python cli_main.py --input long_video.mp4 --output upscaled.mp4 \\ | |
| --frame_chunk_size 50 --mode tiny-long | |
| # Low VRAM mode (8GB GPUs) | |
| python cli_main.py --input video.mp4 --output upscaled.mp4 --scale 2 \\ | |
| --vae_model LightVAE_W2.1 --tiled_vae --tiled_dit \\ | |
| --frame_chunk_size 20 --resize_factor 0.5 | |
| For more information, visit: https://github.com/naxci1/ComfyUI-FlashVSR_Stable | |
| """ | |
| ) | |
| # ========================================================================== | |
| # Required arguments | |
| # ========================================================================== | |
| parser.add_argument( | |
| '--input', '-i', | |
| type=str, | |
| required=True, | |
| help='Input video file path (e.g., video.mp4)' | |
| ) | |
| parser.add_argument( | |
| '--output', '-o', | |
| type=str, | |
| required=True, | |
| help='Output video file path (e.g., upscaled.mp4)' | |
| ) | |
| # ========================================================================== | |
| # FlashVSRNodeInitPipe parameters (Pipeline Initialization) | |
| # ========================================================================== | |
| parser.add_argument( | |
| '--model', | |
| type=str, | |
| choices=['FlashVSR', 'FlashVSR-v1.1'], | |
| default='FlashVSR-v1.1', | |
| help='FlashVSR model version. V1.1 is recommended for better stability. (default: FlashVSR-v1.1)' | |
| ) | |
| parser.add_argument( | |
| '--mode', | |
| type=str, | |
| choices=['tiny', 'tiny-long', 'full'], | |
| default='tiny', | |
| help='Operation mode. "tiny": faster, standard memory. "tiny-long": optimized for long videos (lower VRAM). "full": higher quality but max VRAM. (default: tiny)' | |
| ) | |
| parser.add_argument( | |
| '--vae_model', | |
| type=str, | |
| choices=['Wan2.1', 'Wan2.2', 'LightVAE_W2.1', 'TAE_W2.2', 'LightTAE_HY1.5'], | |
| default='Wan2.1', | |
| help='VAE model: Wan2.1 (default), Wan2.2, LightVAE_W2.1 (50%% less VRAM), TAE_W2.2, LightTAE_HY1.5. Auto-downloads if missing. (default: Wan2.1)' | |
| ) | |
| parser.add_argument( | |
| '--force_offload', | |
| action='store_true', | |
| default=True, | |
| help='Force offloading of models to CPU RAM after execution to free up VRAM. (default: True)' | |
| ) | |
| parser.add_argument( | |
| '--no_force_offload', | |
| action='store_true', | |
| help='Disable force offloading (keeps models in VRAM).' | |
| ) | |
| parser.add_argument( | |
| '--precision', | |
| type=str, | |
| choices=['fp16', 'bf16', 'auto'], | |
| default='auto', | |
| help="Inference precision. 'auto' selects bf16 if supported (RTX 30/40/50 series), otherwise fp16. (default: auto)" | |
| ) | |
| parser.add_argument( | |
| '--device', | |
| type=str, | |
| default='auto', | |
| help='Computation device (e.g., "cuda:0", "cuda:1", "cpu", "auto"). (default: auto)' | |
| ) | |
| parser.add_argument( | |
| '--attention_mode', | |
| type=str, | |
| choices=['sparse_sage_attention', 'block_sparse_attention', 'flash_attention_2', 'sdpa'], | |
| default='sparse_sage_attention', | |
| help='Attention mechanism backend. "sparse_sage"/"block_sparse" use efficient sparse attention. "flash_attention_2"/"sdpa" use dense attention. (default: sparse_sage_attention)' | |
| ) | |
| # ========================================================================== | |
| # FlashVSRNodeAdv parameters (Processing) | |
| # ========================================================================== | |
| parser.add_argument( | |
| '--scale', | |
| type=int, | |
| choices=[2, 4], | |
| default=2, | |
| help='Upscaling factor. 2x or 4x. Higher scale requires more VRAM and compute. (default: 2)' | |
| ) | |
| parser.add_argument( | |
| '--color_fix', | |
| action='store_true', | |
| default=True, | |
| help='Apply wavelet-based color correction to match output colors with input. (default: True)' | |
| ) | |
| parser.add_argument( | |
| '--no_color_fix', | |
| action='store_true', | |
| help='Disable color correction.' | |
| ) | |
| parser.add_argument( | |
| '--tiled_vae', | |
| action='store_true', | |
| default=False, | |
| help='Enable spatial tiling for the VAE decoder. Reduces VRAM usage significantly but is slower.' | |
| ) | |
| parser.add_argument( | |
| '--tiled_dit', | |
| action='store_true', | |
| default=False, | |
| help='Enable spatial tiling for the Diffusion Transformer (DiT). Crucial for saving VRAM on large inputs.' | |
| ) | |
| parser.add_argument( | |
| '--tile_size', | |
| type=int, | |
| default=256, | |
| help='Size of the tiles for DiT processing (32-1024). Smaller = less VRAM, more tiles, slower. (default: 256)' | |
| ) | |
| parser.add_argument( | |
| '--tile_overlap', | |
| type=int, | |
| default=24, | |
| help='Overlap pixels between tiles to blend seams (8-512). Higher = smoother transitions. (default: 24)' | |
| ) | |
| parser.add_argument( | |
| '--unload_dit', | |
| action='store_true', | |
| default=False, | |
| help='Unload the DiT model from VRAM before VAE decoding starts. Use if VAE decode runs out of memory.' | |
| ) | |
| parser.add_argument( | |
| '--sparse_ratio', | |
| type=float, | |
| default=2.0, | |
| help='Control for sparse attention (1.5-2.0). 1.5 is faster, 2.0 is more stable/quality. (default: 2.0)' | |
| ) | |
| parser.add_argument( | |
| '--kv_ratio', | |
| type=float, | |
| default=3.0, | |
| help='Key/Value cache ratio (1.0-3.0). 1.0 uses less VRAM; 3.0 provides highest quality retention. (default: 3.0)' | |
| ) | |
| parser.add_argument( | |
| '--local_range', | |
| type=int, | |
| choices=[9, 11], | |
| default=11, | |
| help='Local attention range window. 9 = sharper details; 11 = more stable/consistent results. (default: 11)' | |
| ) | |
| parser.add_argument( | |
| '--seed', | |
| type=int, | |
| default=0, | |
| help='Random seed for noise generation. Same seed + same settings = reproducible results. (default: 0)' | |
| ) | |
| parser.add_argument( | |
| '--frame_chunk_size', | |
| type=int, | |
| default=0, | |
| help='Process video in chunks of N frames to prevent VRAM OOM. 0 = Process all frames at once. (default: 0)' | |
| ) | |
| parser.add_argument( | |
| '--enable_debug', | |
| action='store_true', | |
| default=False, | |
| help='Enable verbose logging to console. Shows VRAM usage, step times, tile info, and detailed progress.' | |
| ) | |
| parser.add_argument( | |
| '--keep_models_on_cpu', | |
| action='store_true', | |
| default=True, | |
| help='Move models to CPU RAM instead of keeping them in VRAM when not in use. (default: True)' | |
| ) | |
| parser.add_argument( | |
| '--no_keep_models_on_cpu', | |
| action='store_true', | |
| help='Keep models in VRAM (faster but uses more VRAM).' | |
| ) | |
| parser.add_argument( | |
| '--resize_factor', | |
| type=float, | |
| default=1.0, | |
| help='Resize input frames before processing (0.1-1.0). Set to 0.5 for large 1080p+ videos. (default: 1.0)' | |
| ) | |
| # ========================================================================== | |
| # Video I/O parameters | |
| # ========================================================================== | |
| parser.add_argument( | |
| '--fps', | |
| type=float, | |
| default=None, | |
| help='Output video FPS. If not specified, uses input video FPS.' | |
| ) | |
| parser.add_argument( | |
| '--codec', | |
| type=str, | |
| default='libx264', | |
| help='Video codec for output (e.g., libx264, libx265, h264_nvenc). (default: libx264)' | |
| ) | |
| parser.add_argument( | |
| '--crf', | |
| type=int, | |
| default=18, | |
| help='Constant Rate Factor for quality (0-51, lower = better quality). (default: 18)' | |
| ) | |
| parser.add_argument( | |
| '--start_frame', | |
| type=int, | |
| default=0, | |
| help='Start processing from this frame index (0-indexed). (default: 0)' | |
| ) | |
| parser.add_argument( | |
| '--end_frame', | |
| type=int, | |
| default=-1, | |
| help='Stop processing at this frame index (-1 = process all). (default: -1)' | |
| ) | |
| # ========================================================================== | |
| # Model paths (optional, for custom model locations) | |
| # ========================================================================== | |
| parser.add_argument( | |
| '--models_dir', | |
| type=str, | |
| default=None, | |
| help='Custom path to FlashVSR models directory. If not set, uses ComfyUI default or ./models' | |
| ) | |
| return parser.parse_args() | |
| def get_video_properties(video_path): | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| return fps, total_frames | |
| def load_video_chunk(cap, start_idx, chunk_size, end_limit): | |
| """指定されたチャンク範囲のフレームを読み込む""" | |
| frames = [] | |
| actual_chunk_size = min(chunk_size, end_limit - start_idx) | |
| for _ in range(actual_chunk_size): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame_normalized = frame_rgb.astype(np.float32) / 255.0 | |
| frames.append(frame_normalized) | |
| if not frames: | |
| return None | |
| return torch.from_numpy(np.stack(frames, axis=0)) | |
| def main(): | |
| import sys | |
| args = parse_args() | |
| # --- 出力パスのチェックと上書き防止処理 --- | |
| output_path = args.output | |
| if os.path.exists(output_path): | |
| print(f"Error: Output file '{output_path}' already exists.") | |
| print("Aborting to prevent overwriting.") | |
| sys.exit(1) | |
| force_offload = args.force_offload and not args.no_force_offload | |
| color_fix = args.color_fix and not args.no_color_fix | |
| keep_models_on_cpu = args.keep_models_on_cpu and not args.no_keep_models_on_cpu | |
| print("=" * 60) | |
| print("FlashVSR CLI - Video Super Resolution") | |
| print("=" * 60) | |
| print(f"Input: {args.input}") | |
| print(f"Output: {args.output}") | |
| print(f"Model: {args.model}, Mode: {args.mode}") | |
| print(f"VAE: {args.vae_model}, Scale: {args.scale}x") | |
| print("=" * 60) | |
| # ========================================================================== | |
| # Setup environment and imports | |
| # ========================================================================== | |
| # Mock ComfyUI modules for standalone CLI operation | |
| from unittest.mock import MagicMock | |
| # Create mock folder_paths module | |
| folder_paths_mock = MagicMock() | |
| if args.models_dir: | |
| folder_paths_mock.models_dir = args.models_dir | |
| else: | |
| # Default to ./models or ComfyUI default | |
| folder_paths_mock.models_dir = os.path.join(os.path.dirname(__file__), "models") | |
| folder_paths_mock.get_filename_list = MagicMock(return_value=[]) | |
| sys.modules['folder_paths'] = folder_paths_mock | |
| # Create mock comfy modules | |
| comfy_mock = MagicMock() | |
| comfy_utils_mock = MagicMock() | |
| comfy_utils_mock.ProgressBar = MagicMock() | |
| sys.modules['comfy'] = comfy_mock | |
| sys.modules['comfy.utils'] = comfy_utils_mock | |
| from nodes import init_pipeline, flashvsr | |
| from src.models import wan_video_dit | |
| # デバイス・精度の設定 | |
| device = args.device | |
| if device == "auto": | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if args.precision == "auto" and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 | |
| # パイプライン初期化 | |
| wan_video_dit.ATTENTION_MODE = args.attention_mode | |
| pipe = init_pipeline(model=args.model, mode=args.mode, device=device, dtype=dtype, vae_model=args.vae_model) | |
| # ビデオ情報の取得 | |
| input_fps, total_frames = get_video_properties(args.input) | |
| output_fps = args.fps if args.fps is not None else input_fps | |
| start_f = args.start_frame | |
| end_f = args.end_frame if args.end_frame != -1 else total_frames | |
| # チャンクサイズが0の場合は全フレームを一度に処理 | |
| chunk_size = args.frame_chunk_size if args.frame_chunk_size > 0 else (end_f - start_f) | |
| cap = cv2.VideoCapture(args.input) | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, start_f) | |
| video_writer = None | |
| current_idx = start_f | |
| # 処理開始時間の記録 | |
| start_time = time.time() | |
| total_to_process = end_f - start_f | |
| processed_count = 0 | |
| print(f"\nProcessing video: {args.input}") | |
| print(f"Total frames to process: {total_to_process} (Chunk size: {chunk_size})") | |
| try: | |
| while current_idx < end_f: | |
| # 1. 読み込み | |
| frames_chunk = load_video_chunk(cap, current_idx, chunk_size, end_f) | |
| if frames_chunk is None: | |
| break | |
| num_frames_in_chunk = len(frames_chunk) | |
| # print(f"\n--- Processing frames {current_idx} to {current_idx + len(frames_chunk) - 1} / {end_f} ---") | |
| # 2. 推論 (FlashVSR) | |
| output_chunk = flashvsr( | |
| pipe=pipe, | |
| frames=frames_chunk, | |
| scale=args.scale, | |
| color_fix=color_fix, | |
| tiled_vae=args.tiled_vae, | |
| tiled_dit=args.tiled_dit, | |
| tile_size=args.tile_size, | |
| tile_overlap=args.tile_overlap, | |
| unload_dit=args.unload_dit, | |
| sparse_ratio=args.sparse_ratio, | |
| kv_ratio=args.kv_ratio, | |
| local_range=args.local_range, | |
| seed=args.seed, | |
| force_offload=keep_models_on_cpu, | |
| enable_debug=args.enable_debug, | |
| chunk_size=0, # FlashVSR内部のチャンク分割は無効化(外側で制御しているため) | |
| resize_factor=args.resize_factor, | |
| mode=args.mode | |
| ) | |
| # 3. 書き出し準備(初回のみ) | |
| if video_writer is None: | |
| h, w = output_chunk.shape[1:3] | |
| os.makedirs(os.path.dirname(os.path.abspath(args.output)), exist_ok=True) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 互換性の高いmp4vを使用 | |
| video_writer = cv2.VideoWriter(args.output, fourcc, output_fps, (w, h)) | |
| # 4. 書き出し | |
| out_np = (torch.clamp(output_chunk, 0, 1).cpu().numpy() * 255).astype(np.uint8) | |
| for i in range(len(out_np)): | |
| video_writer.write(cv2.cvtColor(out_np[i], cv2.COLOR_RGB2BGR)) | |
| # 進捗の更新と表示 | |
| processed_count += num_frames_in_chunk | |
| current_idx += num_frames_in_chunk | |
| elapsed_now = time.time() - start_time | |
| percent = (processed_count / total_to_process) * 100 | |
| current_fps = processed_count / elapsed_now | |
| print(f"Progress: {percent:6.2f}% | Processed: {processed_count}/{total_to_process} | " | |
| f"Elapsed: {elapsed_now:6.1f}s | Speed: {current_fps:4.2f} fps") | |
| # メモリ解放 | |
| del frames_chunk, output_chunk | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| finally: | |
| total_elapsed = time.time() - start_time | |
| avg_fps = processed_count / total_elapsed if total_elapsed > 0 else 0 | |
| cap.release() | |
| if video_writer: | |
| video_writer.release() | |
| print("\n" + "=" * 60) | |
| print("Final Statistics:") | |
| print(f" - Total processed frames: {processed_count}") | |
| print(f" - Total processing time: {total_elapsed:.2f} seconds") | |
| print(f" - Average throughput: {avg_fps:.2f} frames/sec") | |
| print(f" - Output saved to: {output_path}") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Based on commit 6995a80
Since I just had Gemini rewrite the original code, so it includes many unnecessary diffs.