Skip to content

Instantly share code, notes, and snippets.

@surmenok
Last active November 19, 2018 01:01
Show Gist options
  • Select an option

  • Save surmenok/8de6fd87af5f3dfe72e84ccaf690167e to your computer and use it in GitHub Desktop.

Select an option

Save surmenok/8de6fd87af5f3dfe72e84ccaf690167e to your computer and use it in GitHub Desktop.
def plot_loss_change(sched, sma=1, n_skip=20, y_lim=(-0.01,0.01)):
"""
Plots rate of change of the loss function.
Parameters:
sched - learning rate scheduler, an instance of LR_Finder class.
sma - number of batches for simple moving average to smooth out the curve.
n_skip - number of batches to skip on the left.
y_lim - limits for the y axis.
"""
derivatives = [0] * (sma + 1)
for i in range(1 + sma, len(sched.lrs)):
derivative = (sched.losses[i] - sched.losses[i - sma]) / sma
derivatives.append(derivative)
plt.ylabel("d/loss")
plt.xlabel("learning rate (log scale)")
plt.plot(sched.lrs[n_skip:], derivatives[n_skip:])
plt.xscale('log')
plt.ylim(y_lim)
learn.lr_find()
plot_loss_change(learn.sched, sma=20, n_skip=20, y_lim=(-0.02, 0.01))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment