-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
34 lines (27 loc) · 1.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from functools import reduce
from contrastyou.configure import dictionary_merge_by_hierachy, extract_dictionary_from_anchor, \
extract_params_with_key_prefix
def separate_pretrain_finetune_configs(config_manager):
input_params = config_manager.parsed_config
base_config = config_manager.base_config
opt_params = reduce(dictionary_merge_by_hierachy, config_manager.optional_configs)
pretrain_config = dictionary_merge_by_hierachy(base_config, opt_params)
# extract the input_params for both settings
pretrain_config = dictionary_merge_by_hierachy(
pretrain_config,
extract_dictionary_from_anchor(target_dictionary=input_params,
anchor_dictionary=pretrain_config,
prune_anchor=True))
# extract input_params for pre_
pretrain_config = dictionary_merge_by_hierachy(
pretrain_config,
extract_params_with_key_prefix(input_params, prefix="pre_"))
base_config = dictionary_merge_by_hierachy(
base_config,
extract_dictionary_from_anchor(target_dictionary=input_params,
anchor_dictionary=base_config,
prune_anchor=True))
base_config = dictionary_merge_by_hierachy(
base_config,
extract_params_with_key_prefix(input_params, prefix="ft_"))
return pretrain_config, base_config