-
Notifications
You must be signed in to change notification settings - Fork 6
/
configurator.py
276 lines (245 loc) · 10.4 KB
/
configurator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import re
import os
import sys
from enum import Enum
import yaml
from logging import getLogger
from evaluator import metric_types, smaller_metrics
class Config(object):
def __init__(
self, model=None, dataset=None, config_file_list=None, config_dict=None
):
self.yaml_loader = self._build_yaml_loader()
self.file_config_dict = self._load_config_files(config_file_list)
self.variable_config_dict = self._load_variable_config_dict(config_dict)
self.cmd_config_dict = self._load_cmd_line()
self._merge_external_config_dict()
self.model, self.dataset = self._get_model_and_dataset(model, dataset)
self._load_internal_config_dict(self.model, self.dataset)
self.final_config_dict = self._get_final_config_dict()
self._set_default_parameters()
self._init_device()
def _build_yaml_loader(self):
loader = yaml.FullLoader
loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
return loader
def _convert_config_dict(self, config_dict):
r"""This function convert the str parameters to their original type."""
for key in config_dict:
param = config_dict[key]
if not isinstance(param, str):
continue
try:
value = eval(param)
if value is not None and not isinstance(
value, (str, int, float, list, tuple, dict, bool, Enum)
):
value = param
except (NameError, SyntaxError, TypeError):
if isinstance(param, str):
if param.lower() == "true":
value = True
elif param.lower() == "false":
value = False
else:
value = param
else:
value = param
config_dict[key] = value
return config_dict
def _load_config_files(self, file_list):
file_config_dict = dict()
if file_list:
for file in file_list:
with open(file, "r", encoding="utf-8") as f:
file_config_dict.update(
yaml.load(f.read(), Loader=self.yaml_loader)
)
return file_config_dict
def _load_variable_config_dict(self, config_dict):
# HyperTuning may set the parameters such as mlp_hidden_size in NeuMF in the format of ['[]', '[]']
# then config_dict will receive a str '[]', but indeed it's a list []
# temporarily use _convert_config_dict to solve this problem
return self._convert_config_dict(config_dict) if config_dict else dict()
def _load_cmd_line(self):
r"""Read parameters from command line and convert it to str."""
cmd_config_dict = dict()
unrecognized_args = []
if "ipykernel_launcher" not in sys.argv[0]:
for arg in sys.argv[1:]:
if not arg.startswith("--") or len(arg[2:].split("=")) != 2:
unrecognized_args.append(arg)
continue
cmd_arg_name, cmd_arg_value = arg[2:].split("=")
if (
cmd_arg_name in cmd_config_dict
and cmd_arg_value != cmd_config_dict[cmd_arg_name]
):
raise SyntaxError(
"There are duplicate commend arg '%s' with different value."
% arg
)
else:
cmd_config_dict[cmd_arg_name] = cmd_arg_value
if len(unrecognized_args) > 0:
logger = getLogger()
logger.warning(
"command line args [{}] will not be used in RecBole".format(
" ".join(unrecognized_args)
)
)
cmd_config_dict = self._convert_config_dict(cmd_config_dict)
return cmd_config_dict
def _merge_external_config_dict(self):
external_config_dict = dict()
external_config_dict.update(self.file_config_dict)
external_config_dict.update(self.variable_config_dict)
external_config_dict.update(self.cmd_config_dict)
self.external_config_dict = external_config_dict
def _get_model_and_dataset(self, model, dataset):
if model is None:
try:
final_model = self.external_config_dict["model"]
except KeyError:
raise KeyError(
"model need to be specified in at least one of the these ways: "
"[model variable, config file, config dict, command line] "
)
else:
final_model = model
if dataset is None:
try:
final_dataset = self.external_config_dict["dataset"]
except KeyError:
raise KeyError(
"dataset need to be specified in at least one of the these ways: "
"[dataset variable, config file, config dict, command line] "
)
else:
final_dataset = dataset
return final_model, final_dataset
def _update_internal_config_dict(self, file):
with open(file, "r", encoding="utf-8") as f:
config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
if config_dict is not None:
self.internal_config_dict.update(config_dict)
return config_dict
def _load_internal_config_dict(self, model, dataset):
current_path = os.path.dirname(os.path.realpath(__file__))
overall_init_file = os.path.join(current_path, "./config/overall.yaml")
model_init_file = os.path.join(
current_path, "./config/" + model + ".yaml"
)
dataset_init_file = os.path.join(
current_path, "./config/" + dataset + ".yaml"
)
self.internal_config_dict = dict()
for file in [
overall_init_file,
model_init_file,
dataset_init_file,
]:
if os.path.isfile(file):
self._update_internal_config_dict(file)
def _get_final_config_dict(self):
final_config_dict = dict()
final_config_dict.update(self.internal_config_dict)
final_config_dict.update(self.external_config_dict)
return final_config_dict
def _set_default_parameters(self):
self.final_config_dict["dataset"] = self.dataset
self.final_config_dict["model"] = self.model
self.final_config_dict["data_path"] = os.path.join(
self.final_config_dict["data_path"], self.dataset
)
metrics = self.final_config_dict["metrics"]
if isinstance(metrics, str):
self.final_config_dict["metrics"] = [metrics]
eval_type = set()
for metric in self.final_config_dict["metrics"]:
if metric.lower() in metric_types:
eval_type.add(metric_types[metric.lower()])
else:
raise NotImplementedError(f"There is no metric named '{metric}'")
if len(eval_type) > 1:
raise RuntimeError(
"Ranking metrics and value metrics can not be used at the same time."
)
self.final_config_dict["eval_type"] = eval_type.pop()
valid_metric = self.final_config_dict["valid_metric"].split("@")[0]
self.final_config_dict["valid_metric_bigger"] = (
False if valid_metric.lower() in smaller_metrics else True
)
topk = self.final_config_dict["topk"]
if isinstance(topk, (int, list)):
if isinstance(topk, int):
topk = [topk]
for k in topk:
if k <= 0:
raise ValueError(
f"topk must be a positive integer or a list of positive integers, but get `{k}`"
)
self.final_config_dict["topk"] = topk
else:
raise TypeError(f"The topk [{topk}] must be a integer, list")
# train_neg_sample_args checking
if "train_neg_sample_num" not in self.final_config_dict:
self.final_config_dict["train_neg_sample_num"] = 1
# eval_candidate_num checking
if "eval_candidate_num" not in self.final_config_dict:
self.final_config_dict["eval_candidate_num"] = 50
def _init_device(self):
self.final_config_dict["gpu_id"] = str(self.final_config_dict["gpu_id"])
gpu_id = self.final_config_dict["gpu_id"]
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
import torch
self.final_config_dict["device"] = (
torch.device("cpu")
if len(gpu_id) == 0 or not torch.cuda.is_available()
else torch.device("cuda")
)
def change_direction(self):
self.final_config_dict['USER_ID_FIELD'], self.final_config_dict['ITEM_ID_FIELD'] = \
self.final_config_dict['ITEM_ID_FIELD'], self.final_config_dict['USER_ID_FIELD']
def __setitem__(self, key, value):
if not isinstance(key, str):
raise TypeError("index must be a str.")
self.final_config_dict[key] = value
def __getattr__(self, item):
if "final_config_dict" not in self.__dict__:
raise AttributeError(
f"'Config' object has no attribute 'final_config_dict'"
)
if item in self.final_config_dict:
return self.final_config_dict[item]
raise AttributeError(f"'Config' object has no attribute '{item}'")
def __getitem__(self, item):
if item in self.final_config_dict:
return self.final_config_dict[item]
else:
return None
def __contains__(self, key):
if not isinstance(key, str):
raise TypeError("index must be a str.")
return key in self.final_config_dict
def __str__(self):
args_info = '\n\t'.join(['Parameters:'] + [
f'{arg}={value}'
for arg, value in self.final_config_dict.items()
])
return args_info
def __repr__(self):
return self.__str__()