Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Mar 8, 2024
1 parent a5259fb commit f8b56b9
Show file tree
Hide file tree
Showing 10 changed files with 1,940 additions and 493 deletions.
524 changes: 524 additions & 0 deletions feature_importance/01_run_importance_local_ablation_classification.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,34 @@ def compare_estimators(estimators: List[ModelConfig],
'splitting_strategy': splitting_strategy
}
start = time.time()
local_fi_score = fi_est.cls(X_test, y_test, copy.deepcopy(est), **fi_est.kwargs)
if fi_est.name in ["MDI_local_sub_stumps_evaluate", "MDI_local_all_stumps_evaluate", "LFI_absolute_sum_evaluate",
"MDI_local_sub_stumps_evaluate_without_raw", "MDI_local_all_stumps_evaluate_without_raw",
"LFI_absolute_sum_evaluate_without_raw"]:
local_fi_score = fi_est.cls(X_train, y_train, X_test, y_test, copy.deepcopy(est), **fi_est.kwargs)
else:
local_fi_score = fi_est.cls(X_test, y_test, copy.deepcopy(est), **fi_est.kwargs)
end = time.time()
metric_results['fi_time'] = end - start
feature_importance_list.append(local_fi_score)
support_df = pd.DataFrame({"var": np.arange(len(support)),
"true_support": support,
"cor_with_signal": x_cor})
metric_results['fi_scores'] = support_df

start = time.time()
if np.max(support) != np.min(support):
y_pred = est.predict(X_test)
metric_results['MSE_before_ablation'] = mean_squared_error(y_test, y_pred)
imp_vals = copy.deepcopy(local_fi_score)
imp_vals[imp_vals == float("-inf")] = -sys.maxsize - 1
imp_vals[imp_vals == float("inf")] = sys.maxsize - 1
for i in range(X_test.shape[1]):
if fi_est.ascending:
ablation_X_test = ablation(X_test, imp_vals, "max", i+1, seed)
else:
ablation_X_test = ablation(X_test, imp_vals, "min", i+1, seed)
metric_results[f'MSE_after_ablation_{i+1}'] = mean_squared_error(y_test, est.predict(ablation_X_test))
y_pred = est.predict(X_test)
metric_results['MSE_before_ablation'] = mean_squared_error(y_test, y_pred)
imp_vals = copy.deepcopy(local_fi_score)
imp_vals[imp_vals == float("-inf")] = -sys.maxsize - 1
imp_vals[imp_vals == float("inf")] = sys.maxsize - 1
for i in range(X_test.shape[1]):
if fi_est.ascending:
ablation_X_test = ablation(X_test, imp_vals, "max", i+1, seed)
else:
ablation_X_test = ablation(X_test, imp_vals, "min", i+1, seed)
metric_results[f'MSE_after_ablation_{i+1}'] = mean_squared_error(y_test, est.predict(ablation_X_test))
end = time.time()

metric_results['ablation_time'] = end - start
metric_results['test_size'] = X_test.shape[0]
print(f"data_size: {X_test.shape[0]}, fi: {fi_est.name}, done with time: {end - start}")
Expand Down
254 changes: 234 additions & 20 deletions feature_importance/ablation_demo.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit f8b56b9

Please sign in to comment.