-
Notifications
You must be signed in to change notification settings - Fork 9
/
parse_args.py
69 lines (58 loc) · 2.72 KB
/
parse_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import configargparse
def str2bool(v):
""" from https://stackoverflow.com/a/43357954/1361529 """
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise configargparse.ArgumentTypeError('Boolean value expected.')
def parse_args(config_file):
parser = configargparse.ArgParser(default_config_files=[config_file])
parser.add("--name", type=str, default="main")
parser.add("--train_data_path", action="append")
parser.add("--val_data_path", action="append")
parser.add("--test_data_path", action="append")
parser.add("--model_save_path", required=True)
parser.add("--num_mfcc", type=int, default=14)
parser.add("--pose_representation", type=str, default='3d_vec')
parser.add("--mean_dir_vec", action="append", type=float, nargs='*')
parser.add("--mean_pose", action="append", type=float, nargs='*')
parser.add("--random_seed", type=int, default=-1)
parser.add("--save_result_video", type=str2bool, default=True)
# word embedding
parser.add("--wordembed_path", type=str, default=None)
parser.add("--wordembed_dim", type=int, default=100)
parser.add("--freeze_wordembed", type=str2bool, default=False)
# model
parser.add("--model", type=str, required=True)
parser.add("--epochs", type=int, default=10)
parser.add("--batch_size", type=int, default=50)
parser.add("--dropout_prob", type=float, default=0.3)
parser.add("--n_layers", type=int, default=2)
parser.add("--hidden_size", type=int, default=200)
parser.add("--hidden_size_s2eg", type=int, default=8)
parser.add("--z_type", type=str, default='none')
parser.add("--input_context", type=str, default='both')
# dataset
parser.add("--motion_resampling_framerate", type=int, default=24)
parser.add("--n_poses", type=int, default=50)
parser.add("--n_pre_poses", type=int, default=5)
parser.add("--subdivision_stride", type=int, default=5)
parser.add("--loader_workers", type=int, default=0)
# GAN parameter
parser.add("--GAN_noise_size", type=int, default=0)
# training
parser.add("--learning_rate", type=float, default=0.001)
parser.add("--discriminator_lr_weight", type=float, default=0.2)
parser.add("--loss_regression_weight", type=float, default=50)
parser.add("--loss_gan_weight", type=float, default=1.0)
parser.add("--loss_kld_weight", type=float, default=0.1)
parser.add("--loss_reg_weight", type=float, default=0.01)
parser.add("--loss_warmup", type=int, default=-1)
# eval
parser.add("--val_net_path", type=str, default='')
args, _ = parser.parse_known_args()
return args