Pytorch to fastai, Bridging the Gap
Understanding how to bring Pytorch code into the fastai space with minimal headache
#hide_input
from wwf.utils import state_versions
'fastai', 'fastcore', 'torch', 'torchvision']) state_versions([
Addressing the Elephant in the Room
I recently posted a tweet asking about what people struggle with the most in fastai
, and the resounding answer was how to integrate minimally with Pytorch. An impression seems to have been made that to use fastai
you must use the complete fastai
API only, and nothing else.
Let’s clear up that misconception now:
Important:
fastai
at its core is a training loop, designed to be framework agnostic. You can use any flavor of Pytorch you want, and only usefastai
to quickly and effictively train a model with state-of-the-art practices
The Plan
Now that the misconceptions have been addressed, let’s walk through just how that is going to happen. We’re going to follow the official Pytorch CIFAR10 tutorial and show what needs to minimally happen in the fastai
framework to take full advantage of the Learner
. This will include:
- The
Dataset
- The
DataLoaders
- The model
- The optimizer
The Dataset and DataLoaders
Following from the tutorial, we’re going to load in the dataset using only torchvision. First we’ll grab our imports:
import torch
import torchvision
import torchvision.transforms as transforms
Next we’re going to definine some minimal transforms:
= transforms.Compose(
transform
[transforms.ToTensor(),0.5,0.5,0.5), (0.5,0.5,0.5))]) transforms.Normalize((
Before downloading our train and test sets:
Note: I’m using naming conventions similar to how
fastai
names things, so you can see how these can relate to each other
= torchvision.datasets.CIFAR10(root='./data', train=True,
dset_train =True, transform=transform) download
= torchvision.datasets.CIFAR10(root='./data', train=False,
dset_test =True, transform=transform) download
Next we’ll make our Dataloaders
:
= torch.utils.data.DataLoader(dset_train, batch_size=4,
trainloader =True, num_workers=2)
shuffle= torch.utils.data.DataLoader(dset_test, batch_size=4,
testloader =False, num_workers=2) shuffle
And that’s as far as we’ll go from there for now, let’s move onto the model next
The Model
We’ll bring in the architecture from the tutorial and use it here:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
= self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
x return x
And finally we’ll make an instance of it:
= Net() net
Loss Function and Optimizer
Next we’ll bring in their loss function and optimizer.
The loss function is simple enough:
= nn.CrossEntropyLoss() criterion
However the optimizer requires a little bit of fastai
magic, specifically in the form of an OptimWrapper
. Our optimizer function should be defined as below:
from fastai.optimizer import OptimWrapper
from torch import optim
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, lr=0.001))
Training
Now we have everything needed to train a model, so now let’s bring in fastai
’s training loop, also known as the Learner
.
fastai
’s Learner
expects DataLoaders
to be used, rather than simply one DataLoader
, so let’s make that:
Note: fastai also expects a validation
DataLoader
to be present, so we’ll be tying thetestloader
in here
from fastai.data.core import DataLoaders
= DataLoaders(trainloader, testloader) dls
Finally we’re going to wrap it all up in a Learner
. As mentioned before, the Learner
is the glue that merges everything together and enables users to utilize Leslie Smith’s One-Cycle Policy, the learning rate finder, and other fastai
training goodies.
Let’s make it by passing in our dls
, the model, the optimizer, and the loss function:
from fastai.learner import Learner
To get fastai
’s fancy-looking progress bar, we need to import the ProgressCallback
:
from fastai.callback.progress import ProgressCallback
We also need to pass in the CudaCallback
so our batches can be pushed to the GPU (fastai
’s DataLoaders
can do this automatically)
from fastai.callback.data import CudaCallback
= Learner(dls, net, loss_func=criterion, opt_func=opt_func, cbs=[CudaCallback]) learn
Finally, let’s do some minimal training.
Now we have everything needed to do a basic fit
: > Note: Since we already passed in a learning rate to Learner
we don’t need to pass one in here
2) learn.fit(
What’s Next?
Great, so now we’ve trained our model, but what do we do with it? How do I get it out?
Your model lives in learn.model
, and we’ve already seen that we passed in a regular Pytorch model earlier. Since we’re using fastai’s base Learner
class, the model itself was untouched. As a result, it’s still a regular Pytorch model we can save away:
'./cifar_net.pth') torch.save(learn.model.state_dict(),
And that’s really it! As you can see, the minimalist you can absolutely get with using the fastai framework is:
Pytorch
DataLoader
Pytorch
modelfastai
Learner
fastai
Optimizer
Closing Remarks
I hope this has enlightned you on just how flexible the fastai framework can truly be for your training needs with the idealistic goal of simply getting a model out there.
As we’ve removed most of the fastai magic, from here on out you should be utilizing standard Pytorch, as fastai specific functions like test_dl
and predict
will no longer be able to be used, as you didn’t use a fastai
DataLoader.
Thank you for reading!