First let's train a model to analyze
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dep_var = 'salary'
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = IndexSplitter(list(range(800,1000)))(range_of(df))
to = TabularPandas(df, procs, cat_names, cont_names, y_names="salary", splits=splits)
dls = to.dataloaders()
And fit it
learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)
learn.fit(1, 1e-2)
This class allows you to utilize various methods within the
SHAP
interpretation library. Currentlysummary_plot
,dependence_plot
,waterfall_plot
,force_plot
, anddecision_plot
are supported.
test_data
should be either aPandas
dataframe or aTabularDataLoader
. If not, 100 random rows of the training data will be used instead.link
can either be "identity" or "logit". A generalized linear model link to connect the feature importance values to the model output. Since the feature importance values, phi, sum up to the model output, it often makes sense to connect them to the ouput with a link function where link(outout) = sum(phi). If the model output is a probability then the LogitLink link function makes the feature importance values have log-odds units.n_samples
can either be "auto" or an integer value. This is the number of times to re-evaluate the model when explaining each predictions. More samples leads to lower variance estimations of theSHAP
valuesl1_reg
can be:an integer value representing the number of features, "auto", "aic", "bic", or a float value. The l1 regularization to use for feature selection (the estimation procedure is based on a debiased lasso). The auto option currently uses "aic" when less that 20% of the possible sample space is enumerated, otherwise it uses no regularization.
exp = ShapInterpretation(learn)
exp = ShapInterpretation(learn, df.iloc[:2])
exp = ShapInterpretation(learn, learn.dls.test_dl(df.iloc[:100]))
decision_plot
Visualizes a model's decisions using cumulative
SHAP
values. Accepts aclass_id
which is used to indicate the class of interest for a classification model. It can either be anint
orstr
representation for a class of choice. Each colored line in the plot represents the model's prediction for a single observation. If no index is passed in to use from the data, it will default to the first ten samples on the test set. Note:plotting too many samples at once can make the plot illegible. For an up-to-date list of parameters, see here and for more information see here
exp.decision_plot(class_id=0, row_idx=10)
dependence_plot
Plots the value of a variable on the x-axis and the
SHAP
value of the same variable on the y-axis. Accepts aclass_id
andvariable_name
.class_id
is used to indicate the class of interest for a classification model. It can either be anint
orstr
representation for a class of choice. This plot shows how the model depends on the given variable. Vertical dispersion of the datapoints represent interaction effects. Gray ticks along the y-axis are datapoints where the variable's values wereNaN
. For an up-to-date list of parameters, see here and for more information see here
exp.dependence_plot('age', class_id=0)
force_plot
Visualizes the
SHAP
values with an added force layout. Accepts aclass_id
which is used to indicate the class of interest for a classification model. It can either be anint
orstr
representation for a class of choice.matplotlib
determines if it should be shown using matplotlib or in JavaScript. For an up-to-date list of parameters, see here
exp.force_plot(class_id=0)
summary_plot
Displays the SHAP values (which can be interpreted for feature importance) For an up-to-date list of parameters, see here
exp.summary_plot()
waterfall_plot
Plots an explanation of a single prediction as a waterfall plot. Accepts a
row_index
andclass_id
.row_index
is the index of the row chosen intest_data
to be analyzed, which defaults to zero. Accepts aclass_id
which is used to indicate the class of interest for a classification model. It can either be anint
orstr
representation for a class of choice. For an up-to-date list of parameters, see here
exp.waterfall_plot(row_idx=10)