-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
156 lines (129 loc) · 6.11 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from torch.utils.data import DataLoader, random_split
from Datasets import CarRacingDataset
from Models import ConvVAEWrapper, LiNetWrapper
from Utils import saveLogData, parseArguments, getValueFromDict
import sys
import os
from Globals import *
if __name__=="__main__":
cmd_args = parseArguments(sys.argv[1:])
num_epochs_pn = int(getValueFromDict(cmd_args, 'num_epochs', 1))
num_epochs_vae = int(getValueFromDict(cmd_args, 'num_epochs_vae', 1))
checkpoint_folder = str(getValueFromDict(cmd_args, 'checkpoint_folder', "./checkpoint"))
do_save_checkpoints = str(getValueFromDict(cmd_args, 'do_save_checkpoints', "True")).lower() in ["true", "1"]
do_balance_dataset = str(getValueFromDict(cmd_args, 'do_balance_dataset', "False")).lower() in ["true", "1"]
load_epoch = int(getValueFromDict(cmd_args, 'load_epoch', -1))
model_file = str(getValueFromDict(cmd_args, 'model_file', ""))
do_load_vae = str(getValueFromDict(cmd_args, 'do_load_vae', "False")).lower() in ["true", "1"]
if do_load_vae:
vae_model_file = str(getValueFromDict(cmd_args, 'vae_model_file', ""))
if vae_model_file=="":
print("vae_model_file argument is necessary!")
sys.exit(2)
vae_checkpoint_folder, vae_checkpoint_file = os.path.split(vae_model_file)
data_folder = str(getValueFromDict(cmd_args, 'data_folder', ""))
if data_folder=="":
print("data_folder argument is necessary!")
sys.exit(2)
batch_size = 128
num_skip_frames_dataset = 0 # Used for skipping zooming-in frames in the beginning of each episode
# VAE hyperparameters
args_vae = {'in_channels': IM_CHANNELS, # Input dimensions
'rows' : IM_HEIGHT,
'cols' : IM_WIDTH,
'num_hidden_features': [32, 64, 128, 256], # Hidden block channels
'num_latent_features': 32, # Latent space dims
'strides': [2, 2, 2, 2], # Strides for each hidden block
'do_use_cuda': True, # Use CUDA?
'comments': 'VAE for behaviour cloning'
}
# Policy network hyperparameters
args_pn = {'num_epochs': num_epochs_pn, # Number of epochs to train
'lr': 0.001, # Learning rate
'grad_clip': 0.1, # Gradient clip
'do_use_cuda': True, # Use CUDA?
'num_classes': NUM_DISCRETE_ACTIONS, # Output dimension of the network, number of actions in this case
'in_channels': args_vae['num_latent_features'], # Input channels
'num_channels': 64, # Channels of the first block
'category_weights': None, # Category weights for loss function
'comments': 'policy network for behaviour cloning',
}
log_data = {}
log_data['args_vae'] = args_vae # Save model hyperparams for further reference
log_data['args_pn'] = args_pn # Save model hyperparams for further reference
# Datasets and loaders
dataset = CarRacingDataset(data_folder, action_space='discrete', num_skip_frames=num_skip_frames_dataset)
if do_balance_dataset:
dataset.balance_dataset()
if R_TRAIN<1.0:
train_set, valid_set = random_split(dataset, [int(len(dataset)*R_TRAIN), len(dataset)-int(len(dataset)*(R_TRAIN))])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)
else:
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
valid_loader = None
args_pn['category_weights'] = dataset.calculate_weights()*100
# -------------------
# VAE
# -------------------
vae = ConvVAEWrapper(args_vae)
if do_load_vae:
# Continue training not implemented, just load existing model to skip VAE training
print("Loading VAE")
vae.load_checkpoint(vae_checkpoint_folder, vae_checkpoint_file)
valid_loss_vae = vae.test_epoch(valid_loader)
else:
# Training
print("Training VAE")
for epoch in range(num_epochs_vae):
print("Epoch: ", epoch)
train_loss_vae = vae.train_epoch(train_loader)
# Validate
if valid_loader is not None:
valid_loss_vae = vae.test_epoch(valid_loader)
else:
valid_loss_vae = 0
if do_save_checkpoints:
vae_checkpoint_filename = 'checkpoint.vae.epoch.'+str(epoch)+'.tar'
vae.save_checkpoint(folder=checkpoint_folder, filename=vae_checkpoint_filename)
# After training
if valid_loader is not None:
vae.visualize_decoder(valid_loader)
# -------------------
# Policy
# -------------------
policy_network = LiNetWrapper(args_pn)
if load_epoch>=0:
print("Loading PN")
checkpoint_folder, checkpoint_file = os.path.split(model_file)
policy_network.load_checkpoint(checkpoint_folder, checkpoint_file)
epoch_start = load_epoch + 1
else:
epoch_start = 0
print("Training PN")
for epoch in range(epoch_start, num_epochs_pn):
print("Epoch: ", epoch)
# Train
train_loss_pn = policy_network.train_epoch(train_loader, vae)
# Validate
if valid_loader is not None:
valid_loss_pn = policy_network.test_epoch(valid_loader, vae)
else:
valid_loss_pn = 0
# Save checkpoint
if do_save_checkpoints:
checkpoint_filename = 'checkpoint.policy.epoch.'+str(epoch)+'.tar'
policy_network.save_checkpoint(folder=checkpoint_folder, filename=checkpoint_filename)
# Logging
if epoch not in log_data.keys():
log_data[epoch] = {}
log_data['last_iteration'] = epoch
log_data[epoch]['train_loss_pn'] = train_loss_pn.avg
if valid_loader is not None:
log_data[epoch]['valid_loss_pn'] = valid_loss_pn.avg
if not do_load_vae:
log_data[epoch]['train_loss_vae'] = train_loss_vae.avg
if valid_loader is not None:
log_data[epoch]['valid_loss_vae'] = valid_loss_vae.avg
saveLogData(log_data, checkpoint_folder)