Undoing a wrapped function
Motivation
A few days ago there was an issue in Accelerate where a prepared model couldn’t be pickled. At first I thought it had to do with how in Accelerate we have a wrapper function that will make the model return FP32 outputs even on FP16.
The Meat
Okay so, let’s pretend we have the following situation in PyTorch:
#| language: python
class MathClass:
"A super basic class that performs math"
def __init__(self, a:int):
self.a = a
def addition(self, b):
return self.a+b
def subtraction(self, b):
return self.a-b
Let’s then say that on the addition function of this class, I want to be able to wrap a function that will take the output of addition
and divide it by two (logic wise it doesn’t make sense, but code wise it does).
We can use functools.wraps
to do this:
#| language: python
from functools import wraps
def addition_with_div(addition_func):
@wraps(addition_func)
def inner(*args, **kwargs):
= addition_func(*args, **kwargs)
result return result / 2
return inner
#| language: python
from functools import wraps
def addition_with_div(addition_func):
@wraps(addition_func)
def inner(*args, **kwargs):
= addition_func(*args, **kwargs)
result return result / 2
return inner
@wraps(addition_func)
A decorator that takes in some function and says that whatever function we have defined under it will be called after the inner function has finished.
def inner(*args, **kwargs):
= addition_func(*args, **kwargs)
result return result / 2
Inside the function that was decorated we then pass all the args to the original func and return its result divided by 2
And finally use it:
#| language: python
= MathClass(a=2)
math = addition_with_div(math.addition) math.addition
Now if we try and do math.addition
we get:
#| language: python
10) math.addition(
2 + 10 divided by 2 like we expect! But then what did I have to solve? Pickling.
Pickling, the beloved destructor
Let’s try pickling this (and use torch because I’m lazy):
#| language: python
import torch
"mymaththing.pth") torch.save(math,
As you can see, we get this weird pickling error. For the life of me I couldn’t figure out the why, until I finally did.
To save our object, I needed to remove my wrappers I had added, as they weren’t needed in the end result. Here’s how I did so:
#| language: python
import pickle
from functools import update_wrapper
class AdditionWithDiv:
"""
Decorator which will perform addition then divide the result by two
"""
def __init__(self, addition_func):
self.addition_func = addition_func
self, addition_func)
update_wrapper(
def __call__(self, *args, **kwargs):
= self.addition_func(*args, **kwargs)
result return result / 2
def __getstate__(self):
raise pickle.PicklingError(
"This wrapper cannot be pickled! Remove it before doing so"
)
= AdditionWithDiv addition_with_div
#| language: python
import pickle
from functools import update_wrapper
class AdditionWithDiv:
"""
Decorator which will perform addition then divide the result by two
"""
def __init__(self, addition_func):
self.addition_func = addition_func
self, addition_func)
update_wrapper(
def __call__(self, *args, **kwargs):
= self.addition_func(*args, **kwargs)
result return result / 2
def __getstate__(self):
raise pickle.PicklingError(
"This wrapper cannot be pickled! Remove it before doing so"
)
= AdditionWithDiv addition_with_div
def __init__(self, addition_func):
self.addition_func = addition_func
self, addition_func) update_wrapper(
The init function will first store the function and then call functools.update_wrapper
and wrap self
around addition_func
. It’s the same thing that functools.wraps
did for us, but we can make use of a custom class
instead.
def __call__(self, *args, **kwargs):
= self.addition_func(*args, **kwargs)
result return result / 2
Here we perform what the inner
did earlier, get our result and divide by two
def __getstate__(self):
raise pickle.PicklingError(
"This wrapper cannot be pickled! Remove it before doing so"
)
This is a very important custom error that will occur when someone tries to pickle this object, letting them know that this shouldn’t happen and cannot be done. This will help not return that weird error as before that tells us nothing.
#| language: python
= MathClass(a=2)
math = addition_with_div(math.addition) math.addition
#| language: python
"mymaththing.pth") torch.save(math,
Better, now to remove the wrapper:
#| language: python
= MathClass(a=2)
math = math.addition
math._original_addition = addition_with_div(math.addition)
math.addition
= getattr(math, "addition")
addition = math.__dict__.pop("_original_addition", None)
original_addition if original_addition is not None:
while hasattr(addition, "__wrapped__"):
if addition != original_addition:
= addition.__wrapped__
addition else:
break
= addition math.addition
#| language: python
= MathClass(a=2)
math = math.addition
math._original_addition = addition_with_div(math.addition)
math.addition
= getattr(math, "addition")
addition = math.__dict__.pop("_original_addition", None)
original_addition if original_addition is not None:
while hasattr(addition, "__wrapped__"):
if addition != original_addition:
= addition.__wrapped__
addition else:
break
= addition math.addition
= MathClass(a=2)
math = math.addition
math._original_addition = addition_with_div(math.addition) math.addition
We instantiate a new MathClass
object and set a reference point to the original addition function we had before wrapping the function in our addition_with_div
= math.addition
addition = addition_with_div(math.addition)
math.addition
= getattr(math, "addition")
addition = math.__dict__.pop("_original_addition", None) original_addition
We need to extract both the wrapped addition function and potentially if we have an _original_addition
function if it exists.
while hasattr(addition, "__wrapped__"):
if addition != original_addition:
= addition.__wrapped__
addition else:
break
We traverse the layers of __wrapped__
functions (as this can go to multitudes such as d(c(b(a())))
) and if the wrapped addition isn’t the same as the original, get that new reference and keep going
= addition math.addition
Finally set the function to the found addition
We now have the old function again and can pickle it!
#| language: python
"mymaththing.pth") torch.save(math,