Pytorch to fastai, Bridging the Gap

fastai
Published

February 14, 2021

Understanding how to bring Pytorch code into the fastai space with minimal headache

#hide_input
from wwf.utils import state_versions
state_versions(['fastai', 'fastcore', 'torch', 'torchvision'])

Addressing the Elephant in the Room

I recently posted a tweet asking about what people struggle with the most in fastai, and the resounding answer was how to integrate minimally with Pytorch. An impression seems to have been made that to use fastai you must use the complete fastai API only, and nothing else.

Let’s clear up that misconception now:

Important: fastai at its core is a training loop, designed to be framework agnostic. You can use any flavor of Pytorch you want, and only use fastai to quickly and effictively train a model with state-of-the-art practices

The Plan

Now that the misconceptions have been addressed, let’s walk through just how that is going to happen. We’re going to follow the official Pytorch CIFAR10 tutorial and show what needs to minimally happen in the fastai framework to take full advantage of the Learner. This will include:

  • The Dataset
  • The DataLoaders
  • The model
  • The optimizer

The Dataset and DataLoaders

Following from the tutorial, we’re going to load in the dataset using only torchvision. First we’ll grab our imports:

import torch
import torchvision
import torchvision.transforms as transforms

Next we’re going to definine some minimal transforms:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

Before downloading our train and test sets:

Note: I’m using naming conventions similar to how fastai names things, so you can see how these can relate to each other

dset_train = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
dset_test = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

Next we’ll make our Dataloaders:

trainloader = torch.utils.data.DataLoader(dset_train, batch_size=4,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(dset_test, batch_size=4,
                                         shuffle=False, num_workers=2)

And that’s as far as we’ll go from there for now, let’s move onto the model next

The Model

We’ll bring in the architecture from the tutorial and use it here:

import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

And finally we’ll make an instance of it:

net = Net()

Loss Function and Optimizer

Next we’ll bring in their loss function and optimizer.

The loss function is simple enough:

criterion = nn.CrossEntropyLoss()

However the optimizer requires a little bit of fastai magic, specifically in the form of an OptimWrapper. Our optimizer function should be defined as below:

from fastai.optimizer import OptimWrapper
from torch import optim
def opt_func(params, **kwargs): return OptimWrapper(optim.SGD(params, lr=0.001))

Training

Now we have everything needed to train a model, so now let’s bring in fastai’s training loop, also known as the Learner.

fastai’s Learner expects DataLoaders to be used, rather than simply one DataLoader, so let’s make that:

Note: fastai also expects a validation DataLoader to be present, so we’ll be tying the testloader in here

from fastai.data.core import DataLoaders
dls = DataLoaders(trainloader, testloader)

Finally we’re going to wrap it all up in a Learner. As mentioned before, the Learner is the glue that merges everything together and enables users to utilize Leslie Smith’s One-Cycle Policy, the learning rate finder, and other fastai training goodies.

Let’s make it by passing in our dls, the model, the optimizer, and the loss function:

from fastai.learner import Learner

To get fastai’s fancy-looking progress bar, we need to import the ProgressCallback:

from fastai.callback.progress import ProgressCallback

We also need to pass in the CudaCallback so our batches can be pushed to the GPU (fastai’s DataLoaders can do this automatically)

from fastai.callback.data import CudaCallback
learn = Learner(dls, net, loss_func=criterion, opt_func=opt_func, cbs=[CudaCallback])

Finally, let’s do some minimal training.

Now we have everything needed to do a basic fit: > Note: Since we already passed in a learning rate to Learner we don’t need to pass one in here

learn.fit(2)

What’s Next?

Great, so now we’ve trained our model, but what do we do with it? How do I get it out?

Your model lives in learn.model, and we’ve already seen that we passed in a regular Pytorch model earlier. Since we’re using fastai’s base Learner class, the model itself was untouched. As a result, it’s still a regular Pytorch model we can save away:

torch.save(learn.model.state_dict(), './cifar_net.pth')

And that’s really it! As you can see, the minimalist you can absolutely get with using the fastai framework is:

  • Pytorch DataLoader
  • Pytorch model
  • fastai Learner
  • fastai Optimizer

Closing Remarks

I hope this has enlightned you on just how flexible the fastai framework can truly be for your training needs with the idealistic goal of simply getting a model out there.

As we’ve removed most of the fastai magic, from here on out you should be utilizing standard Pytorch, as fastai specific functions like test_dl and predict will no longer be able to be used, as you didn’t use a fastai DataLoader.

Thank you for reading!