-
Notifications
You must be signed in to change notification settings - Fork 287
/
config.py
277 lines (239 loc) · 11.4 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
from math import ceil
from typing import Optional
import logging
from argparse import ArgumentParser
import sys
import os
class Config:
@classmethod
def arguments_parser(cls) -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("-d", "--data", dest="data_path",
help="path to preprocessed dataset", required=False)
parser.add_argument("-te", "--test", dest="test_path",
help="path to test file", metavar="FILE", required=False, default='')
parser.add_argument("-s", "--save", dest="save_path",
help="path to save the model file", metavar="FILE", required=False)
parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
help="path to save the tokens embeddings file", metavar="FILE", required=False)
parser.add_argument("-t2v", "--save_target2v", dest="save_t2v",
help="path to save the targets embeddings file", metavar="FILE", required=False)
parser.add_argument("-l", "--load", dest="load_path",
help="path to load the model from", metavar="FILE", required=False)
parser.add_argument('--save_w2v', dest='save_w2v', required=False,
help="save word (token) vectors in word2vec format")
parser.add_argument('--save_t2v', dest='save_t2v', required=False,
help="save target vectors in word2vec format")
parser.add_argument('--export_code_vectors', action='store_true', required=False,
help="export code vectors for the given examples")
parser.add_argument('--release', action='store_true',
help='if specified and loading a trained model, release the loaded model for a lower model '
'size.')
parser.add_argument('--predict', action='store_true',
help='execute the interactive prediction shell')
parser.add_argument("-fw", "--framework", dest="dl_framework", choices=['keras', 'tensorflow'],
default='tensorflow', help="deep learning framework to use.")
parser.add_argument("-v", "--verbose", dest="verbose_mode", type=int, required=False, default=1,
help="verbose mode (should be in {0,1,2}).")
parser.add_argument("-lp", "--logs-path", dest="logs_path", metavar="FILE", required=False,
help="path to store logs into. if not given logs are not saved to file.")
parser.add_argument('-tb', '--tensorboard', dest='use_tensorboard', action='store_true',
help='use tensorboard during training')
return parser
def set_defaults(self):
self.NUM_TRAIN_EPOCHS = 20
self.SAVE_EVERY_EPOCHS = 1
self.TRAIN_BATCH_SIZE = 1024
self.TEST_BATCH_SIZE = self.TRAIN_BATCH_SIZE
self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
self.NUM_BATCHES_TO_LOG_PROGRESS = 100
self.NUM_TRAIN_BATCHES_TO_EVALUATE = 1800
self.READER_NUM_PARALLEL_BATCHES = 6 # cpu cores [for tf.contrib.data.map_and_batch() in the reader]
self.SHUFFLE_BUFFER_SIZE = 10000
self.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
self.MAX_TO_KEEP = 10
# model hyper-params
self.MAX_CONTEXTS = 200
self.MAX_TOKEN_VOCAB_SIZE = 1301136
self.MAX_TARGET_VOCAB_SIZE = 261245
self.MAX_PATH_VOCAB_SIZE = 911417
self.DEFAULT_EMBEDDINGS_SIZE = 128
self.TOKEN_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
self.PATH_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
self.CODE_VECTOR_SIZE = self.context_vector_size
self.TARGET_EMBEDDINGS_SIZE = self.CODE_VECTOR_SIZE
self.DROPOUT_KEEP_RATE = 0.75
self.SEPARATE_OOV_AND_PAD = False
def load_from_args(self):
args = self.arguments_parser().parse_args()
# Automatically filled, do not edit:
self.PREDICT = args.predict
self.MODEL_SAVE_PATH = args.save_path
self.MODEL_LOAD_PATH = args.load_path
self.TRAIN_DATA_PATH_PREFIX = args.data_path
self.TEST_DATA_PATH = args.test_path
self.RELEASE = args.release
self.EXPORT_CODE_VECTORS = args.export_code_vectors
self.SAVE_W2V = args.save_w2v
self.SAVE_T2V = args.save_t2v
self.VERBOSE_MODE = args.verbose_mode
self.LOGS_PATH = args.logs_path
self.DL_FRAMEWORK = 'tensorflow' if not args.dl_framework else args.dl_framework
self.USE_TENSORBOARD = args.use_tensorboard
def __init__(self, set_defaults: bool = False, load_from_args: bool = False, verify: bool = False):
self.NUM_TRAIN_EPOCHS: int = 0
self.SAVE_EVERY_EPOCHS: int = 0
self.TRAIN_BATCH_SIZE: int = 0
self.TEST_BATCH_SIZE: int = 0
self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION: int = 0
self.NUM_BATCHES_TO_LOG_PROGRESS: int = 0
self.NUM_TRAIN_BATCHES_TO_EVALUATE: int = 0
self.READER_NUM_PARALLEL_BATCHES: int = 0
self.SHUFFLE_BUFFER_SIZE: int = 0
self.CSV_BUFFER_SIZE: int = 0
self.MAX_TO_KEEP: int = 0
# model hyper-params
self.MAX_CONTEXTS: int = 0
self.MAX_TOKEN_VOCAB_SIZE: int = 0
self.MAX_TARGET_VOCAB_SIZE: int = 0
self.MAX_PATH_VOCAB_SIZE: int = 0
self.DEFAULT_EMBEDDINGS_SIZE: int = 0
self.TOKEN_EMBEDDINGS_SIZE: int = 0
self.PATH_EMBEDDINGS_SIZE: int = 0
self.CODE_VECTOR_SIZE: int = 0
self.TARGET_EMBEDDINGS_SIZE: int = 0
self.DROPOUT_KEEP_RATE: float = 0
self.SEPARATE_OOV_AND_PAD: bool = False
# Automatically filled by `args`.
self.PREDICT: bool = False # TODO: update README;
self.MODEL_SAVE_PATH: Optional[str] = None
self.MODEL_LOAD_PATH: Optional[str] = None
self.TRAIN_DATA_PATH_PREFIX: Optional[str] = None
self.TEST_DATA_PATH: Optional[str] = ''
self.RELEASE: bool = False
self.EXPORT_CODE_VECTORS: bool = False
self.SAVE_W2V: Optional[str] = None # TODO: update README;
self.SAVE_T2V: Optional[str] = None # TODO: update README;
self.VERBOSE_MODE: int = 0
self.LOGS_PATH: Optional[str] = None
self.DL_FRAMEWORK: str = '' # in {'keras', 'tensorflow'}
self.USE_TENSORBOARD: bool = False
# Automatically filled by `Code2VecModelBase._init_num_of_examples()`.
self.NUM_TRAIN_EXAMPLES: int = 0
self.NUM_TEST_EXAMPLES: int = 0
self.__logger: Optional[logging.Logger] = None
if set_defaults:
self.set_defaults()
if load_from_args:
self.load_from_args()
if verify:
self.verify()
@property
def context_vector_size(self) -> int:
# The context vector is actually a concatenation of the embedded
# source & target vectors and the embedded path vector.
return self.PATH_EMBEDDINGS_SIZE + 2 * self.TOKEN_EMBEDDINGS_SIZE
@property
def is_training(self) -> bool:
return bool(self.TRAIN_DATA_PATH_PREFIX)
@property
def is_loading(self) -> bool:
return bool(self.MODEL_LOAD_PATH)
@property
def is_saving(self) -> bool:
return bool(self.MODEL_SAVE_PATH)
@property
def is_testing(self) -> bool:
return bool(self.TEST_DATA_PATH)
@property
def train_steps_per_epoch(self) -> int:
return ceil(self.NUM_TRAIN_EXAMPLES / self.TRAIN_BATCH_SIZE) if self.TRAIN_BATCH_SIZE else 0
@property
def test_steps(self) -> int:
return ceil(self.NUM_TEST_EXAMPLES / self.TEST_BATCH_SIZE) if self.TEST_BATCH_SIZE else 0
def data_path(self, is_evaluating: bool = False):
return self.TEST_DATA_PATH if is_evaluating else self.train_data_path
def batch_size(self, is_evaluating: bool = False):
return self.TEST_BATCH_SIZE if is_evaluating else self.TRAIN_BATCH_SIZE # take min with NUM_TRAIN_EXAMPLES?
@property
def train_data_path(self) -> Optional[str]:
if not self.is_training:
return None
return '{}.train.c2v'.format(self.TRAIN_DATA_PATH_PREFIX)
@property
def word_freq_dict_path(self) -> Optional[str]:
if not self.is_training:
return None
return '{}.dict.c2v'.format(self.TRAIN_DATA_PATH_PREFIX)
@classmethod
def get_vocabularies_path_from_model_path(cls, model_file_path: str) -> str:
vocabularies_save_file_name = "dictionaries.bin"
return '/'.join(model_file_path.split('/')[:-1] + [vocabularies_save_file_name])
@classmethod
def get_entire_model_path(cls, model_path: str) -> str:
return model_path + '__entire-model'
@classmethod
def get_model_weights_path(cls, model_path: str) -> str:
return model_path + '__only-weights'
@property
def model_load_dir(self):
return '/'.join(self.MODEL_LOAD_PATH.split('/')[:-1])
@property
def entire_model_load_path(self) -> Optional[str]:
if not self.is_loading:
return None
return self.get_entire_model_path(self.MODEL_LOAD_PATH)
@property
def model_weights_load_path(self) -> Optional[str]:
if not self.is_loading:
return None
return self.get_model_weights_path(self.MODEL_LOAD_PATH)
@property
def entire_model_save_path(self) -> Optional[str]:
if not self.is_saving:
return None
return self.get_entire_model_path(self.MODEL_SAVE_PATH)
@property
def model_weights_save_path(self) -> Optional[str]:
if not self.is_saving:
return None
return self.get_model_weights_path(self.MODEL_SAVE_PATH)
def verify(self):
if not self.is_training and not self.is_loading:
raise ValueError("Must train or load a model.")
if self.is_loading and not os.path.isdir(self.model_load_dir):
raise ValueError("Model load dir `{model_load_dir}` does not exist.".format(
model_load_dir=self.model_load_dir))
if self.DL_FRAMEWORK not in {'tensorflow', 'keras'}:
raise ValueError("config.DL_FRAMEWORK must be in {'tensorflow', 'keras'}.")
def __iter__(self):
for attr_name in dir(self):
if attr_name.startswith("__"):
continue
try:
attr_value = getattr(self, attr_name, None)
except:
attr_value = None
if callable(attr_value):
continue
yield attr_name, attr_value
def get_logger(self) -> logging.Logger:
if self.__logger is None:
self.__logger = logging.getLogger('code2vec')
self.__logger.setLevel(logging.INFO)
self.__logger.handlers = []
self.__logger.propagate = 0
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
if self.VERBOSE_MODE >= 1:
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
self.__logger.addHandler(ch)
if self.LOGS_PATH:
fh = logging.FileHandler(self.LOGS_PATH)
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
self.__logger.addHandler(fh)
return self.__logger
def log(self, msg):
self.get_logger().info(msg)