Created
July 28, 2025 12:24
-
-
Save dhbrojas/9b9daa88097683ede5076a46c171c79e to your computer and use it in GitHub Desktop.
Warmup Stable Decay LR
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| def warmup_stable_decay(*, W: int, S: int, D: int, min_lr_scale_factor: float = 0.1): | |
| """ | |
| Returns a lambda function for PyTorch's LambdaLR scheduler implementing the | |
| WSD learning rate schedule. | |
| Parameters: | |
| - W: The last step of the warmup phase. | |
| - S: The last step of the stable phase. | |
| - D: The last step of the decay phase. | |
| - min_lr_scale_factor: The minimum learning rate is eta * min_lr_scale_factor. | |
| Returns: | |
| - A lambda function to be used with torch.optim.lr_scheduler.LambdaLR. | |
| """ | |
| def lr_lambda(current_step: int): | |
| if current_step <= W: | |
| if W > 0: | |
| return current_step / W | |
| else: | |
| return 1.0 | |
| elif current_step <= S: | |
| return 1 | |
| else: | |
| return ( | |
| min_lr_scale_factor | |
| + (1 - min_lr_scale_factor) | |
| * (1 + math.cos(math.pi * (current_step - S) / (D - S))) | |
| / 2 | |
| ) | |
| return lr_lambda |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment