Skip to content

Commit

Permalink
multiexperiments code.
Browse files Browse the repository at this point in the history
  • Loading branch information
juselara1 committed Feb 3, 2021
1 parent 22745a3 commit d63699a
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 43 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/replication/hyperparams/dmae_params.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"use_cov": false,
"n_clusters": 10,
"alpha": 100,
"dissimilarity": "kullback_leibler"
"dissimilarity": "euclidean"
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
"iters": 100,
"use_cov": false,
"n_clusters": 10,
"epochs": 2,
"epochs": 500,
"batch_size": 256
}
2 changes: 1 addition & 1 deletion examples/scripts/replication/hyperparams/train_params.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"epochs": 2,
"epochs": 300,
"batch_size": 256
}
19 changes: 11 additions & 8 deletions examples/scripts/replication/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,34 @@
class Logger:
#TODOC
def __init__(
self, models, metrics,
datasets, dataset_name,
self,
dataset_name,
evaluable_datasets = [
"pretrain", "clustering"
],
metric_names = ["uACC", "NMI", "ARS"],
save_path="./results/"
):

self.__models = models
self.__metrics = metrics
self.__datasets = datasets
self.__dataset_name = dataset_name
self.__evaluable_datasets = evaluable_datasets
self.__metric_names = metric_names
self.__save_path = save_path
self.__make_dataframes()

def add(self, models, metrics, datasets):
self.__models = models
self.__metrics = metrics
self.__datasets = datasets

def __make_dataframes(self):
#TODOC
self.__dfs = {}
for dataset in self.__evaluable_datasets:
metric_names = tuple(self.__metrics.keys())
self.__dfs[dataset] = _pd.DataFrame(
columns=[
"iteration",
*metric_names
*self.__metric_names
]
)
self.__dfs[dataset].set_index("iteration", inplace=True)
Expand All @@ -44,7 +47,7 @@ def __call__(
dataset.reset_gen()
dataset = dataset()

y_true = self.__datasets["labels"]
y_true = self.__datasets["labels"].astype("int32")
preds = model.predict(
dataset,
steps=self.__datasets["steps"]
Expand Down
42 changes: 23 additions & 19 deletions examples/scripts/replication/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@ def make_parser():
parser.add_argument("--iters", type=int, default=5)
return parser

if __name__ == '__main__':
parser = make_parser()
args = parser.parse_args()
arguments = utils.json_arguments(args)

def make_iteration(arguments, scorer, iteration):
# make models
models = models.make_models(
cur_models = models.make_models(
arguments["encoder_params"],
arguments["decoder_params"],
arguments["dmae_params"],
Expand All @@ -32,34 +28,42 @@ def make_parser():
)

# make datasets
datasets = datasets.make_datasets(
cur_datasets = datasets.make_datasets(
**arguments["dataset_params"]
)

# make metrics
metrics = metrics.make_metrics()
cur_metrics = metrics.make_metrics()

# logger
scorer = logger.Logger(
models, metrics,
datasets, "mnist"
)

scorer.add(cur_models, cur_metrics, cur_datasets)

# Pretrain
arguments["pretrain_params"]\
["steps_per_epoch"] = datasets["steps"]
["steps_per_epoch"] = cur_datasets["steps"]
arguments["train_params"]\
["steps_per_epoch"] = datasets["steps"]
["steps_per_epoch"] = cur_datasets["steps"]

train.pretrain(
models, datasets,
cur_models, cur_datasets,
arguments["pretrain_params"],
scorer, iteration=1,
scorer, iteration=iteration,
dissimilarity=arguments["dmae_params"]["dissimilarity"]
)
train.train(
models, datasets,
cur_models, cur_datasets,
arguments["train_params"],
scorer, iteration=1
scorer, iteration=iteration
)

if __name__ == '__main__':
parser = make_parser()
args = parser.parse_args()
arguments = utils.json_arguments(args)
scorer = logger.Logger(
arguments["dataset_params"]["dataset_name"]
)

for iteration in range(args.iters):
make_iteration(arguments, scorer, iteration)
scorer.save()
7 changes: 4 additions & 3 deletions examples/scripts/replication/results/mnist/clustering.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
iteration,uACC,NMI,ARS
1,0.19112007783882784,0.1706133870001527,-0.17376299376086046
mean,0.19112007783882784,0.1706133870001527,-0.17376299376086046
std,0.0,0.0,0.0
0,0.3249198717948718,0.2974064119317825,0.16677801889403499
1,0.27691449175824173,0.25669355378788655,0.12515683728665603
mean,0.3009171817765568,0.2770499828598345,0.14596742809034552
std,0.024002690018315037,0.020356429071947985,0.02081059080368948
7 changes: 4 additions & 3 deletions examples/scripts/replication/results/mnist/pretrain.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
iteration,uACC,NMI,ARS
1,0.34450835622710624,0.3124251856194935,0.16642473909353817
mean,0.34450835622710624,0.3124251856194935,0.16642473909353817
std,0.0,0.0,0.0
0,0.37102220695970695,0.3299324743100253,0.2076916216553968
1,0.32822516025641024,0.26351401699037746,0.13510127504054154
mean,0.3496236836080586,0.29672324565020136,0.1713964483479692
std,0.021398523351648352,0.033209228659823925,0.03629517330742764
22 changes: 15 additions & 7 deletions requirements_docker_gpu.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
absl-py==0.11.0
astunparse==1.6.3
cachetools==4.2.0
cachetools==4.2.1
certifi==2020.12.5
chardet==4.0.0
cycler==0.10.0
flatbuffers==1.12
gast==0.3.3
google-auth==1.24.0
Expand All @@ -11,17 +12,24 @@ google-pasta==0.2.0
grpcio==1.32.0
h5py==2.10.0
idna==2.10
importlib-metadata==3.4.0
joblib==1.0.0
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.3
numpy==1.19.5
pandas==1.1.5
matplotlib==3.3.4
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
pandas==1.1.5
Pillow==8.1.0
protobuf==3.14.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.7
python-dateutil==2.8.1
python-ternary==1.0.7
pytz==2021.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7
Expand All @@ -30,14 +38,14 @@ scipy==1.5.4
six==1.15.0
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorflow-gpu==2.4.1
tensorflow-addons==0.12.0
tensorflow-addons==0.12.1
tensorflow-estimator==2.4.0
tensorflow-gpu==2.4.1
termcolor==1.1.0
threadpoolctl==2.1.0
typeguard==2.10.0
typing-extensions==3.7.4.3
urllib3==1.26.2
urllib3==1.26.3
Werkzeug==1.0.1
wrapt==1.12.1
python-ternary==1.0.7
zipp==3.4.0

0 comments on commit d63699a

Please sign in to comment.