Skip to content

Commit

Permalink
Add demo notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed Feb 23, 2024
1 parent 7caf5c4 commit 4d74f40
Showing 1 changed file with 197 additions and 0 deletions.
197 changes: 197 additions & 0 deletions feature_importance/ablation_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"import os\n",
"from os.path import join as oj\n",
"import glob\n",
"import argparse\n",
"import pickle as pkl\n",
"import time\n",
"import warnings\n",
"from scipy import stats\n",
"import dask\n",
"from dask.distributed import Client\n",
"import numpy as np\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"import sys\n",
"from collections import defaultdict\n",
"from typing import Callable, List, Tuple\n",
"import itertools\n",
"from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score, mean_squared_error\n",
"\n",
"sys.path.append(\".\")\n",
"sys.path.append(\"..\")\n",
"sys.path.append(\"../../imodels/\")\n",
"\n",
"warnings.filterwarnings(\"ignore\", message=\"Bins whose width\")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n",
"from sklearn.metrics import r2_score, mean_absolute_error, accuracy_score, roc_auc_score, mean_squared_error\n",
"\n",
"from imodels.importance import RandomForestPlusRegressor, RandomForestPlusClassifier, \\\n",
" RidgeRegressorPPM, LassoRegressorPPM, IdentityTransformer\n",
"from imodels.importance.rf_plus import _fast_r2_score\n",
"import seaborn as sns\n",
"from util import ModelConfig, FIModelConfig, tp, fp, neg, pos, specificity_score, auroc_score, auprc_score, compute_nsg_feat_corr_w_sig_subspace, apply_splitting_strategy\n",
"import shap\n",
"import sklearn"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def sample_normal_X_subgroups(n, d, mean, scale):\n",
" \"\"\"\n",
" :param n: Number of samples\n",
" :param d: Number of features\n",
" :param mean: Nested list of mean of normal distribution for each subgroup\n",
" :param scale: Nested ist of scale of normal distribution for each subgroup\n",
" :return:\n",
" \"\"\"\n",
" assert len(mean[0]) == len(scale[0]) == d\n",
" num_groups = len(mean)\n",
" result = []\n",
" group_size = n // num_groups\n",
" for i in range(num_groups):\n",
" X = np.zeros((group_size, d))\n",
" for j in range(d):\n",
" X[:, j] = np.random.normal(mean[i][j], scale[i][j], size=group_size)\n",
" result.append(X)\n",
" return np.vstack(result)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def generate_random_shuffle(data, seed):\n",
" \"\"\"\n",
" Randomly shuffle each column of the data.\n",
" \"\"\"\n",
" np.random.seed(seed)\n",
" return np.array([np.random.permutation(data[:, i]) for i in range(data.shape[1])]).T\n",
"\n",
"\n",
"def ablation(data, feature_importance, mode, num_features, seed):\n",
" \"\"\"\n",
" Replace the top num_features max feature importance data with random shuffle for each sample\n",
" \"\"\"\n",
" assert mode in [\"max\", \"min\"]\n",
" fi = feature_importance.to_numpy()\n",
" shuffle = generate_random_shuffle(data, seed)\n",
" if mode == \"max\":\n",
" indices = np.argsort(-fi)\n",
" else:\n",
" indices = np.argsort(fi)\n",
" data_copy = data.copy()\n",
" for i in range(data.shape[0]):\n",
" for j in range(num_features):\n",
" data_copy[i, indices[i,j]] = shuffle[i, indices[i,j]]\n",
" return data_copy\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "operands could not be broadcast together with shapes (66,13) (134,13) ",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[5], line 26\u001b[0m\n\u001b[0;32m 23\u001b[0m metric_results[\u001b[39m'\u001b[39m\u001b[39mMSE_before_ablation\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m mean_squared_error(y_test, y_pred)\n\u001b[0;32m 25\u001b[0m \u001b[39m# Ablation\u001b[39;00m\n\u001b[1;32m---> 26\u001b[0m score \u001b[39m=\u001b[39m rf_plus_model\u001b[39m.\u001b[39;49mget_mdi_plus_scores(X_test, y_test, lfi\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, lfi_abs \u001b[39m=\u001b[39;49m \u001b[39m\"\u001b[39;49m\u001b[39moutside\u001b[39;49m\u001b[39m\"\u001b[39;49m, sample_split\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m)\n\u001b[0;32m 27\u001b[0m local_fi_score \u001b[39m=\u001b[39m score[\u001b[39m\"\u001b[39m\u001b[39mlfi\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[0;32m 28\u001b[0m ascending \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m \u001b[39m# False for MDI\u001b[39;00m\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\rf_plus.py:379\u001b[0m, in \u001b[0;36m_RandomForestPlus.get_mdi_plus_scores\u001b[1;34m(self, X, y, scoring_fns, local_scoring_fns, sample_split, mode, version, lfi, lfi_abs)\u001b[0m\n\u001b[0;32m 367\u001b[0m mdi_plus_obj \u001b[39m=\u001b[39m ForestMDIPlus(estimators\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mestimators_,\n\u001b[0;32m 368\u001b[0m transformers\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransformers_,\n\u001b[0;32m 369\u001b[0m scoring_fns\u001b[39m=\u001b[39mscoring_fns,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 376\u001b[0m normalize\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnormalize,\n\u001b[0;32m 377\u001b[0m version\u001b[39m=\u001b[39mversion)\n\u001b[0;32m 378\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmdi_plus_ \u001b[39m=\u001b[39m mdi_plus_obj\n\u001b[1;32m--> 379\u001b[0m mdi_plus_scores \u001b[39m=\u001b[39m mdi_plus_obj\u001b[39m.\u001b[39;49mget_scores(X_array, y, lfi\u001b[39m=\u001b[39;49mlfi,\n\u001b[0;32m 380\u001b[0m lfi_abs\u001b[39m=\u001b[39;49mlfi_abs)\n\u001b[0;32m 381\u001b[0m \u001b[39mif\u001b[39;00m lfi \u001b[39mand\u001b[39;00m local_scoring_fns:\n\u001b[0;32m 382\u001b[0m mdi_plus_lfi \u001b[39m=\u001b[39m mdi_plus_scores[\u001b[39m\"\u001b[39m\u001b[39mlfi\u001b[39m\u001b[39m\"\u001b[39m]\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\mdi_plus.py:126\u001b[0m, in \u001b[0;36mForestMDIPlus.get_scores\u001b[1;34m(self, X, y, lfi, lfi_abs)\u001b[0m\n\u001b[0;32m 124\u001b[0m \u001b[39m# print(\"IN 'get_scores' METHOD WITHIN THE FOREST MDI PLUS OBJECT\")\u001b[39;00m\n\u001b[0;32m 125\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlfi_abs \u001b[39m=\u001b[39m lfi_abs\n\u001b[1;32m--> 126\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit_importance_scores(X, y)\n\u001b[0;32m 127\u001b[0m \u001b[39mif\u001b[39;00m lfi:\n\u001b[0;32m 128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlocal_scoring_fns:\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\mdi_plus.py:223\u001b[0m, in \u001b[0;36mForestMDIPlus._fit_importance_scores\u001b[1;34m(self, X, y)\u001b[0m\n\u001b[0;32m 208\u001b[0m \u001b[39mfor\u001b[39;00m estimator, transformer, tree_random_state \u001b[39min\u001b[39;00m \\\n\u001b[0;32m 209\u001b[0m \u001b[39mzip\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mestimators, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtransformers, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtree_random_states):\n\u001b[0;32m 210\u001b[0m tree_mdi_plus \u001b[39m=\u001b[39m TreeMDIPlus(estimator\u001b[39m=\u001b[39mestimator,\n\u001b[0;32m 211\u001b[0m transformer\u001b[39m=\u001b[39mtransformer,\n\u001b[0;32m 212\u001b[0m scoring_fns\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscoring_fns,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 221\u001b[0m num_iters\u001b[39m=\u001b[39mnum_iters,\n\u001b[0;32m 222\u001b[0m lfi_abs\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlfi_abs)\n\u001b[1;32m--> 223\u001b[0m scores \u001b[39m=\u001b[39m tree_mdi_plus\u001b[39m.\u001b[39;49mget_scores(X, y)\n\u001b[0;32m 224\u001b[0m lfi_matrix_lst\u001b[39m.\u001b[39mappend(tree_mdi_plus\u001b[39m.\u001b[39mlfi_matrix)\n\u001b[0;32m 225\u001b[0m \u001b[39mif\u001b[39;00m scores \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\mdi_plus.py:383\u001b[0m, in \u001b[0;36mTreeMDIPlus.get_scores\u001b[1;34m(self, X, y)\u001b[0m\n\u001b[0;32m 366\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 367\u001b[0m \u001b[39mObtain the MDI+ feature importances for a single tree.\u001b[39;00m\n\u001b[0;32m 368\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 380\u001b[0m \u001b[39m The MDI+ feature importances.\u001b[39;00m\n\u001b[0;32m 381\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[0;32m 382\u001b[0m \u001b[39m# print(\"IN 'get_scores' METHOD WITHIN THE TREE MDI PLUS OBJECT\")\u001b[39;00m\n\u001b[1;32m--> 383\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit_importance_scores(X, y)\n\u001b[0;32m 384\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mlocal_scoring_fns:\n\u001b[0;32m 385\u001b[0m \u001b[39mreturn\u001b[39;00m {\u001b[39m\"\u001b[39m\u001b[39mglobal\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeature_importances_,\n\u001b[0;32m 386\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mlocal\u001b[39m\u001b[39m\"\u001b[39m: \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfeature_importances_local_}\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\mdi_plus.py:436\u001b[0m, in \u001b[0;36mTreeMDIPlus._fit_importance_scores\u001b[1;34m(self, X, y)\u001b[0m\n\u001b[0;32m 432\u001b[0m \u001b[39mif\u001b[39;00m train_blocked_data\u001b[39m.\u001b[39mget_all_data()\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m] \u001b[39m!=\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m 433\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mestimator, \u001b[39m\"\u001b[39m\u001b[39mpredict_full\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mand\u001b[39;00m \\\n\u001b[0;32m 434\u001b[0m \u001b[39mhasattr\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mestimator, \u001b[39m\"\u001b[39m\u001b[39mpredict_partial\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[0;32m 435\u001b[0m \u001b[39m# print(\"IN IF STATEMENT IN LINE 389\")\u001b[39;00m\n\u001b[1;32m--> 436\u001b[0m full_preds \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mestimator\u001b[39m.\u001b[39;49mpredict_full(test_blocked_data)\n\u001b[0;32m 437\u001b[0m \u001b[39m# print(\"IN STATEMENT LINE 391\")\u001b[39;00m\n\u001b[0;32m 438\u001b[0m partial_preds \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mestimator\u001b[39m.\u001b[39mpredict_partial(test_blocked_data, mode\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmode, zero_values\u001b[39m=\u001b[39mzero_values)\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\ppms.py:308\u001b[0m, in \u001b[0;36m_GlmPPM.predict_full\u001b[1;34m(self, blocked_data)\u001b[0m\n\u001b[0;32m 306\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mpredict_full\u001b[39m(\u001b[39mself\u001b[39m, blocked_data):\n\u001b[0;32m 307\u001b[0m \u001b[39m# print(\"IN 'predict_full' method of _GlmPPM\")\u001b[39;00m\n\u001b[1;32m--> 308\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpredict_loo(blocked_data\u001b[39m.\u001b[39;49mget_all_data())\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\ppms.py:296\u001b[0m, in \u001b[0;36m_GlmPPM.predict_loo\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 294\u001b[0m \u001b[39mfor\u001b[39;00m j \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_n_outputs):\n\u001b[0;32m 295\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mloo:\n\u001b[1;32m--> 296\u001b[0m preds_j \u001b[39m=\u001b[39m _get_preds(X, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mloo_coefficients_[j], \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minv_link_fn)\n\u001b[0;32m 297\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m 298\u001b[0m preds_j \u001b[39m=\u001b[39m _get_preds(X, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcoefficients_[j], \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minv_link_fn)\n",
"File \u001b[1;32md:\\local_MDI+\\imodels-experiments\\feature_importance\\../../imodels\\imodels\\importance\\ppms.py:626\u001b[0m, in \u001b[0;36m_get_preds\u001b[1;34m(data_block, coefs, inv_link_fn, intercept)\u001b[0m\n\u001b[0;32m 624\u001b[0m intercept \u001b[39m=\u001b[39m coefs[:, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m]\n\u001b[0;32m 625\u001b[0m coefs \u001b[39m=\u001b[39m coefs[:, :\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m]\n\u001b[1;32m--> 626\u001b[0m lin_preds \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39msum(data_block \u001b[39m*\u001b[39;49m coefs, axis\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m) \u001b[39m+\u001b[39m intercept\n\u001b[0;32m 627\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m 628\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(coefs) \u001b[39m==\u001b[39m (data_block\u001b[39m.\u001b[39mshape[\u001b[39m1\u001b[39m] \u001b[39m+\u001b[39m \u001b[39m1\u001b[39m):\n",
"\u001b[1;31mValueError\u001b[0m: operands could not be broadcast together with shapes (66,13) (134,13) "
]
}
],
"source": [
"# Define the data\n",
"n = 200\n",
"d = 10\n",
"mean = [[0]*10, [10]*10]\n",
"scale = [[1]*10,[1]*10]\n",
"s = 5\n",
"X = sample_normal_X_subgroups(n, d, mean, scale)\n",
"beta = np.concatenate((np.ones(s), np.zeros(d-s)))\n",
"y = np.matmul(X, beta)\n",
"split_seed = 0\n",
"X_train, X_tune, X_test, y_train, y_tune, y_test = apply_splitting_strategy(X, y, \"train-test\", split_seed)\n",
"\n",
"#Define the model and fit\n",
"rf_regressor = RandomForestRegressor(n_estimators=100, min_samples_leaf=5, max_features=0.33, random_state=331)\n",
"rf_plus_model = RandomForestPlusRegressor(rf_model=rf_regressor, include_raw=False)\n",
"rf_plus_model.fit(X_train, y_train)\n",
"\n",
"\n",
"# initialize the metric results\n",
"metric_results = {}\n",
"\n",
"y_pred = rf_plus_model.predict(X_test)\n",
"metric_results['MSE_before_ablation'] = mean_squared_error(y_test, y_pred)\n",
"\n",
"# Ablation\n",
"score = rf_plus_model.get_mdi_plus_scores(X_test, y_test, lfi=True, lfi_abs = \"outside\")\n",
"local_fi_score = score[\"lfi\"]\n",
"ascending = True # False for MDI\n",
"imp_vals = copy.deepcopy(local_fi_score)\n",
"imp_vals[imp_vals == float(\"-inf\")] = -sys.maxsize - 1\n",
"imp_vals[imp_vals == float(\"inf\")] = sys.maxsize - 1\n",
"seed = np.random.randint(0, 100000)\n",
"for i in range(X_test.shape[1]):\n",
" if ascending:\n",
" ablation_X_test = ablation(X_test, imp_vals, \"max\", i+1, seed)\n",
" else:\n",
" ablation_X_test = ablation(X_test, imp_vals, \"min\", i+1, seed)\n",
" metric_results[f'MSE_after_ablation_{i+1}'] = mean_squared_error(y_test, est.predict(ablation_X_test))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 4d74f40

Please sign in to comment.