Created
July 21, 2025 11:05
-
-
Save dhbrojas/5d931b533fc6be5e4a4380892cf45caf to your computer and use it in GitHub Desktop.
Linear Gradient Accumulation Schedule
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| class GradientAccumulationSchedule: | |
| """ | |
| A schedule that linearly increases the number of gradient accumulation | |
| steps throughout training to converge faster. | |
| """ | |
| def __init__(self, *, min: int, max: int, steps: int, factor: int | None = None): | |
| factor = factor if factor is not None else 1 | |
| assert min > 0 and max > min, f"expect {min} < {max}" | |
| assert min % factor == 0 and max % factor == 0, f"expect {min} % {factor} == 0 and {max} % {factor} == 0" | |
| self.min = min | |
| self.max = max | |
| self.factor = factor | |
| self.steps = steps | |
| self.values = torch.clamp( | |
| torch.linspace(min // factor, (max // factor) + 1, steps, dtype=torch.long), | |
| min // factor, max // factor | |
| ) * factor | |
| self.average = self.values.float().mean().item() | |
| values, counts = torch.unique(self.values, return_counts=True) | |
| self.counts = { | |
| int(value.item()): int(count.item()) | |
| for value, count in zip(values, counts) | |
| } | |
| def value(self, step: int) -> int: | |
| return int(self.values[min(step, self.steps - 1)].item()) | |
| # Example usage | |
| schedule = GradientAccumulationSchedule(min=2, max=8, steps=25000, factor=2) | |
| print(schedule.counts) | |
| # {2: 6250, 4: 6250, 6: 6250, 8: 6250} | |
| print(schedule.values) | |
| # tensor([2, 2, 2, ..., 8, 8, 8]) | |
| print(acc.value(10000)) | |
| # 4 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment