Pytorch to fastai, Bridging the Gap
Understanding how to bring Pytorch code into the fastai space with minimal headache
#hide_input
from wwf.utils import state_versions
state_versions(['fastai', 'fastcore', 'torch', 'torchvision'])Addressing the Elephant in the Room
I recently posted a tweet asking about what people struggle with the most in fastai, and the resounding answer was how to integrate minimally with Pytorch. An impression seems to have been made that to use fastai you must use the complete fastai API only, and nothing else.
Let’s clear up that misconception now:
Important:
fastaiat its core is a training loop, designed to be framework agnostic. You can use any flavor of Pytorch you want, and only usefastaito quickly and effictively train a model with state-of-the-art practices
The Plan
Now that the misconceptions have been addressed, let’s walk through just how that is going to happen. We’re going to follow the official Pytorch CIFAR10 tutorial and show what needs to minimally happen in the fastai framework to take full advantage of the Learner. This will include:
- The
Dataset - The
DataLoaders - The model
- The optimizer
The Dataset and DataLoaders
Following from the tutorial, we’re going to load in the dataset using only torchvision. First we’ll grab our imports:
import torch
import torchvision
import torchvision.transforms as transformsNext we’re going to definine some minimal transforms:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])Before downloading our train and test sets:
Note: I’m using naming conventions similar to how
fastainames things, so you can see how these can relate to each other
dset_train = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)dset_test = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)Next we’ll make our Dataloaders:
trainloader = torch.utils.data.DataLoader(dset_train, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(dset_test, batch_size=4,
shuffle=False, num_workers=2)And that’s as far as we’ll go from there for now, let’s move onto the model next
The Model
We’ll bring in the architecture from the tutorial and use it here:
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return xAnd finally we’ll make an instance of it:
net = Net()Loss Function and Optimizer
Next we’ll bring in their loss function and optimizer.
The loss function is simple enough:
criterion = nn.CrossEntropyLoss()However the optimizer requires a little bit of fastai magic, specifically in the form of an OptimWrapper. Our optimizer function should be defined as below:
from fastai.optimizer import OptimWrapper
from torch import optimdef opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, lr=0.001))Training
Now we have everything needed to train a model, so now let’s bring in fastai’s training loop, also known as the Learner.
fastai’s Learner expects DataLoaders to be used, rather than simply one DataLoader, so let’s make that:
Note: fastai also expects a validation
DataLoaderto be present, so we’ll be tying thetestloaderin here
from fastai.data.core import DataLoadersdls = DataLoaders(trainloader, testloader)Finally we’re going to wrap it all up in a Learner. As mentioned before, the Learner is the glue that merges everything together and enables users to utilize Leslie Smith’s One-Cycle Policy, the learning rate finder, and other fastai training goodies.
Let’s make it by passing in our dls, the model, the optimizer, and the loss function:
from fastai.learner import LearnerTo get fastai’s fancy-looking progress bar, we need to import the ProgressCallback:
from fastai.callback.progress import ProgressCallbackWe also need to pass in the CudaCallback so our batches can be pushed to the GPU (fastai’s DataLoaders can do this automatically)
from fastai.callback.data import CudaCallbacklearn = Learner(dls, net, loss_func=criterion, opt_func=opt_func, cbs=[CudaCallback])Finally, let’s do some minimal training.
Now we have everything needed to do a basic fit: > Note: Since we already passed in a learning rate to Learner we don’t need to pass one in here
learn.fit(2)What’s Next?
Great, so now we’ve trained our model, but what do we do with it? How do I get it out?
Your model lives in learn.model, and we’ve already seen that we passed in a regular Pytorch model earlier. Since we’re using fastai’s base Learner class, the model itself was untouched. As a result, it’s still a regular Pytorch model we can save away:
torch.save(learn.model.state_dict(), './cifar_net.pth')And that’s really it! As you can see, the minimalist you can absolutely get with using the fastai framework is:
PytorchDataLoaderPytorchmodelfastaiLearnerfastaiOptimizer
Closing Remarks
I hope this has enlightned you on just how flexible the fastai framework can truly be for your training needs with the idealistic goal of simply getting a model out there.
As we’ve removed most of the fastai magic, from here on out you should be utilizing standard Pytorch, as fastai specific functions like test_dl and predict will no longer be able to be used, as you didn’t use a fastai DataLoader.
Thank you for reading!