Skip to content

Instantly share code, notes, and snippets.

@oscarknagg
Last active January 26, 2020 05:35
Show Gist options
  • Select an option

  • Save oscarknagg/1fbd51a9f4569d7b4e8d52c1e8487ffb to your computer and use it in GitHub Desktop.

Select an option

Save oscarknagg/1fbd51a9f4569d7b4e8d52c1e8487ffb to your computer and use it in GitHub Desktop.
Key functionality for the supervised learning part of Model-Agnostic Meta-Learning (Finn et al 2017)
import torch
import torch.nn.functional as F
def replace_grad(parameter_gradients, parameter_name):
"""Creates a backward hook function that replaces the calculated gradient
with a precomputed value when .backward() is called.
See
https://pytorch.org/docs/stable/autograd.html?highlight=hook#torch.Tensor.register_hook
for more info
"""
def replace_grad_(module):
return parameter_gradients[parameter_name]
return replace_grad_
def functional_forward(x: torch.Tensor, weights: dict):
"""Performs a forward pass of the network using the PyTorch functional API."""
for block in [1, 2, 3, 4]:
x = functional_conv_block(x, weights[f'conv{block}.0.weight'], weights[f'conv{block}.0.bias'],
weights.get(f'conv{block}.1.weight'), weights.get(f'conv{block}.1.bias'))
x = x.view(x.size(0), -1)
x = F.linear(x, weights['logits.weight'], weights['logits.bias'])
return x
def meta_gradient_step(model: Module,
optimiser: Optimizer,
loss_fn: Callable,
x: torch.Tensor,
y: torch.Tensor,
n_shot: int,
k_way: int,
q_queries: int,
order: int,
inner_train_steps: int,
inner_lr: float,
train: bool,
device: Union[str, torch.device]):
"""
Perform a gradient step on a meta-learner.
# Arguments
model: Base model of the meta-learner being trained
optimiser: Optimiser to calculate gradient step from loss
loss_fn: Loss function to calculate between predictions and outputs
x: Input samples for all few shot tasks
y: Input labels of all few shot tasks
n_shot: Number of examples per class in the support set of each task
k_way: Number of classes in the few shot classification task of each task
q_queries: Number of examples per class in the query set of each task. The query set is used to calculate
meta-gradients after applying the update to
order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the
query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated
weights on the query with respect to the original weights).
inner_train_steps: Number of gradient steps to fit the fast weights during each inner update
inner_lr: Learning rate used to update the fast weights on the inner update
train: Whether to update the meta-learner weights at the end of the episode.
device: Device on which to run computation
"""
data_shape = x.shape[2:]
create_graph = (True if order == 2 else False) and train
task_gradients = []
task_losses = []
task_predictions = []
for meta_batch_samples, meta_batch_labels in zip(x, y):
# By construction x is a 5D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height)
# Hence when we iterate over the first dimension we are iterating through the meta batches
# Equivalently y is a 2D tensor of shape: (meta_batch_size, n*k + q*k, 1)
x_task_train = meta_batch_samples[:n_shot * k_way]
x_task_val = meta_batch_samples[n_shot * k_way:]
y_task_train = meta_batch_labels[:n_shot * k_way]
y_task_val = meta_batch_labels[n_shot * k_way:]
# Create a fast model using the current meta model weights
fast_weights = OrderedDict(model.named_parameters())
# Train the model for `inner_train_steps` iterations
for inner_batch in range(inner_train_steps):
# Perform update of model weights
logits = model.functional_forward(x_task_train, fast_weights)
loss = loss_fn(logits, y_task_train)
gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)
# Update weights manually
fast_weights = OrderedDict(
(name, param - inner_lr * grad)
for ((name, param), grad) in zip(fast_weights.items(), gradients)
)
# Do a pass of the model on the validation data from the current task
logits = functional_forward(x_task_val, fast_weights)
loss = loss_fn(logits, y_task_val)
loss.backward(retain_graph=True)
# Get post-update accuracies
y_pred = logits.softmax(dim=1)
task_predictions.append(y_pred)
# Accumulate losses and gradients
task_losses.append(loss)
gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)
named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)}
task_gradients.append(named_grads)
if order == 1:
if train:
sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
for k in task_gradients[0].keys()}
hooks = []
for name, param in model.named_parameters():
hooks.append(
param.register_hook(replace_grad(sum_task_gradients, name))
)
model.train()
optimiser.zero_grad()
# Dummy pass in order to create `loss` variable
# Replace dummy gradients with mean task gradients using hooks
logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
loss.backward()
optimiser.step()
for h in hooks:
h.remove()
return torch.stack(task_losses).mean(), torch.cat(task_predictions)
elif order == 2:
model.train()
optimiser.zero_grad()
meta_batch_loss = torch.stack(task_losses).mean()
if train:
meta_batch_loss.backward()
optimiser.step()
return meta_batch_loss, torch.cat(task_predictions)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment