Decorators

Published

July 6, 2022

An introduction to decorators including when they can be useful and how they’re written

TL;DR

Here’s the basic decorator written in this article (though I recommend still reading and going through it so it makes sense!)

from functools import partial

def addition_decorator(function:callable = None, verbose:bool=False):
    "A simple decorator that will ensure `function` is only called with `function(1,1)`"
    if function is None:
        return partial(addition_decorator, verbose=verbose)
    def decorator(*args, **kwargs):
        while True:
            try:
                return function(*args, **kwargs)
            except ValueError as e:
                is_a = "`a`" in e.args[0]
                a,b = args
                args = (a-1, b) if is_a else (a, b-1)
                if verbose: 
                    print(f'Args are now {args}')
    return decorator

@addition_decorator(verbose=True)
def addition(a,b):
    "Adds `a` and `b` together"
    if a > 1: 
        raise ValueError("`a` is greater than 1")
    if b > 1: 
        raise ValueError("`b` is greater than 1")
    return a+b

What is a decorator?

What is a decorator, and why is it useful?

A decorator can be thought of as code that “wraps” around other code. It encapsulates it, kind of like a blanket! They’re extremely helpful for setting up a particular configuration for the “decorated” (a function that has a decorator) function, or to do some particular behavior when something occurs in the code its wrapped around.

In python these are generally denoted with an @ symbol followed by the decorator function name, and these are placed above an actual function.

What is this article going to show?

I’m not going to dive into “how to make your own decorators from scratch” and so forth. There’s a W3 article on this.

Instead, we’re going to focus on writing a decorator given what you’d find everyone else using, a few applications of them, and how to understand just what a decorator is doing and how it’s written. What I’m hoping you will take out of this is patterns when writing decorators so you can understand them easier both from a reading and writing perspective, as well as when debugging functions to see what outside behaviour is being performed

Example case: a Retry Decorator

The decorator we’ll focus on is a “retry” loop. Say I have a function and I want to be able to catch if a particular call gets raised while it’s run.

From a Deep Learning perspective this could be something like a CUDA Out-of-Memory for instance.

If this Exception has been raised, I want to be able to run the code again slightly modifying one aspect of it to potentially avoid the error being raised.

For this example I’ll make a simplistic addition function that will only be ran if 1+1 is being done as its inputs.

(Does this make sense in the real world? Probably not. But you can get the simple idea!)

def addition(a,b):
    "Adds `a` and `b` together"
    if a > 1: 
        raise ValueError("`a` is greater than 1")
    if b > 1: 
        raise ValueError("`b` is greater than 1")
    return a+b

Now logically let’s think of how we’d want to catch this. We raise the same error type, but how do we know what input to change?

We can read Exception.args to get the actual message being sent, and use it to see which argument we should adjust.

Generally decorators are written as a function with an inner function. This inner function is what is then called when truly calling the function. Meanwhile the decorator takes the function in as the first parameter:

def addition_decorator(function: callable, verbose:bool=False):
    """
    A simple decorator that will ensure `function` is only called with `function(1,1)`
    """
    def decorator(*args, **kwargs):
        # This contains the args and kwargs for our `function`
        # We then do a `while` loop:
        while True:
            try:
                return function(*args, **kwargs)
            except ValueError as e:
                # We can then see if we need to adjust `a` or `b`
                is_a = "`a`" in e.args[0]
                # and then we adjust the current `args` based on this result:
                a,b = args
                # We can also print our attempt here:
                if verbose:
                    print(f'Tried to do {a} + {b}, but at least one argument is greater than 1!')
                if is_a:
                    if verbose:
                        print(f'Reducing `a` by 1: {a-1}')
                    args = (a-1, b)
                else:
                    if verbose:
                        print(f'Reducing `b` by 1: {b-1}')
                    args = (a, b-1)
    # Finally we return the inner function! *Very* important!
    return decorator

With this simple decorator, we will continuously loop over and try calling function until both a and b are equal to 1.

Now how do we actually apply this decorator?

We can do this one of two ways. We can “wrap” around the function and call it as a normal function. For example:

func = addition_decorator(addition)
func(2,2)

Now let’s pass in verbose=True to see just how it was really called:

func = addition_decorator(addition, verbose=True)
func(2,2)

You can see that we continuously reduced an input we passed in and passed it to the function before finally getting the value we want!

Now how do I write this in such a way that doesn’t require me to build this new function func and call it? How can I just call addition and still have addition_decorator?

First we would declare addition_decorator, and then add @addition_decorator to the top of our addition function like below:

@addition_decorator
def addition(a,b):
    "Adds `a` and `b` together"
    if a > 1: 
        raise ValueError("`a` is greater than 1")
    if b > 1: 
        raise ValueError("`b` is greater than 1")
    return a+b

Now when we see how addition is declared we can see it points to the decorator:

addition

And when we call addition it will have the same effect:

addition(2,2)

So how is it doing this?

By wrapping our function it uses that function to fill in the first parameter in our decorator by itself! This leaves the rest of them to their default values however.

Making it a bit more complex, passing in arguments

So how could we configure that verbose argument then?

Here’s where a bit of magic comes in through partial functions.

Partial’s allow us to create a loaded function with some values already filled in for us by default. This makes use of the functools library. First we’ll write a small function to test our point:

def subtraction(a,b): 
    "Subtract two numbers"
    return a-b

Next we’ll import partial and create our partial version of subtraction. For any values you wish to fill you should pass them as keyword arguments:

from functools import partial
partial_subtraction = partial(subtraction, a=2)

Now if I call partial_subtraction, I can just pass in b and it will work:

partial_subtraction(b=1)


We can perform a similar idea for our decorator, where we return a partial function first filling in our values that we want if `function` is `None`.

The reason for this is the decorator gets applied *before* the function is called within it, so we setup the parameters we want first. I'll print out when this occurs in our decorator as well so you can see this behavior, and I've also simplified it to show the point I'm trying to visualize for you:

> Note: Since everything can now be passed in as kwargs, each value in the function parameters **must** contain a default of some sort. Usually this would be `None`, which I've done here

```python
def addition_decorator(function:callable = None, verbose:bool=False):
    "A simple decorator that will ensure `function` is only called with `function(1,1)`"
    if function is None:
        # We return our original function with the `verbose` param
        print(f'Creating a new `addition_decorator` function with verbose={verbose}')
        return partial(addition_decorator, verbose=verbose)
    def decorator(*args, **kwargs):
        while True:
            try:
                return function(*args, **kwargs)
            except ValueError as e:
                is_a = "`a`" in e.args[0]
                a,b = args
                if is_a:
                    if verbose:
                        print(f'Reducing `a` by 1: {a-1}')
                    args = (a-1, b)
                else:
                    if verbose:
                        print(f'Reducing `b` by 1: {b-1}')
                    args = (a, b-1)
    # Finally we return the inner function! *Very* important!
    return decorator

To add this parameter, we pass it into the @ function itself. Once we’ve declared our addition function, we should see a print statement immediatly:

@addition_decorator(verbose=True)
def addition(a,b):
    "Adds `a` and `b` together"
    if a > 1: 
        raise ValueError("`a` is greater than 1")
    if b > 1: 
        raise ValueError("`b` is greater than 1")
    return a+b

And now if we call addition:

addition(2,2)

We can see that it prints out the same information we had earlier!

You now know almost enough to be on your way with decorators

The most extreme example with nonlocal.

I’m going to provide a real example of what using nonlocal can actually provide for you and perform (that weird python thing no one really does?)

The choice to explain it like this is simply because coming up with a similar situation is something I can’t quite think of well, and this serves as a good example.

This example is that “cuda out of memory” I showed earlier in accelerate

The API can be thought of as so:

  1. Write a training function that takes a batch_size as the first argument
  2. Decorate this training function with the find_executable_batch_size decorator
  3. Have it continuously try and run, and if cuda OOM is hit, retry the loop by reducing the batch size in half.

Let’s see how this is implemented: > Note: this will be a simplified version of the official decorator for teaching purposes

import gc, torch

def find_executable_batch_size(function:callable = None, starting_batch_size:int = 128):
    """
    Decorator that will attempt to execute `function` with `starting_batch_size`.
    If CUDA Out-of-Memory is reached, the batch size is reduced in half and tried again.

    `function` must take in `batch_size` as its first parameter
    """
    if function is None:
        return partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
    
    # Keep a `batch_size` variable that gets updated and modified
    batch_size = starting_batch_size

    def decorator(*args, **kwargs):
        # Bring it into context
        nonlocal batch_size
        gc.collect()
        torch.cuda.empty_cache()

        while True:
            if batch_size == 0:
                raise RuntimeError("No executable batch size found")
            try:
                return function(batch_size, *args, **kwargs)
            except RuntimeError as e:
                if "CUDA out of memory" in e.args[0]:
                    # We reduce the batch size and clear the memory
                    gc.collect()
                    torch.cuda.empty_cache()
                    batch_size //= 2
                else:
                    # Otherwise raise the original error
                    raise

    return decorator

Here we make use of our starting_batch_size and use it as a parameter that constantly changes and adapts based on what happened in our try/except loop below.

Conclusion

Hopefully this gave you a better insight into decorators some! My next article will be discussing context managers and when you should do one vs the other. Thanks for reading!