= "https://figshare.com/ndownloader/files/25635053"
url = FastDownload(base="~/.fastai")
fd = fd.download(url); path path
Path('/home/zach/.fastai/archive/25635053')
The game: Recreate fastai, while only being able to use:
datasets
The game I will also be playing:
import *
The difference between effective people in Deep Learning and the rest is who can make things in code that can work properly, and there’s very few of those people - Jeremy Howard
3 steps to training a really good model:
How to avoid overfitting from A -> F
4 & 5 both have the least impact, start with the first 3
First we need to download the dataset we are using, which will be MNIST
url = "https://figshare.com/ndownloader/files/25635053"
fd = FastDownload(base="~/.fastai")
path = fd.download(url); path
Path('/home/zach/.fastai/archive/25635053')
deeplearning.net is no longer up, so we use a version of Yann LeCun’s dataset
We utilize fastdownload’s FastDownload
class to handle the downloading of the data. from fastai import datasets
is no longer a thing.
Perform the actual downloading
The downloaded data contains numpy
arrays, which are not allowed so they must be converted to tensors
x_train,y_train,x_valid,y_valid = map(tensor, (x_train,y_train,x_valid,y_valid))
n,c = x_train.shape
x_train, x_train.shape, y_train, y_train.shape, y_train.min(), y_train.max()
(tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]),
torch.Size([50000, 784]),
tensor([5, 0, 4, ..., 8, 4, 8]),
torch.Size([50000]),
tensor(0),
tensor(9))
x_train,y_train,x_valid,y_valid = map(tensor, (x_train,y_train,x_valid,y_valid))
n,c = x_train.shape
x_train, x_train.shape, y_train, y_train.shape, y_train.min(), y_train.max()
(tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]),
torch.Size([50000, 784]),
tensor([5, 0, 4, ..., 8, 4, 8]),
torch.Size([50000]),
tensor(0),
tensor(9))
Applys torch.tensor
across the four arrays, converting them all into tensors
n
= the number of rows in the training set, c
= the number of columns in the training set
Verify there are 50,000 items in the dataset
Verify that each item is 28x28 numbers
Verify the lowest class in the y labels is 0
Verify the highest class in the y labels is 9
Get one set of data from the dataset
Check after viewing it as a 28,28 (more on this next) that the type is still a FloatTensor
Create a simple linear model, of something akin to y=ax+b
Core of the basic of machine learning, “affine functions”.
a
and b
are two matricies which should be multiplied
Matrix multiplication cannot occur unless the number of columns in a
aligns with the number of rows in b
c
is the resulting matrix, which has a shape of a
’s rows and b
’s columns
Loop of matrix B as a whole scrolling down matrix A sideways, imagine going row by row like a curtain coming down slowly
Loop of each column in matrix B at each row in matrix A
The actual loop of multiplying and adding (matrix multiplication)
The actual multiplication being performed
CPU times: user 440 ms, sys: 72.4 ms, total: 513 ms
Wall time: 421 ms
This is quite slow. To do a single epoch it would take ~20,000 seconds on the computer I’m using to take notes. (50,000 on Jeremy’s).
This is also why we don’t write things in Python. It’s unreasonably slow.
New goal, can we speed this up 50,000 times
To speed things up, start with the innermost loop and make things just a little bit faster
The way to make Python faster is to remove python - Jeremy Howard
EWO’s include (+,-,*,/,>,<,==)
Example with two tensors:
We performed c[0] = a[0]+b[0]
, c[1] = a[1] + b[1]
, …
Also known as what percentage of a is less than b. We could also perform the same on a rank 2 tensor (a tensor that has 2 dimensions), aka a matrix!
Note: We only convert the first number to a float as PyTorch will realize this and cast the rest as a float
Frobenius norm:
I have no idea what this is/remember what this is
\[\|A\|_F=\left(\sum_{i, j=1}^n\left|a_{i j}\right|^2\right)^{1 / 2}\]
\[\\text{This is the first for loop, and goes from 1 }\\rightarrow\\text{ n}\]
\[\\text{This is the second for loop, and goes from 1 }\\rightarrow\\text{ n as well}\]
\[\\text{This correlates to }\left|a_{i j}\\right|\]
\[\\text{This aligns with }\sum_{i, j=1}^n\\text{, which is equivalent to a product combination of }\sum_{i \mathop =1}^m and \sum_{j \mathop =1}^n\]
\[\\text{This correlates to the 1/2 power, simplified as 'result' }\left(\\text{result}\\right)^{1 / 2}\]
\[\\text{We replace the entire innermost for loop with this, and directly perform the matrix operation.}\\newline\\text{Remember that : selects everything from i}\\rightarrow\\text{end! (Or the entirety of that axis)}\]
We select all of row i
And we select all of column j
.
In numpy and PyTorch it goes 🎵 row by column 🎵
570 µs ± 21.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
We are now 600x faster by removing a single loop by running it in c
Now we need to get rid of the second-most inner loop through broadcasting.
Get rid of all for loops and replace with implicit broadcasted loops
We just broadcast a > 0
. Also known as, the float turns into [0,0,0]
and an element-wise operation is performed, and is done at either C or CUDA speed depending on the device
By the rules we have so far, we’d expect this to not actually do anything. But instead it broadcast the tensor horizontally row by row adding the vector to the matrix
This shows us that it’s only storing one copy of t
, and not a 3x3 copy of t
How to read this:
When going row by row, it should take zero steps through the memory/storage. And when going column by column, it should take one step.
This in turn is how it repeats 10,20,30 for every single row.
We can create tensors that behave like tensors much bigger than what they are.
What if we wanted to take a column instead of a row? In other words, a rank 2 tensor of shape (3,1)
d
would have a shape of (1,3)
which changed from just (3)
by adding a dimension at position 0
.
d
would have a shape of (3,1)
which changed from just (3)
by adding a dimension at position 1
.
This fails because we only have a 1d tensor not a 2d tensor. E.g.:
(torch.Size([3]), torch.Size([1, 3]), torch.Size([3, 1]))
PyTorch and numpy will use this notation to squeeze in a dimension at index None
, equivalent to unsqueeze()
This is equivalent to d.unsqueeze(0)
This is equivalent to d.unsqueeze(1)
This also works with multiple axes:
You can always skip trailing :‘s, and’…’ means ‘all preceding dimensions’:
From here, we visualize this in excel. Follow the timestamp here
With this information now, we can use this to get rid of the loop:
def matmul(a,b):
ar,ac = a.shape
br,bc = b.shape
assert ac==br
c = torch.zeros(ar,bc)
for i in range(ar):
# c[i,j] = (a[i,:] * b[:,j]).sum() previous
c[i] = (a[i].unsqueeze(-1) * b).sum(dim=0)
# This is equivalent to c[i,:]
# Rewritten in None form:
#c[i] = (a[i][:,None] * b).sum(dim=0)
# Rewritten again to avoid second index altogether:
#c[i] = (a[i,:,None] * b).sum(dim=0)
return c
def matmul(a,b):
ar,ac = a.shape
br,bc = b.shape
assert ac==br
c = torch.zeros(ar,bc)
for i in range(ar):
# c[i,j] = (a[i,:] * b[:,j]).sum() previous
c[i] = (a[i].unsqueeze(-1) * b).sum(dim=0)
# This is equivalent to c[i,:]
# Rewritten in None form:
#c[i] = (a[i][:,None] * b).sum(dim=0)
# Rewritten again to avoid second index altogether:
#c[i] = (a[i,:,None] * b).sum(dim=0)
return c
This takes a
at i
and expands its last dimension by 1, and now it’s a rank 2 tensor
This newly reshaped array can then be multiplied by b properly without issue
And finally we can take the sum of that result, doing so on the first dimension
tensor([[-0., 0., 0., ..., 0., -0., 0.],
[-0., 0., 0., ..., -0., -0., -0.],
[0., 0., -0., ..., -0., 0., 0.],
...,
[-0., 0., -0., ..., -0., 0., -0.],
[-0., 0., -0., ..., 0., -0., -0.],
[0., 0., 0., ..., 0., -0., -0.]])
tensor([[-0., 0., 0., ..., 0., -0., 0.],
[-0., 0., 0., ..., -0., -0., -0.],
[0., 0., -0., ..., -0., 0., 0.],
...,
[-0., 0., -0., ..., -0., 0., -0.],
[-0., 0., -0., ..., 0., -0., -0.],
[0., 0., 0., ..., 0., -0., -0.]])
202 µs ± 66.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
We are now more than 2000x faster!
tensor([[100., 200., 300.],
[200., 400., 600.],
[300., 600., 900.]])
Recall the inner most part of the for loops earlier:
And when we removed this, it looked like so:
We can rewrite this in Einstein Summation using the following steps:
i,j
to the end and make an arrow point at it:To the left of the arrow is the input, to the right of the arrow is the output
Inputs are delimited by comma, so there are two in this case
Rank is denoted by the number of letters there are. ik
and kj
are both rank 2
These inputs are read (shape wise) as k
by j
or i
by k
When a letter is repeated across inputs, it is assumed to be a dot product along that dimension
101 µs ± 42.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Since we have now explored matmul to it’s fullest extent, we can utilize pytorch’s operator directly for matrix multiplication:
The slowest run took 19.75 times longer than the fastest. This could mean that an intermediate result is being cached.
13.1 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
The matmul is pushed to a BLAS (basic linear algebra subprogram) cuBLAS for nvidia, ex. This is what the M1 has for example and how they entered the space.
matmul is so common and useful that it has it’s own operator, @
:
This is the exact same speed as m1.matmul(m2)