Pytorch to fastai, Bridging the Gap
Understanding how to bring Pytorch code into the fastai space with minimal headache
Addressing the Elephant in the Room
I recently posted a tweet asking about what people struggle with the most in fastai
, and the resounding answer was how to integrate minimally with Pytorch. An impression seems to have been made that to use fastai
you must use the complete fastai
API only, and nothing else.
Let's clear up that misconception now:
fastai
at its core is a training loop, designed to be framework agnostic. You can use any flavor of Pytorch you want, and only use fastai
to quickly and effictively train a model with state-of-the-art practices
The Plan
Now that the misconceptions have been addressed, let's walk through just how that is going to happen. We're going to follow the official Pytorch CIFAR10 tutorial and show what needs to minimally happen in the fastai
framework to take full advantage of the Learner
. This will include:
- The
Dataset
- The
DataLoaders
- The model
- The optimizer
import torch
import torchvision
import torchvision.transforms as transforms
Next we're going to definine some minimal transforms:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
Before downloading our train and test sets:
fastai
names things, so you can see how these can relate to each other
dset_train = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
dset_test = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
Next we'll make our Dataloaders
:
trainloader = torch.utils.data.DataLoader(dset_train, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(dset_test, batch_size=4,
shuffle=False, num_workers=2)
And that's as far as we'll go from there for now, let's move onto the model next
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
And finally we'll make an instance of it:
net = Net()
The loss function is simple enough:
criterion = nn.CrossEntropyLoss()
However the optimizer requires a little bit of fastai
magic, specifically in the form of an OptimWrapper
. Our optimizer function should be defined as below:
from fastai.optimizer import OptimWrapper
from torch import optim
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, lr=0.001))
Training
Now we have everything needed to train a model, so now let's bring in fastai
's training loop, also known as the Learner
.
fastai
's Learner
expects DataLoaders
to be used, rather than simply one DataLoader
, so let's make that:
DataLoader
to be present, so we’ll be tying the testloader
in here
from fastai.data.core import DataLoaders
dls = DataLoaders(trainloader, testloader)
Finally we're going to wrap it all up in a Learner
. As mentioned before, the Learner
is the glue that merges everything together and enables users to utilize Leslie Smith's One-Cycle Policy, the learning rate finder, and other fastai
training goodies.
Let's make it by passing in our dls
, the model, the optimizer, and the loss function:
from fastai.learner import Learner
To get fastai
's fancy-looking progress bar, we need to import the ProgressCallback
:
from fastai.callback.progress import ProgressCallback
We also need to pass in the CudaCallback
so our batches can be pushed to the GPU (fastai
's DataLoaders
can do this automatically)
from fastai.callback.data import CudaCallback
learn = Learner(dls, net, loss_func=criterion, opt_func=opt_func, cbs=[CudaCallback])
Finally, let's do some minimal training.
Now we have everything needed to do a basic fit
:
Learner
we don’t need to pass one in here
learn.fit(2)
What's Next?
Great, so now we've trained our model, but what do we do with it? How do I get it out?
Your model lives in learn.model
, and we've already seen that we passed in a regular Pytorch model earlier. Since we're using fastai's base Learner
class, the model itself was untouched. As a result, it's still a regular Pytorch model we can save away:
torch.save(learn.model.state_dict(), './cifar_net.pth')
And that's really it! As you can see, the minimalist you can absolutely get with using the fastai framework is:
Pytorch
DataLoader
Pytorch
modelfastai
Learner
fastai
Optimizer
Closing Remarks
I hope this has enlightned you on just how flexible the fastai framework can truly be for your training needs with the idealistic goal of simply getting a model out there.
As we've removed most of the fastai magic, from here on out you should be utilizing standard Pytorch, as fastai specific functions like test_dl
and predict
will no longer be able to be used, as you didn't use a fastai
DataLoader.
Thank you for reading!