-
Notifications
You must be signed in to change notification settings - Fork 61
/
testall.sh
executable file
·64 lines (51 loc) · 1.66 KB
/
testall.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/bin/bash
N_TRIALS=2
EPOCHS=3
SKLEARN_ENV="sklearn"
GBDT_ENV="gbdt"
TORCH_ENV="torch"
KERAS_ENV="tensorflow"
# "LinearModel" "KNN" "DecisionTree" "RandomForest"
# "XGBoost" "CatBoost" "LightGBM"
# "MLP" "TabNet" "VIME"
# MODELS=( "LinearModel" "KNN" "DecisionTree" "RandomForest" "XGBoost" "CatBoost" "LightGBM" "MLP" "TabNet" "VIME")
declare -A MODELS
MODELS=( ["LinearModel"]=$SKLEARN_ENV
["KNN"]=$SKLEARN_ENV
# ["SVM"]=$SKLEARN_ENV
["DecisionTree"]=$SKLEARN_ENV
["RandomForest"]=$SKLEARN_ENV
["XGBoost"]=$GBDT_ENV
["CatBoost"]=$GBDT_ENV
["LightGBM"]=$GBDT_ENV
["MLP"]=$TORCH_ENV
["TabNet"]=$TORCH_ENV
["VIME"]=$TORCH_ENV
["TabTransformer"]=$TORCH_ENV
["ModelTree"]=$GBDT_ENV
["NODE"]=$TORCH_ENV
["DeepGBM"]=$TORCH_ENV
["RLN"]=$KERAS_ENV
["DNFNet"]=$KERAS_ENV
["STG"]=$TORCH_ENV
["NAM"]=$TORCH_ENV
["DeepFM"]=$TORCH_ENV
["SAINT"]=$TORCH_ENV
["DANet"]=$TORCH_ENV
)
CONFIGS=( "config/adult.yml"
"config/covertype.yml"
"config/california_housing.yml"
"config/higgs.yml"
)
# conda init bash
eval "$(conda shell.bash hook)"
for config in "${CONFIGS[@]}"; do
for model in "${!MODELS[@]}"; do
printf "\n\n----------------------------------------------------------------------------\n"
printf 'Training %s with %s in env %s\n\n' "$model" "$config" "${MODELS[$model]}"
conda activate "${MODELS[$model]}"
python train.py --config "$config" --model_name "$model" --n_trials $N_TRIALS --epochs $EPOCHS
conda deactivate
done
done