Useful interpretation functions for tabular, such as Feature Importance
We can pass in sections of a DataFrame, but not a DataLoader. perm_func dictates how to calculate our importance, and reverse will determine how to sort the output
This along with plot_dendrogram and any helper functions along the way are based upon this by Pack911 on the fastai forums.
from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
splits = RandomSplitter()(range_of(df))
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
y_names = 'salary'
to = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names=y_names, splits=splits)
dls = to.dataloaders()
learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)
learn.fit(3)
After fitting, let's first calculate the relative feature importance on the first 1,000 rows:
dl = learn.dls.test_dl(df)
fi = learn.feature_importance(df=df)
Next we'll calculate the correlation matrix, and then we will plot it's dendrogram:
corr_dict = learn.get_top_corr_dict(df, thresh=0.3); corr_dict
learn.plot_dendrogram(df)
This allows us to see what family of features are closesly related based on our thresh, and also to show (in combination with the feature importance) how our model uses each variable.