Lesson 1, The forward and backward passes (part 2)

Building our first model

Initial Helpers


source

normalize

 normalize (x, mean, std)

source

get_data

 get_data ()
x_train,y_train,x_valid,y_valid = get_data()
train_mean,train_std = x_train.mean(), x_train.std()
train_mean,train_std
(tensor(0.1304), tensor(0.3073))

We need to normalize our data (mean ~= 0, std ~=1) by the training data, so they are on the same scale. If we did not then they could be considered two completely different datasets as a whole, and not actually part of the same bunch

x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)
train_mean,train_std = x_train.mean(), x_train.std()
train_mean,train_std
(tensor(2.1425e-08), tensor(1.))

source

test_near_zero

 test_near_zero (a, tol=0.001)
test_near_zero(x_train.mean())
test_near_zero(1-x_train.std())
n,m = x_train.shape
c = y_train.max()+1
n,m = x_train.shape
c = y_train.max()+1

n

Size of the training set


m

The length of one input


c

Number of activations eventual to classify with

n,m,c
(50000, 784, tensor(10))

Foundations version

Basic architecture

  • One hidden layer
  • Mean squared error to keep things simplified rather than cross entropy

We initialize with a simplified version of kaiming init / he init

nh = 50
w1 = torch.randn(m,nh)/math.sqrt(m)
b1 = torch.zeros(nh)
w2 = torch.randn(nh,1)/math.sqrt(nh)
b2 = torch.zeros(1)
nh = 50
w1 = torch.randn(m,nh)/math.sqrt(m)
b1 = torch.zeros(nh)
w2 = torch.randn(nh,1)/math.sqrt(nh)
b2 = torch.zeros(1)

nh

The size of our fully-connected hidden layer (nodes)


w1

One weight for our model, the first layer initialized (784,50)


b1

The bias for that weight


w2

Another weight for our model, the second layer (50,1)


b2

The bias for that weight


torch.randn(a,b)/math.sqrt(a)

Simplified kaiming init/he init

w1.shape, b1.shape, w2.shape, b2.shape
(torch.Size([784, 50]), torch.Size([50]), torch.Size([50, 1]), torch.Size([1]))
test_near_zero(w1.mean())
test_near_zero(w1.std()-1/math.sqrt(m))
# This should be ~ (0,1) (mean,std)
x_valid.mean(),x_valid.std()
(tensor(-0.0059), tensor(0.9924))
def lin(inp, weight, bias): return inp@weight + bias
t = lin(x_valid, w1, b1)
# So should this because we used kaiming init which is designed to have this effect
t.mean(), t.std()
(tensor(-0.0417), tensor(1.0341))
def relu(inp): return inp.clamp_min(0.)
def relu(inp): return inp.clamp_min(0.)

.clamp_min

A ReLU activation will turn all negatives into zero

While there are other ways of writing that, if you can find a function attached to a tensor for the thing you want to do, it will almost always be faster because it will be written in C - Jeremy Howard

t = relu(lin(x_valid, w1, b1))
t.mean(), t.std()
(tensor(0.3898), tensor(0.5947))

Uh oh! What went wrong?

Basically we took everything with a mean below zero and just got rid of it. As a result we lost a ton of good data points, and our standard deviation and mean drastically swong as a result.

\[\operatorname{std}=\sqrt{\frac{2}{\left(1+a^2\right) \times \text { fan_in }}}\]

Solution is to stick a two on the top:

std = math.sqrt(2/m)
w1 = torch.randn(m,nh)*std
t = relu(lin(x_valid, w1,b1))

t.mean(), t.std()
(tensor(0.5535), tensor(0.8032))

While this solved the standard deviation, our mean is now half because we still deleted everything below the mean

# What if...?
def relu_v2(x): return x.clamp_min(0.) - 0.5
def relu_v3(x): return (torch.pow(x.clamp_min(0.), 0.9)) - 0.5
w1 = torch.randn(m,nh)*std
t = relu_v2(lin(x_valid, w1,b1))

t.mean(), t.std()
(tensor(0.0372), tensor(0.8032))
t = relu_v3(lin(x_valid, w1,b1))

t.mean(), t.std()
(tensor(0.0181), tensor(0.7405))

Jeremy tried seeing just what would happen if during relu we reduced it by .5, and it seems to have helped some in returning us to the correct mean:

How well does this work in practice? – To test, I should try building a very basic CNN and throw it to ImageWoof and the only variance being the ReLU layer being utilized.

w1 = torch.zeros(m,nh)
init.kaiming_normal_(w1, mode="fan_out")
t = relu(lin(x_valid, w1, b1))
w1.mean(),w1.std()
(tensor(9.4735e-05), tensor(0.0506))
t.mean(),t.std()
(tensor(0.4818), tensor(0.7318))
w1 = torch.randn(m,nh)*math.sqrt(2./m)
t = relu_v2(lin(x_valid, w1,b1))

t.mean(), t.std()
(tensor(-0.0279), tensor(0.7500))
t = relu_v3(lin(x_valid, w1,b1))

t.mean(), t.std()
(tensor(-0.0422), tensor(0.6948))
def model(xb, v2=True):
    l1 = lin(xb, w1, b1)
    l2 = relu_v2(l1) if v2 else relu_v3(l1)
    l3 = lin(l2, w2, b2)
    return l3
2.56 ms ± 578 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.3 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
assert model(x_valid).shape == torch.Size([x_valid.shape[0],1])

Loss function: MSE

model(x_valid).shape
torch.Size([10000, 1])
def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()

source

mse

 mse (output, targ)

source

mse

 mse (output, targ)

.squeeze()

Opposite of unsqueeze, removes a dimension. We use it to remove the trailing [1]

Note: better to use -1 or 1 than just to do squeeze()

y_train,y_valid = y_train.float(),y_valid.float()
preds_a = model(x_train)
preds_b = model(x_train,False)
preds_a.shape
torch.Size([50000, 1])
mse(preds_a, y_train)
tensor(28.0614)
mse(preds_b, y_train)
tensor(27.8693)

Gradients and backward pass

Chain rule, chain rule, chain rule!

Start with our last function and go backwards:

def mse_grad(inp, targ):
    # grad of loss with respect to output of previous layer
    inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]
def mse_grad(inp, targ):
    # grad of loss with respect to output of previous layer
    inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]

inp.g

Gradients need to be attached to the inputs, so it can be passed across all of the functions and utilized as the output of the previous layer is the input for the current layer


2. * (inp.squeeze() - targ).unsqueeze(-1)/ inp.shape[0]

This is the derivitive of (inp-targ)^2/len(inp)

def relu_grad(inp, out):
    # grad of relu with respect to input activations
    inp.g = (inp>0).float() * out.g
def relu_grad(inp, out):
    # grad of relu with respect to input activations
    inp.g = (inp>0).float() * out.g

(inp>0).float() * out.g

The inp>0 is familiar, but given respect we need to multiply it by the previous layer’s gradients


inp>0

Given that anything negative after a ReLU is set to 0, it has no slope and thus a derivitive of 0. We take everything above 0 as a result

def lin_grad(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t() # transpose
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)
def lin_grad(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t() # transpose
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)

 out.g @ w.t()

The gradient of a matrix product is the product of the matrix transpose


    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum

We need the outputs with respect to the weights


And we also need the outputs with respect to the biases

def forward_and_backward(inp, targ):
    # forward pass:
    l1 = inp @ w1 + b1
    l2 = relu_v2(l1)
    out = l2 @ w2 + b2
    # We don't actually need the loss in backward
    loss = mse(out, targ)
    
    # backward pass:
    mse_grad(out, targ)
    lin_grad(l2, out, w2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, w1, b1)
def forward_and_backward(inp, targ):
    # forward pass:
    l1 = inp @ w1 + b1
    l2 = relu_v2(l1)
    out = l2 @ w2 + b2
    # We don't actually need the loss in backward
    loss = mse(out, targ)
    
    # backward pass:
    mse_grad(out, targ)
    lin_grad(l2, out, w2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, w1, b1)

l1 = inp @ w1 + b1\nlin_grad(inp, l1, w1, b1)

The inputs to the gradients is the original input, the output, and the rest of the options passed originally


l2 = relu_v2(l1)\nrelu_grad(l1, l2)

This pattern continues until we start and end on the original linear layer, traveling through the model and loss function twice

Backprop is the chain rule, with making sure all the calculations are saved somewhere

forward_and_backward(x_train, y_train)
# Save for testing against later
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig = x_train.g.clone()

And now we cheat with pytorch autograd to check results:

xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)
def forward(inp, targ):
    # forward pass
    l1 = inp @ w12 + b12
    l2 = relu_v2(l1)
    out = l2 @ w22 + b22
    return mse(out, targ)
loss = forward(xt2, y_train)
loss.backward()
# And now test
test_close(w22.grad, w2g)
test_close(b22.grad, b2g)
test_close(w12.grad, w1g)
test_close(b12.grad, b1g)
test_close(xt2.grad, ig)

Layers as classes

class ReLU():
    def __call__(self, inp):
        self.inp = inp
        self.out = inp.clamp_min(0.)-0.5
        return self.out
    
    def backward(self): 
        self.inp.g = (self.inp>0).float() * self.out.g
class ReLU():
    def __call__(self, inp):
        self.inp = inp
        self.out = inp.clamp_min(0.)-0.5
        return self.out
    
    def backward(self): 
        self.inp.g = (self.inp>0).float() * self.out.g

Let’s the class be called with ReLU()() and perform an operation


def

This is our backward pass from earlier, but save it inside self.inp.g

Explanation

class Linear():
    def __init__(self, w, b):
        self.w, self.b = w, b
    
    def __call__(self, inp):
        self.inp = inp
        self.out = inp @ self.w + self.b
        return self.out
    
    def backward(self):
        self.inp.g = self.out.g @ self.w.t()
        # Creating a giant outer product just to sum it together is very inefficient. Do it all at once!
        self.w.g = (self.inp.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)
        self.b.g = self.out.g.sum(0)
class MSE():
    def __call__(self, inp, targ):
        self.inp = inp
        self.targ = targ
        self.out = (inp.squeeze() - targ).pow(2).mean()
        return self.out
    
    def backward(self):
        self.inp.g = 2. * (self.inp.squeeze(-1) - self.targ).unsqueeze(-1) / self.targ.shape[0]
class Model():
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Linear(w1,b1), ReLU(), Linear(w2,b2)]
        self.loss = MSE()
    
    def __call__(self, x, targ):
        for layer in self.layers:
            x = layer(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for layer in reversed(self.layers):
            layer.backward()
# Reset our gradients:
w1.g, b1.g, w2.g, b2.g = [None]*4
# And define the model

model = Model(w1, b1, w2, b2)
CPU times: user 77.4 ms, sys: 26.8 ms, total: 104 ms
Wall time: 13.2 ms
CPU times: user 3.62 s, sys: 2.41 s, total: 6.03 s
Wall time: 893 ms
# Check the gradients align
test_close(w2g, w2.g)
test_close(b2g, b2.g)
test_close(w1g, w1.g)
test_close(b1g, b1.g)
test_close(ig, x_train.g)

Refactor again

class Module():
    "Basic class that will impelement .backward() and store the args and outputs from the forward function"
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out
    
    def forward(self): 
        raise NotImplementedError("You need to define the forward funciton still!")
    
    def backward(self):
        self.bwd(self.out, *self.args)
class ReLU(Module):
    def forward(self, inp):
        return inp.clamp_min(0.)-0.5
    
    def bwd(self, out, inp):
        inp.g = (inp>0).float() * out.g
class Linear(Module):
    def __init__(self, w, b):
        self.w, self.b = w, b
    
    def forward(self, inp):
        return inp@self.w + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        # Creating a giant outer product just to sum it together is very inefficient. Do it all at once!
        self.w.g = torch.einsum("bi,bj->ij",inp,out.g)
        self.b.g = out.g.sum(0)
class MSE(Module):
    def forward(self, inp, targ):
        return (inp.squeeze() - targ).pow(2).mean()
    
    def bwd(self, out, inp, targ):
        inp.g = 2. * (inp.squeeze(-1) - targ).unsqueeze(-1) / targ.shape[0]
class Model():
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Linear(w1,b1), ReLU(), Linear(w2,b2)]
        self.loss = MSE()
    
    def __call__(self, x, targ):
        for layer in self.layers:
            x = layer(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for layer in reversed(self.layers):
            layer.backward()
w1.g, b1.g, w2.g, b2.g = [None]*4
model = Model(w1, b1, w2, b2)
CPU times: user 142 ms, sys: 0 ns, total: 142 ms
Wall time: 19.9 ms
CPU times: user 234 ms, sys: 157 ms, total: 391 ms
Wall time: 49.7 ms

nn.Linear and nn.Module

We have now implemented both of these, and thus we’re allowed to use them

class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]
        self.loss = mse
    
    def __call__(self, x, targ):
        for layer in self.layers:
            x = layer(x)
        return self.loss(x.squeeze(-1), targ)
model = Model(m, nh, 1)
CPU times: user 129 ms, sys: 2.81 ms, total: 131 ms
Wall time: 19.7 ms
CPU times: user 105 ms, sys: 0 ns, total: 105 ms
Wall time: 16 ms