-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_image_resolution_model.py
138 lines (115 loc) · 5.91 KB
/
train_image_resolution_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from data import DIV2K
from model.edsr import edsr
from model.qedsr import qedsr
from model.wdsr import wdsr_b
import argparse
import tensorflow.compat.v1 as tf
import datetime
def PSNR(super_resolution, high_resolution):
"""Compute the peak signal-to-noise ratio, measures quality of image."""
# Max value of pixel is 255
psnr_value = tf.image.psnr(high_resolution, super_resolution, max_val=255)[0]
return psnr_value
def main(model_name, downgrade, scale, downgrade_for_validation,
scale_for_validation, precision, pretrained=None, batch_size=16, epochs=100, depth=16,
eval_all_distortions=False, train_all_distortions=False):
#treat downgrade= downgrade_for_training
downgrade_for_training = downgrade
scale_for_training = scale
div2k_train = DIV2K(scale=scale_for_training, subset='train', downgrade=downgrade_for_training)
train_ds = div2k_train.dataset(batch_size=batch_size, random_transform=True,
all_distortions=train_all_distortions)
if eval_all_distortions:
div2k_valid = {distortion: DIV2K(scale=scale_for_validation, subset='valid', downgrade=distortion)
for distortion in ['bicubic', 'unknown', 'mild', 'difficult']}
else:
div2k_valid = {downgrade_for_validation: DIV2K(scale=scale_for_validation, subset='valid', downgrade=downgrade_for_validation)}
valid_ds = {distortion: div2k_valid[distortion].dataset(batch_size=1, random_transform=False, repeat_count=1)
for distortion in div2k_valid}
if model_name == 'edsr':
model = edsr(scale=scale, num_res_blocks=depth)
elif model_name == 'qedsr':
model = qedsr(scale=scale, num_res_blocks=depth, precision=precision)
elif model_name == 'wdsr':
model = wdsr_b(scale=scale, num_res_blocks=depth)
else:
NotImplementedError(f"Model ({model_name}) not implemented. Only (edsr) and (wdsr) models are implemented.")
print(model.summary())
loss_object = tf.keras.losses.MeanAbsoluteError()
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=[200000], values=[1e-4, 5e-5])
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(loss=loss_object,
optimizer=optimizer,
metrics=['MAE', PSNR])
if pretrained:
print(f"Loading {pretrained}")
model.load_weights(pretrained)
log_dir = "logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir + '/checkpoints/',
monitor='MAE',
mode='min')
print(f"Training model on {downgrade} data with scale {scale_for_training} ...")
train_samples = 800
val_samples = 100
if train_all_distortions:
train_samples *= 4
model.fit(train_ds,
epochs=epochs,
steps_per_epoch=train_samples//batch_size,
validation_data=valid_ds[downgrade_for_validation],
validation_steps=val_samples//batch_size,
callbacks=[tensorboard_callback,
model_checkpoint_callback])
for distortion in valid_ds:
print(f"Evaluating model on {distortion} data with scale {scale_for_training} ...")
model.evaluate(valid_ds[distortion])
print("Saving model...")
if train_all_distortions:
model.save(f'weights/{model_name}-{depth}-all-x{scale}/weights.h5', save_format='h5')
else:
model.save(f'weights/{model_name}-{depth}-{downgrade}-x{scale}/weights.h5', save_format='h5')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train Image Restoration Model.')
parser.add_argument('--scale', type=int, default=4,
help='super-resolution factor')
parser.add_argument('--downgrade', type=str, default='bicubic',
help='downgrade type')
parser.add_argument('--depth', type=int, default=16,
help='Number of residual blocks')
parser.add_argument('--batch-size', type=int, default=16,
help='Batch Size for training')
parser.add_argument('--epochs', type=int, default=100,
help='Number of epochs to train for')
parser.add_argument('--model', type=str, default='edsr',
help='Model name, can be edsr or wdsr')
parser.add_argument('--pretrained', type=str, default=None,
help='Weights of the pretrained model')
parser.add_argument('--downgrade_val', type=str, default = "bicubic",
help= 'Downgrade type for validation')
parser.add_argument('--scale_val', type=int, default = 4,
help= 'Scale type for validation')
parser.add_argument('--precision', type=int, default = 8,
help= 'Precision of quantized convolution')
parser.add_argument('--eval_all_distortions', action='store_true',
help= 'Evaluate all distortions for validation')
parser.add_argument('--train_all_distortions', action='store_true',
help= 'Train on all distortions')
args = parser.parse_args()
if len(tf.config.list_physical_devices('GPU')) == 0:
print("WARNING: No GPU found, running on CPU")
main(model_name=args.model,
downgrade=args.downgrade,
scale=args.scale,
downgrade_for_validation=args.downgrade_val,
scale_for_validation=args.scale_val,
precision=args.precision,
pretrained=args.pretrained,
batch_size=args.batch_size,
epochs=args.epochs,
depth=args.depth,
eval_all_distortions=args.eval_all_distortions,
train_all_distortions=args.train_all_distortions
)