Last active
November 19, 2018 01:01
-
-
Save surmenok/8de6fd87af5f3dfe72e84ccaf690167e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def plot_loss_change(sched, sma=1, n_skip=20, y_lim=(-0.01,0.01)): | |
| """ | |
| Plots rate of change of the loss function. | |
| Parameters: | |
| sched - learning rate scheduler, an instance of LR_Finder class. | |
| sma - number of batches for simple moving average to smooth out the curve. | |
| n_skip - number of batches to skip on the left. | |
| y_lim - limits for the y axis. | |
| """ | |
| derivatives = [0] * (sma + 1) | |
| for i in range(1 + sma, len(sched.lrs)): | |
| derivative = (sched.losses[i] - sched.losses[i - sma]) / sma | |
| derivatives.append(derivative) | |
| plt.ylabel("d/loss") | |
| plt.xlabel("learning rate (log scale)") | |
| plt.plot(sched.lrs[n_skip:], derivatives[n_skip:]) | |
| plt.xscale('log') | |
| plt.ylim(y_lim) | |
| learn.lr_find() | |
| plot_loss_change(learn.sched, sma=20, n_skip=20, y_lim=(-0.02, 0.01)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment