First let's train a model to analyze
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dep_var = 'salary'
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = IndexSplitter(list(range(800,1000)))(range_of(df))
to = TabularPandas(df, procs, cat_names, cont_names, y_names="salary", splits=splits)
dls = to.dataloaders()
And fit it
learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)
learn.fit(1, 1e-2)
This class allows you to utilize various methods within the
SHAPinterpretation library. Currentlysummary_plot,dependence_plot,waterfall_plot,force_plot, anddecision_plotare supported.
test_datashould be either aPandasdataframe or aTabularDataLoader. If not, 100 random rows of the training data will be used instead.linkcan either be "identity" or "logit". A generalized linear model link to connect the feature importance values to the model output. Since the feature importance values, phi, sum up to the model output, it often makes sense to connect them to the ouput with a link function where link(outout) = sum(phi). If the model output is a probability then the LogitLink link function makes the feature importance values have log-odds units.n_samplescan either be "auto" or an integer value. This is the number of times to re-evaluate the model when explaining each predictions. More samples leads to lower variance estimations of theSHAPvaluesl1_regcan be:an integer value representing the number of features, "auto", "aic", "bic", or a float value. The l1 regularization to use for feature selection (the estimation procedure is based on a debiased lasso). The auto option currently uses "aic" when less that 20% of the possible sample space is enumerated, otherwise it uses no regularization.
exp = ShapInterpretation(learn)
exp = ShapInterpretation(learn, df.iloc[:2])
exp = ShapInterpretation(learn, learn.dls.test_dl(df.iloc[:100]))
decision_plot
Visualizes a model's decisions using cumulative
SHAPvalues. Accepts aclass_idwhich is used to indicate the class of interest for a classification model. It can either be anintorstrrepresentation for a class of choice. Each colored line in the plot represents the model's prediction for a single observation. If no index is passed in to use from the data, it will default to the first ten samples on the test set. Note:plotting too many samples at once can make the plot illegible. For an up-to-date list of parameters, see here and for more information see here
exp.decision_plot(class_id=0, row_idx=10)
dependence_plot
Plots the value of a variable on the x-axis and the
SHAPvalue of the same variable on the y-axis. Accepts aclass_idandvariable_name.class_idis used to indicate the class of interest for a classification model. It can either be anintorstrrepresentation for a class of choice. This plot shows how the model depends on the given variable. Vertical dispersion of the datapoints represent interaction effects. Gray ticks along the y-axis are datapoints where the variable's values wereNaN. For an up-to-date list of parameters, see here and for more information see here
exp.dependence_plot('age', class_id=0)
force_plot
Visualizes the
SHAPvalues with an added force layout. Accepts aclass_idwhich is used to indicate the class of interest for a classification model. It can either be anintorstrrepresentation for a class of choice.matplotlibdetermines if it should be shown using matplotlib or in JavaScript. For an up-to-date list of parameters, see here
exp.force_plot(class_id=0)
summary_plot
Displays the SHAP values (which can be interpreted for feature importance) For an up-to-date list of parameters, see here
exp.summary_plot()
waterfall_plot
Plots an explanation of a single prediction as a waterfall plot. Accepts a
row_indexandclass_id.row_indexis the index of the row chosen intest_datato be analyzed, which defaults to zero. Accepts aclass_idwhich is used to indicate the class of interest for a classification model. It can either be anintorstrrepresentation for a class of choice. For an up-to-date list of parameters, see here
exp.waterfall_plot(row_idx=10)