forked from h2oai/driverlessai-recipes
-
Notifications
You must be signed in to change notification settings - Fork 1
/
morris_sensitivity_explainer.py
209 lines (185 loc) · 8.78 KB
/
morris_sensitivity_explainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""Morris Sensitivity Analysis Explainer"""
from functools import partial
import datatable as dt
import numpy as np
import pandas as pd
from h2oaicore.mli.oss.byor.core.explainers import (
CustomDaiExplainer,
CustomExplainer,
)
from h2oaicore.mli.oss.byor.core.explanations import GlobalFeatImpExplanation
from h2oaicore.mli.oss.byor.core.representations import (
GlobalFeatImpJSonDatatableFormat,
GlobalFeatImpJSonFormat,
)
from h2oaicore.mli.oss.byor.explainer_utils import clean_dataset
# Explainer MUST extend abstract CustomExplainer class to be discovered and
# deployed. In addition it inherits common metadata and (default) functionality. The
# explainer must implement fit() and explain() methods.
#
# Explainer CAN extend CustomDaiExplainer class if it will run on Driverless AI server
# and use experiments. CustomDaiExplainer class provides easy access/handle to the
# dataset and model (metadata and artifacts), filesystem, ... and common logic.
class MorrisSensitivityLeExplainer(CustomExplainer, CustomDaiExplainer):
"""InterpretML: Morris sensitivity (https://github.com/interpretml/interpret)"""
# explainer display name (used e.g. in UI explainer listing)
_display_name = "Morris Sensitivity Analysis"
_description = (
"Morris sensitivity analysis explainer provides Morris SA based feature "
"importance which is a measure of the contribution of an input variable "
"to the overall predictions of the Driverless AI model. In applied "
"statistics, the Morris method for global sensitivity analysis is a so-called "
"one-step-at-a-time method (OAT), meaning that in each run only one "
"input parameter is given a new value."
"This Morris sensitivity analysis explainer is based based on InterpretML"
"library (http://interpret.ml)."
)
# declaration of supported experiments: regression / binary / multiclass
_regression = True
_binary = True
# declaration of provided explanations: global, local or both
_global_explanation = True
# declaration of explanation types this explainer creates e.g. feature importance
_explanation_types = [GlobalFeatImpExplanation]
# Python package dependencies (can be installed using pip)
_modules_needed_by_name = ["interpret==0.3.2"]
# explainer constructor must not have any required parameters
def __init__(self):
CustomExplainer.__init__(self)
CustomDaiExplainer.__init__(self)
self.cat_variables = None
self.mcle = None
# setup() method is used to initialize the explainer based on provided parameters
# which are passed from client/UI. See parent classes setup() methods docstrings
# and source to check the list of instance fields which are initialized for the
# explainer
def setup(self, model, persistence, key=None, params=None, **e_params):
CustomExplainer.setup(self, model, persistence, key, params, **e_params)
CustomDaiExplainer.setup(self, **e_params)
# abstract fit() method must be implemented - its purpose is to pre-compute
# any artifacts e.g. surrogate models, to be used by explain() method
def fit(self, X: dt.Frame, y: dt.Frame = None, **kwargs):
# nothing to pre-compute
return self
# explain() method is responsible for the creation of the explanations
def explain(
self, X, y=None, explanations_types: list = None, **kwargs
) -> list:
# 3rd party Morris SA library import
from interpret.blackbox import MorrisSensitivity
# DATASET: categorical features encoding (for 3rd party libraries which
# support numeric features only), rows w/ missing values filtering, ...
X = X[:, self.used_features] if self.used_features else X
x, self.cat_variables, self.mcle, _ = clean_dataset(
frame=X.to_pandas(),
le_map_file=self.persistence.get_explainer_working_file("mcle"),
logger=self.logger,
)
# PREDICT FUNCTION: Driverless AI scorer -> library compliant predict function
def predict_function(
pred_fn, col_names, cat_variables, label_encoder, X
):
X = pd.DataFrame(X, columns=col_names)
# categorical features inverse label encoding used in case of 3rd party
# libraries which support numeric only
if label_encoder:
X[cat_variables] = X[cat_variables].astype(np.int64)
label_encoder.inverse_transform(X)
# score
preds = pred_fn(X)
# scoring output conversion to the format expected by 3rd party library
if isinstance(preds, pd.core.frame.DataFrame):
preds = preds.to_numpy()
if preds.ndim == 2:
preds = preds.flatten()
return preds
predict_fn = partial(
predict_function,
self.model.predict_method,
self.used_features,
self.cat_variables,
self.mcle,
)
# CALCULATION of the Morris SA explanation
sensitivity: MorrisSensitivity = MorrisSensitivity(
model=predict_fn, data=x, feature_names=list(x.columns)
)
morris_explanation = sensitivity.explain_global(name=self.display_name)
# NORMALIZATION of proprietary Morris SA library data to explanation w/
# Grammar of MLI format for the visualization in Driverless AI UI
explanations = [self._normalize_to_gom(morris_explanation)]
# explainer MUST return declared explanation(s) (_explanation_types)
return explanations
#
# optional NORMALIZATION to Grammar of MLI
#
"""
explainer_morris_sensitivity_explainer_..._MorrisSensitivityExplainer_<UUID>
├── global_feature_importance
│ ├── application_json
│ │ ├── explanation.json
│ │ └── feature_importance_class_0.json
│ └── application_vnd_h2oai_json_datatable_jay
│ ├── explanation.json
│ └── feature_importance_class_0.jay
├── log
│ ├── explainer_job.log
│ └── logger.lock
└── work
"""
# Normalization of the data to the Grammar of MLI defined format. Normalized data
# can be visualized using Grammar of MLI UI components in Driverless AI web UI.
#
# This method creates explanation (data) and its representations (JSon, datatable)
def _normalize_to_gom(self, morris_explanation) -> GlobalFeatImpExplanation:
# EXPLANATION
explanation = GlobalFeatImpExplanation(
explainer=self,
# display name of explanation's tile in UI
display_name=self.display_name,
# tab name where to put explanation's tile in UI
display_category=GlobalFeatImpExplanation.DISPLAY_CAT_CUSTOM,
)
# FORMAT: explanation representation as JSon+datatable (JSon index file which
# references datatable frame for each class)
jdf = GlobalFeatImpJSonDatatableFormat
# data normalization: 3rd party frame to Grammar of MLI defined frame
# conversion - see GlobalFeatImpJSonDatatableFormat docstring for format
# documentation and source for helpers to create the representation easily
explanation_frame = dt.Frame(
{
jdf.COL_NAME: morris_explanation.data()["names"],
jdf.COL_IMPORTANCE: list(morris_explanation.data()["scores"]),
jdf.COL_GLOBAL_SCOPE: [True]
* len(morris_explanation.data()["scores"]),
}
).sort(-dt.f[jdf.COL_IMPORTANCE])
# index file (of per-class data files)
(
idx_dict,
idx_str,
) = GlobalFeatImpJSonDatatableFormat.serialize_index_file(
["global"],
doc=MorrisSensitivityLeExplainer._description,
)
json_dt_format = GlobalFeatImpJSonDatatableFormat(explanation, idx_str)
json_dt_format.update_index_file(
idx_dict, total_rows=explanation_frame.shape[0]
)
# data file
json_dt_format.add_data_frame(
format_data=explanation_frame,
file_name=idx_dict[jdf.KEY_FILES]["global"],
)
# JSon+datatable format can be added as explanation's representation
explanation.add_format(json_dt_format)
# FORMAT: explanation representation as JSon
#
# Having JSon+datatable formats it's easy to get other formats like CSV,
# datatable, ZIP, ... using helpers - adding JSon representation:
explanation.add_format(
explanation_format=GlobalFeatImpJSonFormat.from_json_datatable(
json_dt_format
)
)
return explanation