Skip to content

Instantly share code, notes, and snippets.

@finbarrtimbers
Last active March 9, 2025 18:40
Show Gist options
  • Select an option

  • Save finbarrtimbers/98a03be83a8953a461f8b1d8716feebc to your computer and use it in GitHub Desktop.

Select an option

Save finbarrtimbers/98a03be83a8953a461f8b1d8716feebc to your computer and use it in GitHub Desktop.
Adam
class SimpleAdam(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
super().__init__(params, defaults={'lr': lr})
self.state = {}
self.t = 0
self.betas = betas
self.eps = eps
for group in self.param_groups:
for p in group['params']:
self.state[p] = {
'first_moment': torch.zeros_like(p.data),
'second_moment': torch.zeros_like(p.data),
}
# Step Method
def step(self):
self.t += 1
for group in self.param_groups:
for p in group['params']:
assert p in self.state, f"{p} not in state"
first_moment = self.state[p]['first_moment']
second_moment = self.state[p]['second_moment']
first_moment = self.betas[0] * first_moment + (1 - self.betas[0]) * p.grad.data
second_moment = self.betas[1] * second_moment + (1 - self.betas[1]) * (p.grad.data ** 2)
self.state[p]['first_moment'] = first_moment
self.state[p]['second_moment'] = second_moment
first_moment_corrected = first_moment / (1 - self.betas[0] ** self.t)
second_moment_corrected = second_moment / (1 - self.betas[1] ** self.t)
p.data -= group['lr'] * first_moment_corrected / (second_moment_corrected.sqrt() + self.eps)
class SimpleAdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay: float = 1e-5):
super().__init__(params, defaults={'lr': lr})
self.state = {}
self.t = 0
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay
for group in self.param_groups:
for p in group['params']:
self.state[p] = {
'first_moment': torch.zeros_like(p.data),
'second_moment': torch.zeros_like(p.data),
}
# Step Method
def step(self):
self.t += 1
for group in self.param_groups:
for p in group['params']:
assert p in self.state, f"{p} not in state"
first_moment = self.state[p]['first_moment']
second_moment = self.state[p]['second_moment']
first_moment = self.betas[0] * first_moment + (1 - self.betas[0]) * p.grad.data
second_moment = self.betas[1] * second_moment + (1 - self.betas[1]) * (p.grad.data ** 2)
self.state[p]['first_moment'] = first_moment
self.state[p]['second_moment'] = second_moment
first_moment_corrected = first_moment / (1 - self.betas[0] ** self.t)
second_moment_corrected = second_moment / (1 - self.betas[1] ** self.t)
p.data -= group['lr'] * self.weight_decay * p.data
p.data -= group['lr'] * first_moment_corrected / (second_moment_corrected.sqrt() + self.eps)
@ket395
Copy link

ket395 commented Mar 9, 2025

Passing comment

See the SimpleAdam class.

This proves yet again the apophthegm code in any language is nigh unreadable even when the language does not forbid it.

Also this most probably doesn't work as expected of a Adam algo implementation.

Do consider using enums here to swap out the values of decay rates specific to a dataset.
Where do you check for completion?

Manish
Check out my kind of oldfangled, kinda newfangled, kinda cool, inevitably thought-provoking profile.

Manish
Check out my kind of oldfangled, kinda newfangled, kinda cool, inevitably thought-provoking profile.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment