This is the decorator we will use for all of our scheduling functions, as it transforms a function taking (start, end, pos)
to something taking (start, end)
and return a function depending of pos
.
annealings = "NO LINEAR COS EXP".split()
p = torch.linspace(0.,1,100)
fns = [SchedNo, SchedLin, SchedCos, SchedExp]
for fn, t in zip(fns, annealings):
plt.plot(p, [fn(2, 1e-2)(o) for o in p], label=t)
f = SchedPoly(2,1e-2,0.5)
plt.plot(p, [f(o) for o in p], label="POLY(0.5)")
plt.legend();
sched = SchedLin(0, 2)
test_eq(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [0., 0.5, 1., 1.5, 2.])
sched = SchedCos(0, 2)
test_close(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [0., 0.29289, 1., 1.70711, 2.])
sched = SchedNo(0, 2)
test_close(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [0., 0., 0., 0., 0.])
sched = SchedExp(1, 2)
test_close(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [1., 1.18921, 1.41421, 1.68179, 2.])
sched = SchedPoly(0, 2, 2)
test_close(L(map(sched, [0., 0.25, 0.5, 0.75, 1.])), [0., 0.125, 0.5, 1.125, 2.])
p = torch.linspace(0.,1,100)
pows = [0.5,1.,2.]
for e in pows:
f = SchedPoly(2, 0, e)
plt.plot(p, [f(o) for o in p], label=f'power {e}')
plt.legend();
pcts
must be a list of positive numbers that add up to 1 and is the same length as scheds
. The generated function will use scheds[0]
from 0 to pcts[0]
then scheds[1]
from pcts[0]
to pcts[0]+pcts[1]
and so forth.
p = torch.linspace(0.,1,100)
f = combine_scheds([0.3,0.7], [SchedCos(0.3,0.6), SchedCos(0.6,0.2)])
plt.plot(p, [f(o) for o in p]);
p = torch.linspace(0.,1,100)
f = combine_scheds([0.3,0.2,0.5], [SchedLin(0.,1.), SchedNo(1.,1.), SchedCos(1., 0.)])
plt.plot(p, [f(o) for o in p]);
This is a useful helper function for the 1cycle policy. pct
is used for the start
to middle
part, 1-pct
for the middle
to end
. Handles floats or collection of floats. For example:
f = combined_cos(0.25,0.5,1.,0.)
plt.plot(p, [f(o) for o in p]);
scheds
is a dictionary with one key for each hyper-parameter you want to schedule, with either a scheduler or a list of schedulers as values (in the second case, the list must have the same length as the the number of parameters groups of the optimizer).
learn = synth_learner()
sched = {'lr': SchedLin(1e-3, 1e-2)}
learn.fit(1, cbs=ParamScheduler(sched))
n = len(learn.dls.train)
test_close(learn.recorder.hps['lr'], [1e-3 + (1e-2-1e-3) * i/n for i in range(n)])
The 1cycle policy was introduced by Leslie N. Smith et al. in Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates. It schedules the learning rate with a cosine annealing from lr_max/div
to lr_max
then lr_max/div_final
(pass an array to lr_max
if you want to use differential learning rates) and the momentum with cosine annealing according to the values in moms
. The first phase takes pct_start
of the training. You can optionally pass additional cbs
and reset_opt
.
learn = synth_learner(lr=1e-2)
xb,yb = learn.dls.one_batch()
init_loss = learn.loss_func(learn.model(xb), yb)
learn.fit_one_cycle(2)
xb,yb = learn.dls.one_batch()
final_loss = learn.loss_func(learn.model(xb), yb)
assert final_loss < init_loss
lrs,moms = learn.recorder.hps['lr'],learn.recorder.hps['mom']
test_close(lrs, [combined_cos(0.25,1e-2/25,1e-2,1e-7)(i/20) for i in range(20)])
test_close(moms, [combined_cos(0.25,0.95,0.85,0.95)(i/20) for i in range(20)])
learn = synth_learner()
learn.fit_one_cycle(2)
learn.recorder.plot_sched()
learn = synth_learner()
learn.fit_flat_cos(2)
learn.recorder.plot_sched()
This schedule was introduced by Ilya Loshchilov et al. in SGDR: Stochastic Gradient Descent with Warm Restarts. It consists of n_cycles
that are cosine annealings from lr_max
(defaults to the Learner
lr) to 0, with a length of cycle_len * cycle_mult**i
for the i
-th cycle (first one is cycle_len
-long, then we multiply the length by cycle_mult
at each epoch). You can optionally pass additional cbs
and reset_opt
.
learn = synth_learner()
with learn.no_logging(): learn.fit_sgdr(3, 1)
test_eq(learn.n_epoch, 7)
iters = [k * len(learn.dls.train) for k in [0,1,3,7]]
for i in range(3):
n = iters[i+1]-iters[i]
#The start of a cycle can be mixed with the 0 of the previous cycle with rounding errors, so we test at +1
test_close(learn.recorder.lrs[iters[i]+1:iters[i+1]], [SchedCos(learn.lr, 0)(k/n) for k in range(1,n)])
learn.recorder.plot_sched()
learn.fine_tune(1)
from fastai.vision.all import *
set_seed(99, True)
path = untar_data(URLs.PETS)/'images'
image_files = get_image_files(path)
if sys.platform == "win32" and IN_NOTEBOOK:
image_files = random.choices(image_files, k=int(len(image_files)/8))
print("Randomly select 1/8 files in NOTEBOOK on Windows to save time")
# pickle can't serializer lamda function.
def _label_func(x):
return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
path, image_files, valid_pct=0.2,
label_func=_label_func, item_tfms=Resize(224))
learn = cnn_learner(dls, resnet18)
learn.fit(1)
learn.opt.state_dict()['state'][1]['grad_avg']
learn.lr_find()
learn.opt.state_dict()['state'][1]['grad_avg']
learn.lr_find()
learn.opt.state_dict()['state'][1]['grad_avg']
import tempfile
from fastcore.basics import range_of
from fastcore.xtras import Path
with tempfile.TemporaryDirectory() as d:
learn = synth_learner(path=Path(d))
init_a,init_b = learn.model.a,learn.model.b
with learn.no_logging(): learn.fit(20, cbs=LRFinder(num_it=100))
assert len(learn.recorder.lrs) <= 100
test_eq(len(learn.recorder.lrs), len(learn.recorder.losses))
#Check stop if diverge
if len(learn.recorder.lrs) < 100: assert learn.recorder.losses[-1] > 4 * min(learn.recorder.losses)
#Test schedule
test_eq(learn.recorder.lrs, [SchedExp(1e-7, 10)(i/100) for i in range_of(learn.recorder.lrs)])
#No validation data
test_eq([len(v) for v in learn.recorder.values], [1 for _ in range_of(learn.recorder.values)])
#Model loaded back properly
test_eq(learn.model.a, init_a)
test_eq(learn.model.b, init_b)
test_eq(learn.opt.state_dict()['state'], {})
First introduced by Leslie N. Smith in Cyclical Learning Rates for Training Neural Networks, the LR Finder trains the model with exponentially growing learning rates from start_lr
to end_lr
for num_it
and stops in case of divergence (unless stop_div=False
) then plots the losses vs the learning rates with a log scale.
A good value for the learning rates is then either:
- one tenth of the minimum before the divergence
- when the slope is the steepest
Those two values are returned by default by the Learning Rate Finder.
with tempfile.TemporaryDirectory() as d:
learn = synth_learner(path=Path(d))
weights_pre_lr_find = L(learn.model.parameters())
lr_min,lr_steep = learn.lr_find()
weights_post_lr_find = L(learn.model.parameters())
test_eq(weights_pre_lr_find, weights_post_lr_find)
print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")