Useful interpretation functions for tabular, such as Feature Importance
We can pass in sections of a DataFrame
, but not a DataLoader
. perm_func
dictates how to calculate our importance, and reverse
will determine how to sort the output
This along with plot_dendrogram
and any helper functions along the way are based upon this by Pack911 on the fastai forums.
from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
splits = RandomSplitter()(range_of(df))
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
y_names = 'salary'
to = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names=y_names, splits=splits)
dls = to.dataloaders()
learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)
learn.fit(3)
After fitting, let's first calculate the relative feature importance on the first 1,000 rows:
dl = learn.dls.test_dl(df)
fi = learn.feature_importance(df=df)
Next we'll calculate the correlation matrix, and then we will plot it's dendrogram:
corr_dict = learn.get_top_corr_dict(df, thresh=0.3); corr_dict
learn.plot_dendrogram(df)
This allows us to see what family of features are closesly related based on our thresh
, and also to show (in combination with the feature importance) how our model uses each variable.