Undoing a wrapped function

python
How to pickle an object that may have been wrapped
Published

11/20/22

Motivation

A few days ago there was an issue in Accelerate where a prepared model couldn’t be pickled. At first I thought it had to do with how in Accelerate we have a wrapper function that will make the model return FP32 outputs even on FP16.

The Meat

Okay so, let’s pretend we have the following situation in PyTorch:

#| language: python
class MathClass:
    "A super basic class that performs math"
    def __init__(self, a:int):
        self.a = a
        
    def addition(self, b):
        return self.a+b
    
    def subtraction(self, b):
        return self.a-b

Let’s then say that on the addition function of this class, I want to be able to wrap a function that will take the output of addition and divide it by two (logic wise it doesn’t make sense, but code wise it does).

We can use functools.wraps to do this:

#| language: python
from functools import wraps

def addition_with_div(addition_func):
    @wraps(addition_func)
    def inner(*args, **kwargs):
        result = addition_func(*args, **kwargs)
        return result / 2
    return inner
#| language: python
from functools import wraps

def addition_with_div(addition_func):
    @wraps(addition_func)
    def inner(*args, **kwargs):
        result = addition_func(*args, **kwargs)
        return result / 2
    return inner

    @wraps(addition_func)

A decorator that takes in some function and says that whatever function we have defined under it will be called after the inner function has finished.


    def inner(*args, **kwargs):
        result = addition_func(*args, **kwargs)
        return result / 2

Inside the function that was decorated we then pass all the args to the original func and return its result divided by 2

And finally use it:

#| language: python
math = MathClass(a=2)
math.addition = addition_with_div(math.addition)

Now if we try and do math.addition we get:

#| language: python
math.addition(10)

2 + 10 divided by 2 like we expect! But then what did I have to solve? Pickling.

Pickling, the beloved destructor

Let’s try pickling this (and use torch because I’m lazy):

#| language: python
import torch
torch.save(math, "mymaththing.pth")

As you can see, we get this weird pickling error. For the life of me I couldn’t figure out the why, until I finally did.

To save our object, I needed to remove my wrappers I had added, as they weren’t needed in the end result. Here’s how I did so:

#| language: python
import pickle
from functools import update_wrapper

class AdditionWithDiv:
    """
    Decorator which will perform addition then divide the result by two
    """
    def __init__(self, addition_func):
        self.addition_func = addition_func
        update_wrapper(self, addition_func)
    
    def __call__(self, *args, **kwargs):
        result = self.addition_func(*args, **kwargs)
        return result / 2

    def __getstate__(self):
        raise pickle.PicklingError(
            "This wrapper cannot be pickled! Remove it before doing so"
        )
        
addition_with_div = AdditionWithDiv
#| language: python
import pickle
from functools import update_wrapper

class AdditionWithDiv:
    """
    Decorator which will perform addition then divide the result by two
    """
    def __init__(self, addition_func):
        self.addition_func = addition_func
        update_wrapper(self, addition_func)
    
    def __call__(self, *args, **kwargs):
        result = self.addition_func(*args, **kwargs)
        return result / 2

    def __getstate__(self):
        raise pickle.PicklingError(
            "This wrapper cannot be pickled! Remove it before doing so"
        )
        
addition_with_div = AdditionWithDiv

    def __init__(self, addition_func):
        self.addition_func = addition_func
        update_wrapper(self, addition_func)

The init function will first store the function and then call functools.update_wrapper and wrap self around addition_func. It’s the same thing that functools.wraps did for us, but we can make use of a custom class instead.


    def __call__(self, *args, **kwargs):
        result = self.addition_func(*args, **kwargs)
        return result / 2

Here we perform what the inner did earlier, get our result and divide by two


    def __getstate__(self):
        raise pickle.PicklingError(
            "This wrapper cannot be pickled! Remove it before doing so"
        )

This is a very important custom error that will occur when someone tries to pickle this object, letting them know that this shouldn’t happen and cannot be done. This will help not return that weird error as before that tells us nothing.

#| language: python
math = MathClass(a=2)
math.addition = addition_with_div(math.addition)
#| language: python
torch.save(math, "mymaththing.pth")

Better, now to remove the wrapper:

#| language: python
math = MathClass(a=2)
math._original_addition = math.addition
math.addition = addition_with_div(math.addition)

addition = getattr(math, "addition")
original_addition = math.__dict__.pop("_original_addition", None)
if original_addition is not None:
    while hasattr(addition, "__wrapped__"):
        if addition != original_addition:
            addition = addition.__wrapped__
        else:
            break
    math.addition = addition
#| language: python
math = MathClass(a=2)
math._original_addition = math.addition
math.addition = addition_with_div(math.addition)

addition = getattr(math, "addition")
original_addition = math.__dict__.pop("_original_addition", None)
if original_addition is not None:
    while hasattr(addition, "__wrapped__"):
        if addition != original_addition:
            addition = addition.__wrapped__
        else:
            break
    math.addition = addition

math = MathClass(a=2)
math._original_addition = math.addition
math.addition = addition_with_div(math.addition)

We instantiate a new MathClass object and set a reference point to the original addition function we had before wrapping the function in our addition_with_div


addition = math.addition
math.addition = addition_with_div(math.addition)

addition = getattr(math, "addition")
original_addition = math.__dict__.pop("_original_addition", None)

We need to extract both the wrapped addition function and potentially if we have an _original_addition function if it exists.


    while hasattr(addition, "__wrapped__"):
        if addition != original_addition:
            addition = addition.__wrapped__
        else:
            break

We traverse the layers of __wrapped__ functions (as this can go to multitudes such as d(c(b(a())))) and if the wrapped addition isn’t the same as the original, get that new reference and keep going


    math.addition = addition

Finally set the function to the found addition

We now have the old function again and can pickle it!

#| language: python
torch.save(math, "mymaththing.pth")