Useful interpretation functions for tabular, such as Feature Importance

base_error[source]

base_error(err, val)

TabularLearner.feature_importance[source]

TabularLearner.feature_importance(x:TabularLearner, df=None, dl=None, perm_func='base_error', metric='accuracy', bs=None, reverse=True, plot=True)

Calculate and plot the Feature Importance based on df

We can pass in sections of a DataFrame, but not a DataLoader. perm_func dictates how to calculate our importance, and reverse will determine how to sort the output

TabularLearner.get_top_corr_dict[source]

TabularLearner.get_top_corr_dict(x:TabularLearner, df:DataFrame, thresh:float=0.8)

Grabs top pairs of correlation with a given correlation matrix on df filtered by thresh

This along with plot_dendrogram and any helper functions along the way are based upon this by Pack911 on the fastai forums.

TabularLearner.plot_dendrogram[source]

TabularLearner.plot_dendrogram(x:TabularLearner, df:DataFrame, figsize=None, leaf_font_size=16)

Plots dendrogram for a calculated correlation matrix

Example Usage

We'll run an example on the ADULT_SAMPLE dataset

from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
splits = RandomSplitter()(range_of(df))
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
y_names = 'salary'
to = TabularPandas(df, procs=procs, cat_names=cat_names, cont_names=cont_names,
                   y_names=y_names, splits=splits)
dls = to.dataloaders()
learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)
learn.fit(3)
Could not gather input dimensions
WandbCallback was not able to prepare a DataLoader for logging prediction samples -> list indices must be integers or slices, not list
epoch train_loss valid_loss accuracy time
0 0.360690 0.360770 0.832463 00:40
1 0.358146 0.355560 0.834152 00:38
2 0.346212 0.353760 0.834152 00:39

After fitting, let's first calculate the relative feature importance on the first 1,000 rows:

dl = learn.dls.test_dl(df)
fi = learn.feature_importance(df=df)
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.8412824869155884, 'accuracy': '00:02'}.
Calculating Permutation Importance
100.00% [9/9 00:22<00:00]
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.8334817886352539, 'accuracy': '00:02'}.
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.8295199871063232, 'accuracy': '00:02'}.
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.7927889227867126, 'accuracy': '00:02'}.
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.8150855302810669, 'accuracy': '00:02'}.
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.834740936756134, 'accuracy': '00:02'}.
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.8399004936218262, 'accuracy': '00:02'}.
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.825711727142334, 'accuracy': '00:02'}.
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.8401461839675903, 'accuracy': '00:02'}.
Could not gather input dimensions
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'epoch': 3}.
wandb: WARNING Adding to old History rows isn't currently supported.  Step 1220 < 1221; dropping {'valid_loss': 0.829827070236206, 'accuracy': '00:02'}.

Next we'll calculate the correlation matrix, and then we will plot it's dendrogram:

corr_dict = learn.get_top_corr_dict(df, thresh=0.3); corr_dict
100.00% [45/45 00:33<00:00]
OrderedDict([('workclass vs sex', 0.991),
             ('marital-status vs race', 0.506),
             ('education vs occupation', 0.493),
             ('fnlwgt vs education-num', 0.488),
             ('age vs education', 0.397),
             ('relationship vs race', 0.363),
             ('education-num vs race', 0.305)])
learn.plot_dendrogram(df)
100.00% [45/45 00:32<00:00]

This allows us to see what family of features are closesly related based on our thresh, and also to show (in combination with the feature importance) how our model uses each variable.