Skip to content

Commit

Permalink
Simplify the code base; saved fitted rf_plus; add other FI
Browse files Browse the repository at this point in the history
  • Loading branch information
zyliang2001 committed May 22, 2024
1 parent a7070d2 commit 3430b61
Show file tree
Hide file tree
Showing 29 changed files with 106,811 additions and 12,392 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ experiments/*.txt
_site

temp.ipynb
**.png
data
4 changes: 2 additions & 2 deletions feature_importance/01_ablation_classification_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#SBATCH --partition=yugroup

source activate mdi
command="01_run_ablation_classification.py --nreps 1 --config mdi_local.real_data_classification --split_seed 1 --ignore_cache --create_rmd --result_name diabetes_classification"

command="01_run_ablation_classification.py --nreps 1 --config mdi_local.real_data_classification --split_seed ${1} --ignore_cache --create_rmd --result_name diabetes_simplify"
# command="01_run_ablation_classification.py --nreps 1 --config mdi_local.real_data_classification --split_seed ${1} --ignore_cache --create_rmd --result_name Enhancer --ablate_features 20"
# Execute the command
python $command
3 changes: 2 additions & 1 deletion feature_importance/01_ablation_regression_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
#SBATCH --partition=yugroup

source activate mdi
command="01_run_ablation_regression.py --nreps 1 --config mdi_local.real_data_regression --split_seed ${1} --ignore_cache --create_rmd --result_name diabetes_regr_new"
command="01_run_ablation_regression.py --nreps 1 --config mdi_local.real_data_regression --split_seed ${1} --ignore_cache --create_rmd --result_name diabetes_test_new"
#command="01_run_ablation_regression.py --nreps 1 --config mdi_local.real_data_regression --split_seed ${1} --ignore_cache --create_rmd --result_name CCLE_AZD0530_new --ablate_features 20"

# Execute the command
python $command
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

slurm_script="01_ablation_classification_script.sh"

for rep in {1..5}
for rep in {1..2}
do
sbatch $slurm_script $rep # Submit SLURM job using the specified script
done
8 changes: 8 additions & 0 deletions feature_importance/01_ablation_script_regr.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

slurm_script="01_ablation_regression_script.sh"

for rep in {1..2}
do
sbatch $slurm_script $rep # Submit SLURM job using the specified script
done
440 changes: 235 additions & 205 deletions feature_importance/01_run_ablation_classification.py

Large diffs are not rendered by default.

387 changes: 213 additions & 174 deletions feature_importance/01_run_ablation_regression.py

Large diffs are not rendered by default.

Binary file removed feature_importance/diabetes_classification_test.png
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed feature_importance/diabetes_regression_test.png
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed feature_importance/diabetes_regression_train.png
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,36 @@
"data_name": "diabetes",
"sample_row_n": None
}
# X_PARAMS_DICT = {
# "source": "imodels",
# "data_name": "juvenile",
# "sample_row_n": None
# }

# X_PARAMS_DICT = {
# "source": "csv",
# "file_path": "/accounts/projects/binyu/zhongyuan_liang/local_MDI+/imodels-experiments/feature_importance/data/Enhancer/X_enhancer_cleaned.csv",
# "sample_row_n": 2000,
# "normalize": False
# }

Y_DGP = sample_real_data_y
Y_PARAMS_DICT = {
"source": "imodels",
"data_name": "diabetes"
}
# Y_PARAMS_DICT = {
# "source": "imodels",
# "data_name": "juvenile"
# }

# Y_PARAMS_DICT = {
# "source": "csv",
# "file_path": "/accounts/projects/binyu/zhongyuan_liang/local_MDI+/imodels-experiments/feature_importance/data/Enhancer/y_enhancer.csv",
# "sample_row_n": 2000
# }


# vary one parameter
VARY_PARAM_NAME = "sample_row_n"
VARY_PARAM_NAME = "normalize"
VARY_PARAM_VALS = {"keep_all_rows": None}
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
]

FI_ESTIMATORS = [
[FIModelConfig('TreeSHAP_RF', tree_shap_evaluation_RF, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LFI_fit_on_inbag_RF', LFI_evaluation_RF_MDI_classification, model_type='tree', splitting_strategy = "train-test", ascending = False, other_params={"include_raw":False, "fit_on":"inbag", "prediction_model": Ridge(alpha=1e-6)})],
[FIModelConfig('LFI_fit_on_OOB_RF', LFI_evaluation_RF_OOB, model_type='tree', splitting_strategy = "train-test", ascending = False, other_params={"fit_on":"oob"})],
[FIModelConfig('LFI_evaluate_on_all_RF_plus', LFI_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_evaluate_on_oob_RF_plus', LFI_evaluation_RF_plus_OOB, model_type='tree', splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('Kernel_SHAP_RF_plus', kernel_shap_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LIME_RF_plus', lime_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('TreeSHAP_RF', tree_shap_evaluation_RF, model_type='tree', base_model="RF", splitting_strategy = "train-test")],
[FIModelConfig('LFI_fit_on_inbag_RFPlus', LFI_evaluation_RFPlus_inbag, model_type='tree', base_model="RFPlus_inbag", splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_fit_on_OOB_RFPlus', LFI_evaluation_RFPlus_oob, model_type='tree', base_model="RFPlus_oob", splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_fit_on_all_evaluate_on_all_RFPlus', LFI_evaluation_RFPlus_all, model_type='tree', base_model="RFPlus_default", splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_fit_on_all_evaluate_on_oob_RFPlus', LFI_evaluation_RFPlus_oob, model_type='tree', base_model="RFPlus_default", splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('Kernel_SHAP_RF_plus', kernel_shap_evaluation_RF_plus, model_type='tree', base_model="RFPlus_default", splitting_strategy = "train-test")],
[FIModelConfig('LIME_RF_plus', lime_evaluation_RF_plus, model_type='tree', base_model="RFPlus_default", splitting_strategy = "train-test")],
[FIModelConfig('Random', None, model_type='tree', base_model="None", splitting_strategy = "train-test")],
[FIModelConfig('Oracle_test_RFPlus', LFI_evaluation_oracle_RF_plus, base_model="RFPlus_default", model_type='tree', splitting_strategy = "train-test", ascending = False)],
]
14 changes: 12 additions & 2 deletions feature_importance/fi_config/mdi_local/real_data_regression/dgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
# }
# X_PARAMS_DICT = {
# "source": "openml",
# "task_id": 359946,
# "task_id": 361236,
# "sample_row_n": None
# }
# X_PARAMS_DICT = {
# "source": "csv",
# "file_path": "/accounts/projects/binyu/zhongyuan_liang/local_MDI+/imodels-experiments/feature_importance/data/CCLE/X_ccle_rnaseq_cleaned_filtered5000.csv",
# "sample_row_n": None
# }

Expand All @@ -31,7 +36,12 @@
# }
# Y_PARAMS_DICT = {
# "source": "openml",
# "task_id": 359946
# "task_id": 361236
# }

# Y_PARAMS_DICT = {
# "source": "csv",
# "file_path": "/accounts/projects/binyu/zhongyuan_liang/local_MDI+/imodels-experiments/feature_importance/data/CCLE/y_ccle_rnaseq_AZD0530.csv",
# }

# vary one parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
]

FI_ESTIMATORS = [
[FIModelConfig('TreeSHAP_RF', tree_shap_evaluation_RF, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LFI_fit_on_inbag_RF', LFI_evaluation_RF_MDI, model_type='tree', splitting_strategy = "train-test", ascending = False, other_params={"include_raw":False, "fit_on":"inbag", "prediction_model": Ridge(alpha=1e-6)})],
[FIModelConfig('LFI_fit_on_OOB_RF', LFI_evaluation_RF_OOB, model_type='tree', splitting_strategy = "train-test", ascending = False, other_params={"fit_on":"oob"})],
[FIModelConfig('LFI_evaluate_on_all_RF_plus', LFI_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_evaluate_on_oob_RF_plus', LFI_evaluation_RF_plus_OOB, model_type='tree', splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('Kernel_SHAP_RF_plus', kernel_shap_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('LIME_RF_plus', lime_evaluation_RF_plus, model_type='tree', splitting_strategy = "train-test")],
[FIModelConfig('TreeSHAP_RF', tree_shap_evaluation_RF, model_type='tree', base_model="RF", splitting_strategy = "train-test")],
[FIModelConfig('LFI_fit_on_inbag_RFPlus', LFI_evaluation_RFPlus_inbag, model_type='tree', base_model="RFPlus_inbag", splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_fit_on_OOB_RFPlus', LFI_evaluation_RFPlus_oob, model_type='tree', base_model="RFPlus_oob", splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_fit_on_all_evaluate_on_all_RFPlus', LFI_evaluation_RFPlus_all, model_type='tree', base_model="RFPlus_default", splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('LFI_fit_on_all_evaluate_on_oob_RFPlus', LFI_evaluation_RFPlus_oob, model_type='tree', base_model="RFPlus_default", splitting_strategy = "train-test", ascending = False)],
[FIModelConfig('Kernel_SHAP_RF_plus', kernel_shap_evaluation_RF_plus, model_type='tree', base_model="RFPlus_default", splitting_strategy = "train-test")],
[FIModelConfig('LIME_RF_plus', lime_evaluation_RF_plus, model_type='tree', base_model="RFPlus_default", splitting_strategy = "train-test")],
[FIModelConfig('Random', None, model_type='tree', base_model="None", splitting_strategy = "train-test")],
[FIModelConfig('Oracle_test_RFPlus', LFI_evaluation_oracle_RF_plus, base_model="RFPlus_default", model_type='tree', splitting_strategy = "train-test", ascending = False)],
]
Loading

0 comments on commit 3430b61

Please sign in to comment.