-
Notifications
You must be signed in to change notification settings - Fork 6
/
generate_flow.py
113 lines (82 loc) · 4 KB
/
generate_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
from RAFT.raft import RAFT
from RAFT.utils import flow_viz
from RAFT.utils.utils import InputPadder
from flow_utils import *
DEVICE = 'cuda'
def create_dir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def load_image(imfile):
img = np.array(Image.open(imfile).convert('RGB')).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
def warp_flow(img, flow):
h, w = flow.shape[:2]
flow_new = flow.copy()
flow_new[:,:,0] += np.arange(w)
flow_new[:,:,1] += np.arange(h)[:,np.newaxis]
res = cv2.remap(img, flow_new, None, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT)
return res
def compute_fwdbwd_mask(fwd_flow, bwd_flow):
alpha_1 = 0.5
alpha_2 = 0.5
bwd2fwd_flow = warp_flow(bwd_flow, fwd_flow)
fwd_lr_error = np.linalg.norm(fwd_flow + bwd2fwd_flow, axis=-1)
fwd_mask = fwd_lr_error < alpha_1 * (np.linalg.norm(fwd_flow, axis=-1) \
+ np.linalg.norm(bwd2fwd_flow, axis=-1)) + alpha_2
fwd2bwd_flow = warp_flow(fwd_flow, bwd_flow)
bwd_lr_error = np.linalg.norm(bwd_flow + fwd2bwd_flow, axis=-1)
bwd_mask = bwd_lr_error < alpha_1 * (np.linalg.norm(bwd_flow, axis=-1) \
+ np.linalg.norm(fwd2bwd_flow, axis=-1)) + alpha_2
return fwd_mask, bwd_mask
def run(args, input_path, output_path, output_img_path):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
with torch.no_grad():
images = glob.glob(os.path.join(input_path, '*.png')) + \
glob.glob(os.path.join(input_path, '*.jpg'))
images = sorted(images)
for i in range(len(images) - 1):
print(i)
image_name = os.path.splitext(os.path.basename(images[i]))[0]
image1 = load_image(images[i])
image2 = load_image(images[i + 1])
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
_, flow_fwd = model(image1, image2, iters=20, test_mode=True)
_, flow_bwd = model(image2, image1, iters=20, test_mode=True)
flow_fwd = padder.unpad(flow_fwd[0]).cpu().numpy().transpose(1, 2, 0)
flow_bwd = padder.unpad(flow_bwd[0]).cpu().numpy().transpose(1, 2, 0)
mask_fwd, mask_bwd = compute_fwdbwd_mask(flow_fwd, flow_bwd)
# Save flow
np.savez(os.path.join(output_path, f'{image_name}_fwd.npz'), flow=flow_fwd, mask=mask_fwd)
np.savez(os.path.join(output_path, f'{image_name}_bwd.npz'), flow=flow_bwd, mask=mask_bwd)
# Save flow_img
Image.fromarray(flow_viz.flow_to_image(flow_fwd)).save(os.path.join(output_img_path, f'{image_name}_fwd.png'))
Image.fromarray(flow_viz.flow_to_image(flow_bwd)).save(os.path.join(output_img_path, f'{image_name}_bwd.png'))
Image.fromarray(mask_fwd).save(os.path.join(output_img_path, f'{image_name}_fwd_mask.png'))
Image.fromarray(mask_bwd).save(os.path.join(output_img_path, f'{image_name}_bwd_mask.png'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, help='Dataset path')
parser.add_argument("--input_dir", type=str, help='Input image directory')
parser.add_argument('--model', help="restore RAFT checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
args = parser.parse_args()
input_path = os.path.join(args.dataset_path, args.input_dir)
output_path = os.path.join(args.dataset_path, f'{args.input_dir}_flow')
output_img_path = os.path.join(args.dataset_path, f'{args.input_dir}_flow_png')
create_dir(output_path)
create_dir(output_img_path)
run(args, input_path, output_path, output_img_path)