Skip to content

Instantly share code, notes, and snippets.

@dhbrojas
Created July 21, 2025 11:05
Show Gist options
  • Select an option

  • Save dhbrojas/5d931b533fc6be5e4a4380892cf45caf to your computer and use it in GitHub Desktop.

Select an option

Save dhbrojas/5d931b533fc6be5e4a4380892cf45caf to your computer and use it in GitHub Desktop.
Linear Gradient Accumulation Schedule
import torch
class GradientAccumulationSchedule:
"""
A schedule that linearly increases the number of gradient accumulation
steps throughout training to converge faster.
"""
def __init__(self, *, min: int, max: int, steps: int, factor: int | None = None):
factor = factor if factor is not None else 1
assert min > 0 and max > min, f"expect {min} < {max}"
assert min % factor == 0 and max % factor == 0, f"expect {min} % {factor} == 0 and {max} % {factor} == 0"
self.min = min
self.max = max
self.factor = factor
self.steps = steps
self.values = torch.clamp(
torch.linspace(min // factor, (max // factor) + 1, steps, dtype=torch.long),
min // factor, max // factor
) * factor
self.average = self.values.float().mean().item()
values, counts = torch.unique(self.values, return_counts=True)
self.counts = {
int(value.item()): int(count.item())
for value, count in zip(values, counts)
}
def value(self, step: int) -> int:
return int(self.values[min(step, self.steps - 1)].item())
# Example usage
schedule = GradientAccumulationSchedule(min=2, max=8, steps=25000, factor=2)
print(schedule.counts)
# {2: 6250, 4: 6250, 6: 6250, 8: 6250}
print(schedule.values)
# tensor([2, 2, 2, ..., 8, 8, 8])
print(acc.value(10000))
# 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment