Generating Comparitive Baselines for CAMVID with fastai’s Dynamic Unet

Published

September 19, 2020

Exploring how baselines are being made and where fastai can fit in

This blog is also a Jupyter notebook available to run from the top down. There will be code snippets that you can then run in any environment. In this section I will be posting what version of fastai and fastcore I am currently running at the time of writing this:

  • fastai: 2.0.13

  • fastcore: 1.0.13

Note: this blog is the result of joint efforts between myself and Juvian on the forums

CAMVID Benchmarks, Can’t We Just Use the Code from Class?

In the fastai course, we are walked through the CAMVID dataset, semantic segmentation with a car’s point of view. Ideally, we would then like to compare our results to the current state-of-the-art benchmarks.

However! In it’s current state, this cannot be done.

Why you might ask? Recently the benchmarks have been adapting a few “weird” changes, as well as making the dataset slightly easier, although comparing against them is not as straightforward as you would hope either (more on this at the end)

The Metric:

First, the reported metrics are different. Instead of accuracy the Mean Intersection Over Union (mIOU) is reported, as well as individual IOU’s per class

The Number of Classes:

In the original fastai version of the dataset, 31 classes are present, with an additional void classes that is ignored in the resulting benchmarks.

Researchers have since changed this class distribution to 11 total classes: Building, Tree, Sky, Car, Sign, Road, Pedestrian, Fence, Pole, Sidewalk, and Cyclist, with one more twelveth void class that is again, not taken into account.

This change in classes allows for a higher mIOU being reported without having the rarely-seen classes scew the results, so if you were running mIOU on the class notebooks and getting ~20% and being confused why it doesn’t align, this is why!

The Splits

When we train with fastai, we wind up mixing in the baseline evaluation dataset with the training data! Not something we want at all! The train/validation/test split in most papers tends to be: 367/101/233. That is correct, there is two-times as many test images as there are validation.

The SegNet Version

This is a version that has images and labels coming in at a size of 360x480 pixels, which is half the size of what fastai’s source dataset is, but has its labels with the 11 classes. What is different paper to paper however is how they use the dataset, which can lead to issues. Let’s look at the current options and their pros/cons:

Using the SegNet Dataset:

If we decide to use only this dataset, there is not much room for fastai’s tricks (such as progressive resizing and Pre-Sizing). That being said, there are papers which use this. If you look on the CAMVID leaderboard however, you’ll notice the best model placed at 8th. So what’s next?

Well, what is the SOTA we’re comparing against then?

While below is a benchmark, we can’t truly compare against it. However, if we wish to, we will be focusing on the models that have an ImageNet backbone:

Using the fastai Images with Smaller Labels

fastai uses the high-quality 720x960 images and labels, so it would make logical sense to train on them and use these smaller masks as the labels, which is being done on all the upper benchmarks.

The Issue

There is a very big issue with this though, which Jeremy pointed out to us while we were discussing these new benchmark approaches. Simply upscaling the labels, without any adjustments to the fastai images, on its own sounds “weird.” Instead, what we do is resize the images back down to the 360x480 size before then upsampling them. This winds up increasing the final accuracy

Can We Train Now?

Okay, enough talking, can we see some code to back up your claims?

Sure, let’s do it! To visualize what we will be doing, throughout this blog we will be:

  1. Downloading a different dataset
  2. Making a DataBlock which pre-sizes our images to the proper size
  3. Making a unet_learner which:
  • Uses a pretrained ResNet34 backbone
  • Uses the ranger optimizer function
  • Compare the use of ReLU and Mish in the head
  • Uses both IOU and mIOU metrics to properly allow us to benchmark the results
  1. Make a test_dl with the proper test set to evaluate with.

Downloading the Dataset

The dataset currently lives in a the repository, so we will go ahead and clone it and make it our working directory:

!git clone https://github.com/alexgkendall/SegNet-Tutorial.git
%cd SegNet-Tutorial/

Now we still want to use fastai’s input images, so we’ll go ahead and pull their CAMVID dataset. First let’s import fastai’s vision module:

from fastai.vision.all import *

Then grab the data:

path_i = untar_data(URLs.CAMVID)

Let’s see how both datasets are formatted:

path_l = Path('')
path_i.ls()
path_l.ls()

So we can see that fastai has the usual images and labels folder, while we can’t quite tell where the annotations are in our second one. Let’s narrow down to the CamVid folder:

path_l = path_l/'CamVid'
path_l.ls()

And we can see a better looking dataset! The three folders we will be caring about are trainannot, valannot and testannot, as these are where the labels live.

DataBlock

As we saw how the data was split up, fastai currently doesn’t have something to work along those lines, the closest is GrandparentSplitter. We’ll write something similar called FolderSplitter, which can accept names for the train and validation folders:

def _folder_idxs(items, name):
    def _inner(items, name): return mask2idxs(Path(o).parents[0].name == name for o in items)
    return [i for n in L(name) for i in _inner(items, n)]

def FolderSplitter(train_name='train', valid_name='valid'):
    "Split `items` from parent folder names `parent_idx` levels above the item"
    def _inner(o):
        return _folder_idxs(o, train_name),_folder_idxs(o, valid_name)
    return _inner

Next we will need a way to get our x images, since they live differently than our labels. We can use a custom function to do so:

def get_x(o): return path_i/'images'/o.name

Finally we need a get_y that will use that same filename to go grab our working masks:

def get_mask(o): return o.parent.parent/(o.parent.name + 'annot')/o.name

We have almost all the pieces to making our dataset now. We’ll use fastai’s progressive resizing when training, and pass in a set of codes for our dataset:

codes = ['Sky', 'Building', 'Pole', 'Road', 'Pavement', 'Tree', 'SignSymbol', 'Fence', 'Car', 'Pedestrian', 'Bicyclist', 'Unlabelled']
half, full = (360, 480), (720, 960)

Now for those transforms. I mentioned earlier we will be downscaling and then upscaling the images, this way the same upscaling is applied to our labels and our images, though the images start from a higher quality. Since we want to train small, we’ll resize it back down in the batch transforms as well as normalize our inputs:

item_tfms = [Resize(half), Resize(full)]
batch_tfms = [*aug_transforms(size=half), Normalize.from_stats(*imagenet_stats)]

And with this we can now build the DataBlock and DataLoaders:

camvid = DataBlock(blocks=(ImageBlock, MaskBlock(codes=codes)),
                   get_items=get_image_files,
                   splitter=FolderSplitter(valid_name='val'),
                   get_x=get_x,
                   get_y=get_mask,
                   item_tfms=item_tfms,
                   batch_tfms=batch_tfms)

We’ll call the .summary() to make sure our images and masks do crop to the half size:

camvid.summary(path_s)

We can see the final input and mask size is (360,480), which is what we want! Let’s go ahead and make them DataLoaders:

dls = camvid.dataloaders(path_l, bs=4)

Since we have a void column, our c attribute in the DataLoaders needs to be one less:

dls.c = len(codes) - 1

Metrics

For the next part Juvian was the one to bring this to life! We want class-wise IOU as well as mIOU, which are defined below:

class IOU(AvgMetric):
    "Intersection over Union Metric"
    def __init__(self, class_index, class_label, axis, ignore_index=-1): store_attr('axis,class_index,class_label,ignore_index')
    def accumulate(self, learn):
        pred, targ = learn.pred.argmax(dim=self.axis), learn.y
        intersec = ((pred == targ) & (targ == self.class_index)).sum().item()
        union = (((pred == self.class_index) | (targ == self.class_index)) & (targ != self.ignore_index)).sum().item()
        if union: self.total += intersec
        self.count += union
  
    @property
    def name(self): return self.class_label
from sklearn.metrics import confusion_matrix

class MIOU(AvgMetric):
    "Mean Intersection over Union Metric"
    def __init__(self, classes, axis): store_attr()

    def accumulate(self, learn):
        pred, targ = learn.pred.argmax(dim=self.axis).cpu(), learn.y.cpu()
        pred, targ = pred.flatten().numpy(), targ.flatten().numpy()
        self.total += confusion_matrix(targ, pred, range(0, self.classes))

    @property
    def value(self): 
        conf_matrix = self.total
        per_class_TP = np.diagonal(conf_matrix).astype(float)
        per_class_FP = conf_matrix.sum(axis=0) - per_class_TP
        per_class_FN = conf_matrix.sum(axis=1) - per_class_TP
        iou_index = per_class_TP / (per_class_TP + per_class_FP + per_class_FN)
        iou_index = np.nan_to_num(iou_index)
        mean_iou_index = (np.mean(iou_index))    

        return mean_iou_index

    @property
    def name(self): return 'miou'

With our metric functions defined, let’s combine them all. We’ll want a mIOU, as well as 11 IOU for each class:

metrics = [MIOU(11, axis=1)]

Note: we do not need to pass in an ignore_index here, as any values larger than 10 get ignored

And now let’s declare our IOU’s. Since there’s so many we’ll just make a function instead that relies on our codes:

for x in range(11): metrics.append(IOU(x, codes[x], axis=1, ignore_index=11))

With this we can finally move over to our model and training:

The Model and Training

For the model we will use a pretrained ResNet34 backbone architecture that has Mish on the head of the Dynamic Unet:

config = unet_config(self_attention=False, act_cls=Mish)

Our optimizer will be ranger:

opt_func = ranger

And finally, since we have an ignore_index we need to pass this into our loss function as well, otherwise we will trigger a CUDA error: device-side assert triggered

loss_func = CrossEntropyLossFlat(ignore_index=11, axis=1)

Now let’s pass this all into unet_learner:

learn = unet_learner(dls, resnet34, metrics=metrics, opt_func=opt_func, 
                     loss_func=loss_func, config=config)

Phase 1

We’ll find a good learning rate, fit for ten epochs frozen with GradientAccumulation to help with stability before unfreezing and training for a few more:

learn.lr_find()

A good learning rate is around 2e-3, so we’ll train with that using fit_flat_cos as the ranger optimizer should be paired with it:

lr = 2e-3
learn.fit_flat_cos(10, slice(lr), cbs=[GradientAccumulation(n_acc=16)])

Next we’ll unfreeze and train for 12 more epochs. When training we will adjust the learning rate and apply the EarlyStoppingCallback to help prevent overfitting:

lrs = slice(lr/400, lr/4)
learn.unfreeze()
learn.fit_flat_cos(12, lrs, cbs=[GradientAccumulation(n_acc=16)])

We’ll save away this model and quickly check how it’s doing on our test set:

learn.save("360")
fnames = get_image_files(path_l/'test')
test_dl = learn.dls.test_dl(fnames, with_labels=True)
metrics = learn.validate(dl=test_dl)[1:]
names = list(map(lambda x: x.name, learn.metrics))
for value, metric in zip(metrics, names):
  print(metric, value)

We can see a starting mIOU of 65% almost matching the mid-tier performer, let’s see if we can take it further by using the full sized images

Phase 2:

First let’s free up our memory:

del learn
torch.cuda.empty_cache()
import gc
gc.collect()

We’ll adjust our transforms to instead keep our full sized images:

item_tfms = [Resize(half), Resize(full)]
batch_tfms = [*aug_transforms(size=full), Normalize.from_stats(*imagenet_stats)]

And simply train again from there:

camvid = DataBlock(blocks=(ImageBlock, MaskBlock(codes=codes)),
                   get_items=get_image_files,
                   splitter=FolderSplitter(valid_name='val'),
                   get_x=get_x,
                   get_y=get_mask,
                   item_tfms=item_tfms,
                   batch_tfms=batch_tfms)

dls = camvid.dataloaders(path_l, bs=2)
dls.c = len(codes) - 1

We’ll need to re-declare our metrics as the current ones have memory of our last training session:

metrics = [MIOU(11, axis=1)]
for x in range(11): metrics.append(IOU(x, codes[x], axis=1, ignore_index=11))

And now let’s train:

learn = unet_learner(dls, resnet34, metrics=metrics, opt_func=opt_func,
                     config=config, loss_func=loss_func)
learn.load('360');
learn.freeze()

lr = 1e-3
learn.fine_tune(12, lr, cbs=[GradientAccumulation(n_acc=16), EarlyStoppingCallback()])

Let’s check it’s final IOU:

fnames = get_image_files(path_l/'test')
test_dl = learn.dls.test_dl(fnames, with_labels=True)
metrics = learn.validate(dl=test_dl)[1:]
names = list(map(lambda x: x.name, learn.metrics))
for value, metric in zip(metrics, names):
  print(metric, value)

Results and Discussion

At first we tried a standard Unet without any special tricks, and we got a test mIOU of around 59%. From this baseline we tried applying Self-Attention, Label Smoothing, and the Mish activation function (as the default is ReLU).

What we found is that by simply applying Mish we could boost that 59% to around 64%, and do note that Mish was only applied to the head of the model, not in the ResNet backbone. (with the highest we got around 67% mIOU)

Self Attention did not seem to help as much, bringing down the mIOU to 62% when training even with the Mish activation function.

Applying Label Smoothing led to a very different result baked inside of each individual IOU. While the mIOU was not as high as a flat Mish model, the distributions of the IOU’s changed.

When applying the proper presizing techniques demonstrated here, we saw a boost of 10% mIOU, confirming an idea that simply blowing up your masks to match the original image resolution can diminish the value inside of them.

Conclusions

What conclusions can we actually make from this study? Not as much as you would think, and the reason lies within current issues in Academia. Right now there are three different datasets being used:

  • fastai images with SegNet masks
  • SegNet images and masks
  • fastai images and labels while ignoring all the other classes

Well… who is right then? Technically 2 and 3 are right, but the three cannot be compared equally. Remember that benchmark table I showed earlier? If you go and read the papers each use one of the three techniques done here.

So… what can we make of this?

There is one direct conclusion we can make: using Mish in the head of our Dynamic Unet boosts the mIOU by 5%. So it is absolutely worth trying and using with your projects.

Where do we go from here?

A better dataset which is much more consistant is the CityScapes dataset. It’s for research only and you must upload your predictions on the test set to the website, it’s essentially a Kaggle competition for researchers, a format I believe works much better. Researchers compare both how they perform on the validation set and the test set. This is certainly an easier benchmark for folks to tackle with the fastai UNet, so hopefully one day someone will try a benchmark and see how it does!