Skip to content

Commit

Permalink
update dgp
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Jan 14, 2024
1 parent cd3baac commit 730298f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
sys.path.append("../..")
from feature_importance.scripts.simulations_util import *


X_DGP = sample_real_X
### Update start for local MDI+
X_DGP = sample_normal_X
X_PARAMS_DICT = {
"fpath": "/mnt/d/local_MDI+/imodels-experiments/data/X_splicing_cleaned.csv",
"sample_row_n": None,
"sample_col_n": None
"n": 1200,
"d": 50,
"mean": 0,
"scale": 1
}
### Update start for local MDI+
Y_DGP = linear_model_two_groups
Y_PARAMS_DICT = {
"beta": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
]

FI_ESTIMATORS = [
# [FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, model_type='tree')],
[FIModelConfig('MDI_all_stumps', MDI_local_all_stumps, model_type='tree')],
[FIModelConfig('MDI_sub_stumps', MDI_local_sub_stumps, model_type='tree')],
[FIModelConfig('TreeSHAP', tree_shap_local, model_type='tree')],
[FIModelConfig('Permutation', permutation_local, model_type='tree')],
Expand Down
7 changes: 5 additions & 2 deletions feature_importance/scripts/competing_methods_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def MDI_local_sub_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
rf_plus_model.fit(X, y)

try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error)
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error, version = "zach")
if return_stability_scores:
stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
except ValueError as e:
Expand All @@ -123,6 +123,7 @@ def MDI_local_sub_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
# if return_stability_scores:
# mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)
result = mdi_plus_scores["local"]
print(result)
# Convert the array to a DataFrame
result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])

Expand All @@ -146,6 +147,7 @@ def MDI_local_all_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
Var: variable name
Importance: MDI+ score
"""
num_samples, num_features = X.shape

if isinstance(fit, RegressorMixin):
RFPlus = RandomForestPlusRegressor
Expand All @@ -157,7 +159,7 @@ def MDI_local_all_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
rf_plus_model = RFPlus(rf_model=fit, **kwargs)
rf_plus_model.fit(X, y)
try:
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error)
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X=X, y=y, local_scoring_fns=mean_squared_error, version = "tiffany")
if return_stability_scores:
stability_scores = rf_plus_model.get_mdi_plus_stability_scores(B=25)
except ValueError as e:
Expand All @@ -174,6 +176,7 @@ def MDI_local_all_stumps(X, y, fit, scoring_fns="auto", return_stability_scores=
# if return_stability_scores:
# mdi_plus_scores = pd.concat([mdi_plus_scores, stability_scores], axis=1)
result = mdi_plus_scores["local"]
print(result)
# Convert the array to a DataFrame
result_table = pd.DataFrame(result, columns=[f'Feature_{i}' for i in range(num_features)])

Expand Down

0 comments on commit 730298f

Please sign in to comment.