Skip to content

Instantly share code, notes, and snippets.

@dhbrojas
Created July 28, 2025 12:24
Show Gist options
  • Select an option

  • Save dhbrojas/9b9daa88097683ede5076a46c171c79e to your computer and use it in GitHub Desktop.

Select an option

Save dhbrojas/9b9daa88097683ede5076a46c171c79e to your computer and use it in GitHub Desktop.
Warmup Stable Decay LR
import math
def warmup_stable_decay(*, W: int, S: int, D: int, min_lr_scale_factor: float = 0.1):
"""
Returns a lambda function for PyTorch's LambdaLR scheduler implementing the
WSD learning rate schedule.
Parameters:
- W: The last step of the warmup phase.
- S: The last step of the stable phase.
- D: The last step of the decay phase.
- min_lr_scale_factor: The minimum learning rate is eta * min_lr_scale_factor.
Returns:
- A lambda function to be used with torch.optim.lr_scheduler.LambdaLR.
"""
def lr_lambda(current_step: int):
if current_step <= W:
if W > 0:
return current_step / W
else:
return 1.0
elif current_step <= S:
return 1
else:
return (
min_lr_scale_factor
+ (1 - min_lr_scale_factor)
* (1 + math.cos(math.pi * (current_step - S) / (D - S)))
/ 2
)
return lr_lambda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment