Skip to content

Instantly share code, notes, and snippets.

@remixer-dec
Created January 15, 2026 20:59
Show Gist options
  • Select an option

  • Save remixer-dec/eb766f635098ec284df2a3fce216e8f9 to your computer and use it in GitHub Desktop.

Select an option

Save remixer-dec/eb766f635098ec284df2a3fce216e8f9 to your computer and use it in GitHub Desktop.
SEEDVR2 patch to add denoise_strength to comfyui, based on numz/ComfyUI-SeedVR2_VideoUpscaler == 2.5.24
  1. Update UI Definition

Add the input to the ComfyUI node and pass it to the execution logic.

File: src/interfaces/video_upscaler.py

<<<<
                io.Combo.Input("color_correction",
                    options=["lab", "wavelet", "wavelet_adaptive", "hsv", "adain", "none"],
                    default="lab",
                    tooltip=(
                        "Corrects color shifts in upscaled output to match original input (default: lab).\n"
                        "The upscaling process may alter colors; this applies color grading to restore them.\n"
                        "\n"
                        "• lab: Perceptual color matching with detail preservation (recommended)\n"
                        "• wavelet: Frequency-based natural colors, preserves fine details\n"
                        "• wavelet_adaptive: Wavelet base with targeted saturation correction\n"
                        "• hsv: Hue-conditional saturation matching\n"
                        "• adain: Statistical style transfer approach\n"
                        "• none: No color correction applied"
                    )
                ),
                io.Float.Input("input_noise_scale",
====
                io.Combo.Input("color_correction",
                    options=["lab", "wavelet", "wavelet_adaptive", "hsv", "adain", "none"],
                    default="lab",
                    tooltip=(
                        "Corrects color shifts in upscaled output to match original input (default: lab).\n"
                        "The upscaling process may alter colors; this applies color grading to restore them.\n"
                        "\n"
                        "• lab: Perceptual color matching with detail preservation (recommended)\n"
                        "• wavelet: Frequency-based natural colors, preserves fine details\n"
                        "• wavelet_adaptive: Wavelet base with targeted saturation correction\n"
                        "• hsv: Hue-conditional saturation matching\n"
                        "• adain: Statistical style transfer approach\n"
                        "• none: No color correction applied"
                    )
                ),
                io.Float.Input("denoise_strength",
                    default=1.0,
                    min=0.0,
                    max=1.0,
                    step=0.01,
                    tooltip=(
                        "Controls how much of the diffusion process is applied (default: 1.0).\n"
                        "• 1.0: Full denoising from pure noise (standard for upscaling).\n"
                        "• < 1.0: Start diffusion from a later step. Lower values retain more structure\n"
                        "  but may not fully resolve high-res details."
                    )
                ),
                io.Float.Input("input_noise_scale",
>>>>
<<<<
    @classmethod
    def execute(cls, image: torch.Tensor, dit: Dict[str, Any], vae: Dict[str, Any], 
                seed: int, resolution: int = 1080, max_resolution: int = 0, batch_size: int = 5,
                uniform_batch_size: bool = False, temporal_overlap: int = 0, prepend_frames: int = 0,
                color_correction: str = "wavelet", input_noise_scale: float = 0.0,
                latent_noise_scale: float = 0.0, offload_device: str = "none", 
                enable_debug: bool = False) -> io.NodeOutput:
        """
        Execute SeedVR2 video upscaling with progress reporting
====
    @classmethod
    def execute(cls, image: torch.Tensor, dit: Dict[str, Any], vae: Dict[str, Any], 
                seed: int, resolution: int = 1080, max_resolution: int = 0, batch_size: int = 5,
                uniform_batch_size: bool = False, temporal_overlap: int = 0, prepend_frames: int = 0,
                color_correction: str = "wavelet", denoise_strength: float = 1.0,
                input_noise_scale: float = 0.0, latent_noise_scale: float = 0.0, 
                offload_device: str = "none", enable_debug: bool = False) -> io.NodeOutput:
        """
        Execute SeedVR2 video upscaling with progress reporting
>>>>

```python
<<<<
            # Phase 2: Upscale
            ctx = upscale_all_batches(
                runner,
                ctx=ctx,
                debug=debug,
                progress_callback=progress_callback,
                seed=seed,
                latent_noise_scale=latent_noise_scale,
                cache_model=dit_cache
            )
====
            # Phase 2: Upscale
            ctx = upscale_all_batches(
                runner,
                ctx=ctx,
                debug=debug,
                progress_callback=progress_callback,
                seed=seed,
                latent_noise_scale=latent_noise_scale,
                denoise_strength=denoise_strength,
                cache_model=dit_cache
            )
>>>>
  1. Update Inference Runner

Update the inference method to accept the parameter and pass it to the sampler.

File: src/core/infer.py

<<<<
    @torch.no_grad()
    def inference(
        self,
        noises: List[Tensor],
        conditions: List[Tensor],
        texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
        texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
        cfg_scale: Optional[float] = None,
    ) -> List[Tensor]:
====
    @torch.no_grad()
    def inference(
        self,
        noises: List[Tensor],
        conditions: List[Tensor],
        texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
        texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]],
        cfg_scale: Optional[float] = None,
        denoise_strength: float = 1.0,
    ) -> List[Tensor]:
>>>>

```python
<<<<
        latents = self.sampler.sample(
            x=latents,
            f=lambda args: classifier_free_guidance_dispatcher(
                pos=lambda: self.dit(
                    vid=torch.cat([args.x_t, latents_cond], dim=-1),
                    txt=text_pos_embeds,
                    vid_shape=latents_shapes,
                    txt_shape=text_pos_shapes,
                    timestep=args.t.repeat(batch_size),
                ).vid_sample,
====
        latents = self.sampler.sample(
            x=latents,
            denoise_strength=denoise_strength,
            f=lambda args: classifier_free_guidance_dispatcher(
                pos=lambda: self.dit(
                    vid=torch.cat([args.x_t, latents_cond], dim=-1),
                    txt=text_pos_embeds,
                    vid_shape=latents_shapes,
                    txt_shape=text_pos_shapes,
                    timestep=args.t.repeat(batch_size),
                ).vid_sample,
>>>>
  1. Update Sampler Base Class

Update the abstract method signature.

File: src/common/diffusion/samplers/base.py

<<<<
    @abstractmethod
    def sample(
        self,
        x: torch.Tensor,
        f: Callable[[SamplerModelArgs], torch.Tensor],
    ) -> torch.Tensor:
        """
        Generate a new sample given the the intial sample x and score function f.
        """
====
    @abstractmethod
    def sample(
        self,
        x: torch.Tensor,
        f: Callable[[SamplerModelArgs], torch.Tensor],
        denoise_strength: float = 1.0,
    ) -> torch.Tensor:
        """
        Generate a new sample given the the intial sample x and score function f.
        """
>>>>
  1. Update Euler Sampler Logic

Implement the logic to skip timesteps based on the strength value.

File: src/common/diffusion/samplers/euler.py

<<<<
    def sample(
        self,
        x: torch.Tensor,
        f: Callable[[SamplerModelArgs], torch.Tensor],
    ) -> torch.Tensor:
        timesteps = self.timesteps.timesteps
        progress = self.get_progress_bar()
        i = 0
        
        # Keep native dtype throughout sampling
        # The DiT model already handles dtype internally via compatibility wrapper
        for t, s in zip(timesteps[:-1], timesteps[1:]):
            pred = f(SamplerModelArgs(x, t, i))
            
            # Next step
            x = self.step_to(pred, x, t, s)
            
            # Clean up temporary tensors
            del pred
            
            i += 1
            progress.update()

        if self.return_endpoint:
            t = timesteps[-1]
            pred = f(SamplerModelArgs(x, t, i))
            x = self.get_endpoint(pred, x, t)
            del pred
            progress.update()
            
        return x
====
    def sample(
        self,
        x: torch.Tensor,
        f: Callable[[SamplerModelArgs], torch.Tensor],
        denoise_strength: float = 1.0,
    ) -> torch.Tensor:
        full_timesteps = self.timesteps.timesteps
        
        # Calculate starting index based on denoise_strength
        # Strength 1.0 starts at index 0 (full steps)
        # Strength 0.0 starts at the end (0 steps)
        total_steps = len(full_timesteps)
        start_idx = int(total_steps * (1.0 - denoise_strength))
        
        # Ensure we have at least 2 steps (start and end) if strength > 0, 
        # but prevent out of bounds
        start_idx = max(0, min(start_idx, total_steps - 2))
        
        # Slice timesteps
        timesteps = full_timesteps[start_idx:]
        
        # Manually advance iterator to match schedule
        i = start_idx
        
        # Adjust progress bar
        progress = self.get_progress_bar()
        if start_idx > 0:
            progress.update(start_idx)
        
        # Keep native dtype throughout sampling
        # The DiT model already handles dtype internally via compatibility wrapper
        for t, s in zip(timesteps[:-1], timesteps[1:]):
            pred = f(SamplerModelArgs(x, t, i))
            
            # Next step
            x = self.step_to(pred, x, t, s)
            
            # Clean up temporary tensors
            del pred
            
            i += 1
            progress.update()

        if self.return_endpoint:
            t = timesteps[-1]
            pred = f(SamplerModelArgs(x, t, i))
            x = self.get_endpoint(pred, x, t)
            del pred
            progress.update()
            
        return x
>>>>
  1. Important: Update Generation Phases

The file src/core/generation_phases.py contains the upscale_all_batches function referenced in video_upscaler.py.

Find the definition of upscale_all_batches and ensure it accepts denoise_strength and passes it to runner.inference:

# In src/core/generation_phases.py (Pseudo-code)

def upscale_all_batches(
    runner, 
    ctx, 
    # ... other args ...
    denoise_strength: float = 1.0, # Add this argument
    # ...
):
    # ... inside the batch processing loop ...
    
    # When calling inference:
    batch_samples = runner.inference(
        ...,
        denoise_strength=denoise_strength, # Pass it here
        *kwargs_param
    )
    # the runner.inference is called twice, replace the second one as well.
    
    # ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment