- Update UI Definition
Add the input to the ComfyUI node and pass it to the execution logic.
File: src/interfaces/video_upscaler.py
<<<<
io.Combo.Input("color_correction",
options=["lab", "wavelet", "wavelet_adaptive", "hsv", "adain", "none"],
default="lab",
tooltip=(
"Corrects color shifts in upscaled output to match original input (default: lab).\n"
"The upscaling process may alter colors; this applies color grading to restore them.\n"
"\n"
"• lab: Perceptual color matching with detail preservation (recommended)\n"
"• wavelet: Frequency-based natural colors, preserves fine details\n"
"• wavelet_adaptive: Wavelet base with targeted saturation correction\n"
"• hsv: Hue-conditional saturation matching\n"
"• adain: Statistical style transfer approach\n"
"• none: No color correction applied"
)
),
io.Float.Input("input_noise_scale",
====
io.Combo.Input("color_correction",
options=["lab", "wavelet", "wavelet_adaptive", "hsv", "adain", "none"],
default="lab",
tooltip=(
"Corrects color shifts in upscaled output to match original input (default: lab).\n"
"The upscaling process may alter colors; this applies color grading to restore them.\n"
"\n"
"• lab: Perceptual color matching with detail preservation (recommended)\n"
"• wavelet: Frequency-based natural colors, preserves fine details\n"
"• wavelet_adaptive: Wavelet base with targeted saturation correction\n"
"• hsv: Hue-conditional saturation matching\n"
"• adain: Statistical style transfer approach\n"
"• none: No color correction applied"
)
),
io.Float.Input("denoise_strength",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip=(
"Controls how much of the diffusion process is applied (default: 1.0).\n"
"• 1.0: Full denoising from pure noise (standard for upscaling).\n"
"• < 1.0: Start diffusion from a later step. Lower values retain more structure\n"
" but may not fully resolve high-res details."
)
),
io.Float.Input("input_noise_scale",
>>>><<<<
@classmethod
def execute(cls, image: torch.Tensor, dit: Dict[str, Any], vae: Dict[str, Any],
seed: int, resolution: int = 1080, max_resolution: int = 0, batch_size: int = 5,
uniform_batch_size: bool = False, temporal_overlap: int = 0, prepend_frames: int = 0,
color_correction: str = "wavelet", input_noise_scale: float = 0.0,
latent_noise_scale: float = 0.0, offload_device: str = "none",
enable_debug: bool = False) -> io.NodeOutput:
"""
Execute SeedVR2 video upscaling with progress reporting
====
@classmethod
def execute(cls, image: torch.Tensor, dit: Dict[str, Any], vae: Dict[str, Any],
seed: int, resolution: int = 1080, max_resolution: int = 0, batch_size: int = 5,
uniform_batch_size: bool = False, temporal_overlap: int = 0, prepend_frames: int = 0,
color_correction: str = "wavelet", denoise_strength: float = 1.0,
input_noise_scale: float = 0.0, latent_noise_scale: float = 0.0,
offload_device: str = "none", enable_debug: bool = False) -> io.NodeOutput:
"""
Execute SeedVR2 video upscaling with progress reporting
>>>>
```python
<<<<
# Phase 2: Upscale
ctx = upscale_all_batches(
runner,
ctx=ctx,
debug=debug,
progress_callback=progress_callback,
seed=seed,
latent_noise_scale=latent_noise_scale,
cache_model=dit_cache
)
====
# Phase 2: Upscale
ctx = upscale_all_batches(
runner,
ctx=ctx,
debug=debug,
progress_callback=progress_callback,
seed=seed,
latent_noise_scale=latent_noise_scale,
denoise_strength=denoise_strength,
cache_model=dit_cache
)
>>>>- Update Inference Runner
Update the inference method to accept the parameter and pass it to the sampler.
File: src/core/infer.py
<<<<
@torch.no_grad()
def inference(
self,
noises: List[Tensor],
conditions: List[Tensor],
texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
cfg_scale: Optional[float] = None,
) -> List[Tensor]:
====
@torch.no_grad()
def inference(
self,
noises: List[Tensor],
conditions: List[Tensor],
texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
cfg_scale: Optional[float] = None,
denoise_strength: float = 1.0,
) -> List[Tensor]:
>>>>
```python
<<<<
latents = self.sampler.sample(
x=latents,
f=lambda args: classifier_free_guidance_dispatcher(
pos=lambda: self.dit(
vid=torch.cat([args.x_t, latents_cond], dim=-1),
txt=text_pos_embeds,
vid_shape=latents_shapes,
txt_shape=text_pos_shapes,
timestep=args.t.repeat(batch_size),
).vid_sample,
====
latents = self.sampler.sample(
x=latents,
denoise_strength=denoise_strength,
f=lambda args: classifier_free_guidance_dispatcher(
pos=lambda: self.dit(
vid=torch.cat([args.x_t, latents_cond], dim=-1),
txt=text_pos_embeds,
vid_shape=latents_shapes,
txt_shape=text_pos_shapes,
timestep=args.t.repeat(batch_size),
).vid_sample,
>>>>- Update Sampler Base Class
Update the abstract method signature.
File: src/common/diffusion/samplers/base.py
<<<<
@abstractmethod
def sample(
self,
x: torch.Tensor,
f: Callable[[SamplerModelArgs], torch.Tensor],
) -> torch.Tensor:
"""
Generate a new sample given the the intial sample x and score function f.
"""
====
@abstractmethod
def sample(
self,
x: torch.Tensor,
f: Callable[[SamplerModelArgs], torch.Tensor],
denoise_strength: float = 1.0,
) -> torch.Tensor:
"""
Generate a new sample given the the intial sample x and score function f.
"""
>>>>- Update Euler Sampler Logic
Implement the logic to skip timesteps based on the strength value.
File: src/common/diffusion/samplers/euler.py
<<<<
def sample(
self,
x: torch.Tensor,
f: Callable[[SamplerModelArgs], torch.Tensor],
) -> torch.Tensor:
timesteps = self.timesteps.timesteps
progress = self.get_progress_bar()
i = 0
# Keep native dtype throughout sampling
# The DiT model already handles dtype internally via compatibility wrapper
for t, s in zip(timesteps[:-1], timesteps[1:]):
pred = f(SamplerModelArgs(x, t, i))
# Next step
x = self.step_to(pred, x, t, s)
# Clean up temporary tensors
del pred
i += 1
progress.update()
if self.return_endpoint:
t = timesteps[-1]
pred = f(SamplerModelArgs(x, t, i))
x = self.get_endpoint(pred, x, t)
del pred
progress.update()
return x
====
def sample(
self,
x: torch.Tensor,
f: Callable[[SamplerModelArgs], torch.Tensor],
denoise_strength: float = 1.0,
) -> torch.Tensor:
full_timesteps = self.timesteps.timesteps
# Calculate starting index based on denoise_strength
# Strength 1.0 starts at index 0 (full steps)
# Strength 0.0 starts at the end (0 steps)
total_steps = len(full_timesteps)
start_idx = int(total_steps * (1.0 - denoise_strength))
# Ensure we have at least 2 steps (start and end) if strength > 0,
# but prevent out of bounds
start_idx = max(0, min(start_idx, total_steps - 2))
# Slice timesteps
timesteps = full_timesteps[start_idx:]
# Manually advance iterator to match schedule
i = start_idx
# Adjust progress bar
progress = self.get_progress_bar()
if start_idx > 0:
progress.update(start_idx)
# Keep native dtype throughout sampling
# The DiT model already handles dtype internally via compatibility wrapper
for t, s in zip(timesteps[:-1], timesteps[1:]):
pred = f(SamplerModelArgs(x, t, i))
# Next step
x = self.step_to(pred, x, t, s)
# Clean up temporary tensors
del pred
i += 1
progress.update()
if self.return_endpoint:
t = timesteps[-1]
pred = f(SamplerModelArgs(x, t, i))
x = self.get_endpoint(pred, x, t)
del pred
progress.update()
return x
>>>>- Important: Update Generation Phases
The file src/core/generation_phases.py contains the upscale_all_batches function referenced in video_upscaler.py.
Find the definition of upscale_all_batches and ensure it accepts denoise_strength and passes it to runner.inference:
# In src/core/generation_phases.py (Pseudo-code)
def upscale_all_batches(
runner,
ctx,
# ... other args ...
denoise_strength: float = 1.0, # Add this argument
# ...
):
# ... inside the batch processing loop ...
# When calling inference:
batch_samples = runner.inference(
...,
denoise_strength=denoise_strength, # Pass it here
*kwargs_param
)
# the runner.inference is called twice, replace the second one as well.
# ...