-
Notifications
You must be signed in to change notification settings - Fork 10
/
flip.py
715 lines (602 loc) · 29 KB
/
flip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
#################################################################################
# Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-FileCopyrightText: Copyright (c) 2020-2023 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: BSD-3-Clause
#################################################################################
# Visualizing and Communicating Errors in Rendered Images
# Ray Tracing Gems II, 2021,
# by Pontus Andersson, Jim Nilsson, and Tomas Akenine-Moller.
# Pointer to the chapter: https://research.nvidia.com/publication/2021-08_Visualizing-and-Communicating.
# Visualizing Errors in Rendered High Dynamic Range Images
# Eurographics 2021,
# by Pontus Andersson, Jim Nilsson, Peter Shirley, and Tomas Akenine-Moller.
# Pointer to the paper: https://research.nvidia.com/publication/2021-05_HDR-FLIP.
# FLIP: A Difference Evaluator for Alternating Images
# High Performance Graphics 2020,
# by Pontus Andersson, Jim Nilsson, Tomas Akenine-Moller,
# Magnus Oskarsson, Kalle Astrom, and Mark D. Fairchild.
# Pointer to the paper: https://research.nvidia.com/publication/2020-07_FLIP.
# Code by Pontus Andersson, Jim Nilsson, and Tomas Akenine-Moller.
import sys
import numpy as np
import torch
import torch.nn as nn
class HDRFLIPLoss(nn.Module):
""" Class for computing HDR-FLIP """
def __init__(self):
""" Init """
super().__init__()
self.qc = 0.7
self.qf = 0.5
self.pc = 0.4
self.pt = 0.95
self.tmax = 0.85
self.tmin = 0.85
self.eps = 1e-15
def forward(self, test, reference, pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180, tone_mapper="aces", start_exposure=None, stop_exposure=None):
"""
Computes the HDR-FLIP error map between two HDR images,
assuming the images are observed at a certain number of
pixels per degree of visual angle
:param test: test tensor (with NxCxHxW layout with nonnegative values)
:param reference: reference tensor (with NxCxHxW layout with nonnegative values)
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
:param tone_mapper: (optional) string describing what tone mapper HDR-FLIP should assume
:param start_exposure: (optional tensor (with Nx1x1x1 layout) with start exposures corresponding to each HDR reference/test pair
:param stop_exposure: (optional) tensor (with Nx1x1x1 layout) with stop exposures corresponding to each HDR reference/test pair
:return: float containing the mean FLIP error (in the range [0,1]) between the HDR reference and test images in the batch
"""
# HDR-FLIP expects nonnegative and non-NaN values in the input
reference = torch.clamp(reference, 0, 65536.0)
test = torch.clamp(test, 0, 65536.0)
# Compute start and stop exposures, if they are not given
if start_exposure is None or stop_exposure is None:
c_start, c_stop = compute_start_stop_exposures(reference, tone_mapper, self.tmax, self.tmin)
if start_exposure is None:
start_exposure = c_start
if stop_exposure is None:
stop_exposure = c_stop
# Compute number of exposures
num_exposures = torch.max(torch.tensor([2.0], requires_grad=False).cuda(), torch.ceil(stop_exposure - start_exposure))
most_exposures = int(torch.amax(num_exposures, dim=0).item())
# Compute exposure step size
step_size = (stop_exposure - start_exposure) / torch.max(num_exposures - 1, torch.tensor([1.0], requires_grad=False).cuda())
# Set the depth of the error tensor to the number of exposures given by the largest exposure range any reference image yielded.
# This allows us to do one loop for each image in our batch, while not affecting the HDR-FLIP error, as we fill up the error tensor with 0s.
# Note that the step size still depends on num_exposures and is therefore independent of most_exposures
dim = reference.size()
all_errors = torch.zeros(size=(dim[0], most_exposures, dim[2], dim[3])).cuda()
# Loop over exposures and compute LDR-FLIP for each pair of LDR reference and test
for i in range(0, most_exposures):
exposure = start_exposure + i * step_size
reference_tone_mapped = tone_map(reference, tone_mapper, exposure)
test_tone_mapped = tone_map(test, tone_mapper, exposure)
reference_opponent = color_space_transform(reference_tone_mapped, 'linrgb2ycxcz')
test_opponent = color_space_transform(test_tone_mapped, 'linrgb2ycxcz')
all_errors[:, i, :, :] = compute_ldrflip(
test_opponent, reference_opponent, pixels_per_degree,
self.qc, self.qf, self.pc, self.pt, self.eps
).squeeze(1)
# Take per-pixel maximum over all LDR-FLIP errors to get HDR-FLIP
hdrflip_error = torch.amax(all_errors, dim=1, keepdim=True)
return torch.mean(hdrflip_error)
class LDRFLIPLoss(nn.Module):
""" Class for computing LDR FLIP loss """
def __init__(self):
""" Init """
super().__init__()
self.qc = 0.7
self.qf = 0.5
self.pc = 0.4
self.pt = 0.95
self.eps = 1e-15
def forward(self, test, reference, pixels_per_degree=(0.7 * 3840 / 0.7) * np.pi / 180, mask = None):
"""
Computes the LDR-FLIP error map between two LDR images,
assuming the images are observed at a certain number of
pixels per degree of visual angle
:param test: test tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space)
:param reference: reference tensor (with NxCxHxW layout with values in the range [0, 1] in the sRGB color space)
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
:return: float containing the mean FLIP error (in the range [0,1]) between the LDR reference and test images in the batch
"""
# LDR-FLIP expects non-NaN values in [0,1] as input
reference = torch.clamp(reference, 0, 1)
test = torch.clamp(test, 0, 1)
# Transform reference and test to opponent color space
reference_opponent = color_space_transform(reference, 'srgb2ycxcz')
test_opponent = color_space_transform(test, 'srgb2ycxcz')
deltaE = compute_ldrflip(test_opponent, reference_opponent, pixels_per_degree, self.qc, self.qf, self.pc, self.pt, self.eps)
if mask is not None:
deltaE = deltaE.squeeze() * mask
return deltaE
def compute_ldrflip(test, reference, pixels_per_degree, qc, qf, pc, pt, eps):
"""
Computes the LDR-FLIP error map between two LDR images,
assuming the images are observed at a certain number of
pixels per degree of visual angle
:param reference: reference tensor (with NxCxHxW layout with values in the YCxCz color space)
:param test: test tensor (with NxCxHxW layout with values in the YCxCz color space)
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer,
default corresponds to viewing the images on a 0.7 meters wide 4K monitor at 0.7 meters from the display
:param qc: float describing the q_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
:param qf: float describing the q_f exponent in the LDR-FLIP feature pipeline (see FLIP paper for details)
:param pc: float describing the p_c exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
:param pt: float describing the p_t exponent in the LDR-FLIP color pipeline (see FLIP paper for details)
:param eps: float containing a small value used to improve training stability
:return: tensor containing the per-pixel FLIP errors (with Nx1xHxW layout and values in the range [0, 1]) between LDR reference and test images
"""
# --- Color pipeline ---
# Spatial filtering
s_a, radius_a = generate_spatial_filter(pixels_per_degree, 'A')
s_rg, radius_rg = generate_spatial_filter(pixels_per_degree, 'RG')
s_by, radius_by = generate_spatial_filter(pixels_per_degree, 'BY')
radius = max(radius_a, radius_rg, radius_by)
filtered_reference = spatial_filter(reference, s_a, s_rg, s_by, radius)
filtered_test = spatial_filter(test, s_a, s_rg, s_by, radius)
# Perceptually Uniform Color Space
preprocessed_reference = hunt_adjustment(color_space_transform(filtered_reference, 'linrgb2lab'))
preprocessed_test = hunt_adjustment(color_space_transform(filtered_test, 'linrgb2lab'))
# Color metric
deltaE_hyab = hyab(preprocessed_reference, preprocessed_test, eps)
power_deltaE_hyab = torch.pow(deltaE_hyab, qc)
hunt_adjusted_green = hunt_adjustment(color_space_transform(torch.tensor([[[0.0]], [[1.0]], [[0.0]]]).unsqueeze(0), 'linrgb2lab'))
hunt_adjusted_blue = hunt_adjustment(color_space_transform(torch.tensor([[[0.0]], [[0.0]], [[1.0]]]).unsqueeze(0), 'linrgb2lab'))
cmax = torch.pow(hyab(hunt_adjusted_green, hunt_adjusted_blue, eps), qc).item()
deltaE_c = redistribute_errors(power_deltaE_hyab, cmax, pc, pt)
# --- Feature pipeline ---
# Extract and normalize Yy component
ref_y = (reference[:, 0:1, :, :] + 16) / 116
test_y = (test[:, 0:1, :, :] + 16) / 116
# Edge and point detection
edges_reference = feature_detection(ref_y, pixels_per_degree, 'edge')
points_reference = feature_detection(ref_y, pixels_per_degree, 'point')
edges_test = feature_detection(test_y, pixels_per_degree, 'edge')
points_test = feature_detection(test_y, pixels_per_degree, 'point')
# Feature metric
deltaE_f = torch.max(
torch.abs(torch.norm(edges_reference, dim=1, keepdim=True) - torch.norm(edges_test, dim=1, keepdim=True)),
torch.abs(torch.norm(points_test, dim=1, keepdim=True) - torch.norm(points_reference, dim=1, keepdim=True))
)
deltaE_f = torch.clamp(deltaE_f, min=eps) # clamp to stabilize training
deltaE_f = torch.pow(((1 / np.sqrt(2)) * deltaE_f), qf)
# --- Final error ---
return torch.pow(deltaE_c, 1 - deltaE_f)
def tone_map(img, tone_mapper, exposure):
"""
Applies exposure compensation and tone mapping.
Refer to the Visualizing Errors in Rendered High Dynamic Range Images
paper for details about the formulas.
:param img: float tensor (with NxCxHxW layout) containing nonnegative values
:param tone_mapper: string describing the tone mapper to apply
:param exposure: float tensor (with Nx1x1x1 layout) describing the exposure compensation factor
"""
# Exposure compensation
x = (2 ** exposure) * img
# Set tone mapping coefficients depending on tone_mapper
if tone_mapper == "reinhard":
lum_coeff_r = 0.2126
lum_coeff_g = 0.7152
lum_coeff_b = 0.0722
Y = x[:, 0:1, :, :] * lum_coeff_r + x[:, 1:2, :, :] * lum_coeff_g + x[:, 2:3, :, :] * lum_coeff_b
return torch.clamp(torch.div(x, 1 + Y), 0.0, 1.0)
if tone_mapper == "hable":
# Source: https://64.github.io/tonemapping/
A = 0.15
B = 0.50
C = 0.10
D = 0.20
E = 0.02
F = 0.30
k0 = A * F - A * E
k1 = C * B * F - B * E
k2 = 0
k3 = A * F
k4 = B * F
k5 = D * F * F
W = 11.2
nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2
denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5
white_scale = torch.div(denom, nom) # = 1 / (nom / denom)
# Include white scale and exposure bias in rational polynomial coefficients
k0 = 4 * k0 * white_scale
k1 = 2 * k1 * white_scale
k2 = k2 * white_scale
k3 = 4 * k3
k4 = 2 * k4
# k5 = k5 # k5 is not changed
else:
# Source: ACES approximation: https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/
# Include pre-exposure cancelation in constants
k0 = 0.6 * 0.6 * 2.51
k1 = 0.6 * 0.03
k2 = 0
k3 = 0.6 * 0.6 * 2.43
k4 = 0.6 * 0.59
k5 = 0.14
x2 = torch.pow(x, 2)
nom = k0 * x2 + k1 * x + k2
denom = k3 * x2 + k4 * x + k5
denom = torch.where(torch.isinf(denom), torch.Tensor([1.0]).cuda(), denom) # if denom is inf, then so is nom => nan. Pixel is very bright. It becomes inf here, but 1 after clamp below
y = torch.div(nom, denom)
return torch.clamp(y, 0.0, 1.0)
def compute_start_stop_exposures(reference, tone_mapper, tmax, tmin):
"""
Computes start and stop exposure for HDR-FLIP based on given tone mapper and reference image.
Refer to the Visualizing Errors in Rendered High Dynamic Range Images
paper for details about the formulas
:param reference: float tensor (with NxCxHxW layout) containing reference images (nonnegative values)
:param tone_mapper: string describing which tone mapper should be assumed
:param tmax: float describing the t value used to find the start exposure
:param tmin: float describing the t value used to find the stop exposure
:return: two float tensors (with Nx1x1x1 layout) containing start and stop exposures, respectively, to use for HDR-FLIP
"""
if tone_mapper == "reinhard":
k0 = 0
k1 = 1
k2 = 0
k3 = 0
k4 = 1
k5 = 1
x_max = tmax * k5 / (k1 - tmax * k4)
x_min = tmin * k5 / (k1 - tmin * k4)
elif tone_mapper == "hable":
# Source: https://64.github.io/tonemapping/
A = 0.15
B = 0.50
C = 0.10
D = 0.20
E = 0.02
F = 0.30
k0 = A * F - A * E
k1 = C * B * F - B * E
k2 = 0
k3 = A * F
k4 = B * F
k5 = D * F * F
W = 11.2
nom = k0 * torch.pow(W, torch.tensor([2.0]).cuda()) + k1 * W + k2
denom = k3 * torch.pow(W, torch.tensor([2.0]).cuda()) + k4 * W + k5
white_scale = torch.div(denom, nom) # = 1 / (nom / denom)
# Include white scale and exposure bias in rational polynomial coefficients
k0 = 4 * k0 * white_scale
k1 = 2 * k1 * white_scale
k2 = k2 * white_scale
k3 = 4 * k3
k4 = 2 * k4
# k5 = k5 # k5 is not changed
c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax)
c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax)
x_max = - 0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin)
c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin)
x_min = - 0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
else:
# Source: ACES approximation: https://knarkowicz.wordpress.com/2016/01/06/aces-filmic-tone-mapping-curve/
# Include pre-exposure cancelation in constants
k0 = 0.6 * 0.6 * 2.51
k1 = 0.6 * 0.03
k2 = 0
k3 = 0.6 * 0.6 * 2.43
k4 = 0.6 * 0.59
k5 = 0.14
c0 = (k1 - k4 * tmax) / (k0 - k3 * tmax)
c1 = (k2 - k5 * tmax) / (k0 - k3 * tmax)
x_max = - 0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
c0 = (k1 - k4 * tmin) / (k0 - k3 * tmin)
c1 = (k2 - k5 * tmin) / (k0 - k3 * tmin)
x_min = - 0.5 * c0 + torch.sqrt(((torch.tensor([0.5]).cuda() * c0) ** 2) - c1)
# Convert reference to luminance
lum_coeff_r = 0.2126
lum_coeff_g = 0.7152
lum_coeff_b = 0.0722
Y_reference = reference[:, 0:1, :, :] * lum_coeff_r + reference[:, 1:2, :, :] * lum_coeff_g + reference[:, 2:3, :, :] * lum_coeff_b
# Compute start exposure
Y_hi = torch.amax(Y_reference, dim=(2, 3), keepdim=True)
start_exposure = torch.log2(x_max / Y_hi)
# Compute stop exposure
dim = Y_reference.size()
Y_ref = Y_reference.view(dim[0], dim[1], dim[2]*dim[3])
Y_lo = torch.median(Y_ref, dim=2).values.unsqueeze(2).unsqueeze(3)
stop_exposure = torch.log2(x_min / Y_lo)
return start_exposure, stop_exposure
def generate_spatial_filter(pixels_per_degree, channel):
"""
Generates spatial contrast sensitivity filters with width depending on
the number of pixels per degree of visual angle of the observer
:param pixels_per_degree: float indicating number of pixels per degree of visual angle
:param channel: string describing what filter should be generated
:yield: Filter kernel corresponding to the spatial contrast sensitivity function of the given channel and kernel's radius
"""
a1_A = 1
b1_A = 0.0047
a2_A = 0
b2_A = 1e-5 # avoid division by 0
a1_rg = 1
b1_rg = 0.0053
a2_rg = 0
b2_rg = 1e-5 # avoid division by 0
a1_by = 34.1
b1_by = 0.04
a2_by = 13.5
b2_by = 0.025
if channel == "A": # Achromatic CSF
a1 = a1_A
b1 = b1_A
a2 = a2_A
b2 = b2_A
elif channel == "RG": # Red-Green CSF
a1 = a1_rg
b1 = b1_rg
a2 = a2_rg
b2 = b2_rg
elif channel == "BY": # Blue-Yellow CSF
a1 = a1_by
b1 = b1_by
a2 = a2_by
b2 = b2_by
# Determine evaluation domain
max_scale_parameter = max([b1_A, b2_A, b1_rg, b2_rg, b1_by, b2_by])
r = np.ceil(3 * np.sqrt(max_scale_parameter / (2 * np.pi**2)) * pixels_per_degree)
r = int(r)
deltaX = 1.0 / pixels_per_degree
x, y = np.meshgrid(range(-r, r + 1), range(-r, r + 1))
z = (x * deltaX)**2 + (y * deltaX)**2
# Generate weights
g = a1 * np.sqrt(np.pi / b1) * np.exp(-np.pi**2 * z / b1) + a2 * np.sqrt(np.pi / b2) * np.exp(-np.pi**2 * z / b2)
g = g / np.sum(g)
g = torch.Tensor(g).unsqueeze(0).unsqueeze(0).cuda()
return g, r
def spatial_filter(img, s_a, s_rg, s_by, radius):
"""
Filters an image with channel specific spatial contrast sensitivity functions
and clips result to the unit cube in linear RGB
:param img: image tensor to filter (with NxCxHxW layout in the YCxCz color space)
:param s_a: spatial filter matrix for the achromatic channel
:param s_rg: spatial filter matrix for the red-green channel
:param s_by: spatial filter matrix for the blue-yellow channel
:return: input image (with NxCxHxW layout) transformed to linear RGB after filtering with spatial contrast sensitivity functions
"""
dim = img.size()
# Prepare image for convolution
img_pad = torch.zeros((dim[0], dim[1], dim[2] + 2 * radius, dim[3] + 2 * radius), device='cuda')
img_pad[:, 0:1, :, :] = nn.functional.pad(img[:, 0:1, :, :], (radius, radius, radius, radius), mode='replicate')
img_pad[:, 1:2, :, :] = nn.functional.pad(img[:, 1:2, :, :], (radius, radius, radius, radius), mode='replicate')
img_pad[:, 2:3, :, :] = nn.functional.pad(img[:, 2:3, :, :], (radius, radius, radius, radius), mode='replicate')
# Apply Gaussian filters
img_tilde_opponent = torch.zeros((dim[0], dim[1], dim[2], dim[3]), device='cuda')
img_tilde_opponent[:, 0:1, :, :] = nn.functional.conv2d(img_pad[:, 0:1, :, :], s_a.cuda(), padding=0)
img_tilde_opponent[:, 1:2, :, :] = nn.functional.conv2d(img_pad[:, 1:2, :, :], s_rg.cuda(), padding=0)
img_tilde_opponent[:, 2:3, :, :] = nn.functional.conv2d(img_pad[:, 2:3, :, :], s_by.cuda(), padding=0)
# Transform to linear RGB for clamp
img_tilde_linear_rgb = color_space_transform(img_tilde_opponent, 'ycxcz2linrgb')
# Clamp to RGB box
return torch.clamp(img_tilde_linear_rgb, 0.0, 1.0)
def hunt_adjustment(img):
"""
Applies Hunt-adjustment to an image
:param img: image tensor to adjust (with NxCxHxW layout in the L*a*b* color space)
:return: Hunt-adjusted image tensor (with NxCxHxW layout in the Hunt-adjusted L*A*B* color space)
"""
# Extract luminance component
L = img[:, 0:1, :, :]
# Apply Hunt adjustment
img_h = torch.zeros(img.size(), device='cuda')
img_h[:, 0:1, :, :] = L
img_h[:, 1:2, :, :] = torch.mul((0.01 * L), img[:, 1:2, :, :])
img_h[:, 2:3, :, :] = torch.mul((0.01 * L), img[:, 2:3, :, :])
return img_h
def hyab(reference, test, eps):
"""
Computes the HyAB distance between reference and test images
:param reference: reference image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*A*B* color space)
:param test: test image tensor (with NxCxHxW layout in the standard or Hunt-adjusted L*a*b* color space)
:param eps: float containing a small value used to improve training stability
:return: image tensor (with Nx1xHxW layout) containing the per-pixel HyAB distances between reference and test images
"""
delta = reference - test
root = torch.sqrt(torch.clamp(torch.pow(delta[:, 0:1, :, :], 2), min=eps))
delta_norm = torch.norm(delta[:, 1:3, :, :], dim=1, keepdim=True)
return root + delta_norm # alternative abs to stabilize training
def redistribute_errors(power_deltaE_hyab, cmax, pc, pt):
"""
Redistributes exponentiated HyAB errors to the [0,1] range
:param power_deltaE_hyab: float tensor (with Nx1xHxW layout) containing the exponentiated HyAb distance
:param cmax: float containing the exponentiated, maximum HyAB difference between two colors in Hunt-adjusted L*A*B* space
:param pc: float containing the cmax multiplier p_c (see FLIP paper)
:param pt: float containing the target value, p_t, for p_c * cmax (see FLIP paper)
:return: image tensor (with Nx1xHxW layout) containing redistributed per-pixel HyAB distances (in range [0,1])
"""
# Re-map error to 0-1 range. Values between 0 and
# pccmax are mapped to the range [0, pt],
# while the rest are mapped to the range (pt, 1]
deltaE_c = torch.zeros(power_deltaE_hyab.size(), device='cuda')
pccmax = pc * cmax
deltaE_c = torch.where(power_deltaE_hyab < pccmax, (pt / pccmax) * power_deltaE_hyab, pt + ((power_deltaE_hyab - pccmax) / (cmax - pccmax)) * (1.0 - pt))
return deltaE_c
def feature_detection(img_y, pixels_per_degree, feature_type):
"""
Detects edges and points (features) in the achromatic image
:param imgy: achromatic image tensor (with Nx1xHxW layout, containing normalized Y-values from YCxCz)
:param pixels_per_degree: float describing the number of pixels per degree of visual angle of the observer
:param feature_type: string indicating the type of feature to detect
:return: image tensor (with Nx2xHxW layout, with values in range [0,1]) containing large values where features were detected
"""
# Set peak to trough value (2x standard deviations) of human edge
# detection filter
w = 0.082
# Compute filter radius
sd = 0.5 * w * pixels_per_degree
radius = int(np.ceil(3 * sd))
# Compute 2D Gaussian
[x, y] = np.meshgrid(range(-radius, radius+1), range(-radius, radius+1))
g = np.exp(-(x ** 2 + y ** 2) / (2 * sd * sd))
if feature_type == 'edge': # Edge detector
# Compute partial derivative in x-direction
Gx = np.multiply(-x, g)
else: # Point detector
# Compute second partial derivative in x-direction
Gx = np.multiply(x ** 2 / (sd * sd) - 1, g)
# Normalize positive weights to sum to 1 and negative weights to sum to -1
negative_weights_sum = -np.sum(Gx[Gx < 0])
positive_weights_sum = np.sum(Gx[Gx > 0])
Gx = torch.Tensor(Gx)
Gx = torch.where(Gx < 0, Gx / negative_weights_sum, Gx / positive_weights_sum)
Gx = Gx.unsqueeze(0).unsqueeze(0).cuda()
# Detect features
featuresX = nn.functional.conv2d(nn.functional.pad(img_y, (radius, radius, radius, radius), mode='replicate'), Gx, padding=0)
featuresY = nn.functional.conv2d(nn.functional.pad(img_y, (radius, radius, radius, radius), mode='replicate'), torch.transpose(Gx, 2, 3), padding=0)
return torch.cat((featuresX, featuresY), dim=1)
def color_space_transform(input_color, fromSpace2toSpace):
"""
Transforms inputs between different color spaces
:param input_color: tensor of colors to transform (with NxCxHxW layout)
:param fromSpace2toSpace: string describing transform
:return: transformed tensor (with NxCxHxW layout)
"""
dim = input_color.size()
# Assume D65 standard illuminant
reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]]).cuda()
inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]]).cuda()
if fromSpace2toSpace == "srgb2linrgb":
limit = 0.04045
transformed_color = torch.where(
input_color > limit,
torch.pow((torch.clamp(input_color, min=limit) + 0.055) / 1.055, 2.4),
input_color / 12.92
) # clamp to stabilize training
elif fromSpace2toSpace == "linrgb2srgb":
limit = 0.0031308
transformed_color = torch.where(
input_color > limit,
1.055 * torch.pow(torch.clamp(input_color, min=limit), (1.0 / 2.4)) - 0.055,
12.92 * input_color
)
elif fromSpace2toSpace in ["linrgb2xyz", "xyz2linrgb"]:
# Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz
# Assumes D65 standard illuminant
if fromSpace2toSpace == "linrgb2xyz":
a11 = 10135552 / 24577794
a12 = 8788810 / 24577794
a13 = 4435075 / 24577794
a21 = 2613072 / 12288897
a22 = 8788810 / 12288897
a23 = 887015 / 12288897
a31 = 1425312 / 73733382
a32 = 8788810 / 73733382
a33 = 70074185 / 73733382
else:
# Constants found by taking the inverse of the matrix
# defined by the constants for linrgb2xyz
a11 = 3.241003275
a12 = -1.537398934
a13 = -0.498615861
a21 = -0.969224334
a22 = 1.875930071
a23 = 0.041554224
a31 = 0.055639423
a32 = -0.204011202
a33 = 1.057148933
A = torch.Tensor([[a11, a12, a13],
[a21, a22, a23],
[a31, a32, a33]])
input_color = input_color.view(dim[0], dim[1], dim[2]*dim[3]).cuda() # NC(HW)
transformed_color = torch.matmul(A.cuda(), input_color)
transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3])
elif fromSpace2toSpace == "xyz2ycxcz":
input_color = torch.mul(input_color, inv_reference_illuminant)
y = 116 * input_color[:, 1:2, :, :] - 16
cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
transformed_color = torch.cat((y, cx, cz), 1)
elif fromSpace2toSpace == "ycxcz2xyz":
y = (input_color[:, 0:1, :, :] + 16) / 116
cx = input_color[:, 1:2, :, :] / 500
cz = input_color[:, 2:3, :, :] / 200
x = y + cx
z = y - cz
transformed_color = torch.cat((x, y, z), 1)
transformed_color = torch.mul(transformed_color, reference_illuminant)
elif fromSpace2toSpace == "xyz2lab":
input_color = torch.mul(input_color, inv_reference_illuminant)
delta = 6 / 29
delta_square = delta * delta
delta_cube = delta * delta_square
factor = 1 / (3 * delta_square)
clamped_term = torch.pow(torch.clamp(input_color, min=delta_cube), 1.0 / 3.0).to(dtype=input_color.dtype)
div = (factor * input_color + (4 / 29)).to(dtype=input_color.dtype)
input_color = torch.where(input_color > delta_cube, clamped_term, div) # clamp to stabilize training
L = 116 * input_color[:, 1:2, :, :] - 16
a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
transformed_color = torch.cat((L, a, b), 1)
elif fromSpace2toSpace == "lab2xyz":
y = (input_color[:, 0:1, :, :] + 16) / 116
a = input_color[:, 1:2, :, :] / 500
b = input_color[:, 2:3, :, :] / 200
x = y + a
z = y - b
xyz = torch.cat((x, y, z), 1)
delta = 6 / 29
delta_square = delta * delta
factor = 3 * delta_square
xyz = torch.where(xyz > delta, torch.pow(xyz, 3), factor * (xyz - 4 / 29))
transformed_color = torch.mul(xyz, reference_illuminant)
elif fromSpace2toSpace == "srgb2xyz":
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
elif fromSpace2toSpace == "srgb2ycxcz":
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz')
elif fromSpace2toSpace == "linrgb2ycxcz":
transformed_color = color_space_transform(input_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz')
elif fromSpace2toSpace == "srgb2lab":
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
elif fromSpace2toSpace == "linrgb2lab":
transformed_color = color_space_transform(input_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
elif fromSpace2toSpace == "ycxcz2linrgb":
transformed_color = color_space_transform(input_color, 'ycxcz2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2linrgb')
elif fromSpace2toSpace == "lab2srgb":
transformed_color = color_space_transform(input_color, 'lab2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2srgb')
elif fromSpace2toSpace == "ycxcz2lab":
transformed_color = color_space_transform(input_color, 'ycxcz2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
else:
sys.exit('Error: The color transform %s is not defined!' % fromSpace2toSpace)
return transformed_color