-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
96 lines (80 loc) · 3.59 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import numpy as np
import matplotlib
import scipy.ndimage as ndimage
from scipy.misc import imresize
from scipy.ndimage.filters import gaussian_filter
import os
def blur(patch,scale_factor):
patch = gaussian_filter(patch,sigma=scale_factor)
return patch
def upscale(patch,scale_factor):
# return patch """ To disable upscale, uncomment this line """
patch = imresize(patch,(36*scale_factor,36*scale_factor),interp='cubic')
return patch
def save_checkpoint(state, output_dir, filename):
torch.save(state, os.path.join(output_dir,filename))
def frac_eq_to(image, value=0):
return (image == value).sum() / float(np.prod(image.shape))
def extract_patches(image, patchshape, overlap_allowed=0.5, cropvalue=None,
crop_fraction_allowed=0.1):
"""
Given an image, extract patches of a given shape with a certain
amount of allowed overlap between patches, using a heuristic to
ensure maximum coverage.
If cropvalue is specified, it is treated as a flag denoting a pixel
that has been cropped. Patch will be rejected if it has more than
crop_fraction_allowed * prod(patchshape) pixels equal to cropvalue.
Likewise, patches will be rejected for having more overlap_allowed
fraction of their pixels contained in a patch already selected.
"""
jump_cols = int(patchshape[1] * overlap_allowed)
jump_rows = int(patchshape[0] * overlap_allowed)
# Restrict ourselves to the rectangle containing non-cropped pixels
if cropvalue is not None:
rows, cols = np.where(image != cropvalue)
rows.sort(); cols.sort()
active = image[rows[0]:rows[-1], cols[0]:cols[-1]]
else:
active = image
rowstart = 0; colstart = 0
# Array tracking where we've already taken patches.
covered = np.zeros(active.shape, dtype=bool)
patches = []
while rowstart < active.shape[0] - patchshape[0]:
# Record whether or not e've found a patch in this row,
# so we know whether to skip ahead.
got_a_patch_this_row = False
colstart = 0
while colstart < active.shape[1] - patchshape[1]:
# Slice tuple indexing the region of our proposed patch
region = (slice(rowstart, rowstart + patchshape[0]),
slice(colstart, colstart + patchshape[1]))
# The actual pixels in that region.
patch = active[region]
# The current mask value for that region.
cover_p = covered[region]
if cropvalue is None or \
frac_eq_to(patch, cropvalue) <= crop_fraction_allowed and \
frac_eq_to(cover_p, True) <= overlap_allowed:
# Accept the patch.
patches.append(patch)
# Mask the area.
covered[region] = True
# Jump ahead in the x direction.
colstart += jump_cols
got_a_patch_this_row = True
#print "Got a patch at %d, %d" % (rowstart, colstart)
else:
# Otherwise, shift window across by one pixel.
colstart += 1
if got_a_patch_this_row:
# Jump ahead in the y direction.
rowstart += jump_rows
else:
# Otherwise, shift the window down by one pixel.
rowstart += 1
# Return a 3D array of the patches with the patch index as the first
# dimension (so that patch pixels stay contiguous in memory, in a
# C-ordered array).
return np.concatenate([pat[np.newaxis, ...] for pat in patches], axis=0)