-
Notifications
You must be signed in to change notification settings - Fork 6
/
evaluate_gen.py
executable file
·64 lines (59 loc) · 2.85 KB
/
evaluate_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from sklearn.decomposition import PCA
import numpy as np
import torch
import math
import gin
import os
import json
def write_text(result_dict,file):
with open(file,'w+') as f:
json.dump(result_dict,f)
def eval_func(eval_dataset, metric_folder, it, preflix=""):
pca_rep = np.load()
beta_VAE_score = True
dci_score = True
factor_VAE_score = True
MIG_score = True
total_results_dict = {}
def _representation(x):
return pca_rep[x]
if beta_VAE_score:
with gin.unlock_config():
from evaluation.metrics.beta_vae import compute_beta_vae_sklearn
gin.bind_parameter("beta_vae_sklearn.batch_size", 64)
gin.bind_parameter("beta_vae_sklearn.num_train", 10000)
gin.bind_parameter("beta_vae_sklearn.num_eval", 5000)
result_dict = compute_beta_vae_sklearn(eval_dataset,_representation,random_state=np.random.RandomState(0),artifact_dir=None)
print("beta VAE score:" + str(result_dict))
total_results_dict["beta_VAE" + preflix] = result_dict
if dci_score:
from evaluation.metrics.dci import compute_dci
with gin.unlock_config():
gin.bind_parameter("dci.num_train", 10000)
gin.bind_parameter("dci.num_test", 5000)
result_dict = compute_dci(eval_dataset,_representation,random_state=np.random.RandomState(0),artifact_dir=None)
print("dci score:" + str(result_dict))
total_results_dict["dci" + preflix] = result_dict
if MIG_score:
with gin.unlock_config():
from evaluation.metrics.mig import compute_mig
from evaluation.metrics.utils import _histogram_discretize
gin.bind_parameter("mig.num_train",10000)
gin.bind_parameter("discretizer.discretizer_fn",_histogram_discretize)
gin.bind_parameter("discretizer.num_bins",20)
result_dict = compute_mig(eval_dataset,_representation,random_state=np.random.RandomState(0),artifact_dir=None)
print("MIG score:" + str(result_dict))
total_results_dict["MIG" + preflix] = result_dict
if factor_VAE_score:
with gin.unlock_config():
from evaluation.metrics.factor_vae import compute_factor_vae
gin.bind_parameter("factor_vae_score.num_variance_estimate",10000)
gin.bind_parameter("factor_vae_score.num_train",10000)
gin.bind_parameter("factor_vae_score.num_eval",5000)
gin.bind_parameter("factor_vae_score.batch_size",64)
gin.bind_parameter("prune_dims.threshold",0.05)
result_dict = compute_factor_vae(eval_dataset,_representation,random_state=np.random.RandomState(0),artifact_dir=None)
print("factor VAE score:" + str(result_dict))
total_results_dict["factor_VAE" + preflix] = result_dict
write_text(total_results_dict,metric_folder + f"/{it}.json")
return total_results_dict