-
Notifications
You must be signed in to change notification settings - Fork 21
/
run.py
98 lines (88 loc) · 3.1 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from lib.config import cfg, args
import numpy as np
from fusion import fusion
def run_dataset():
from lib.datasets import make_data_loader
import tqdm
cfg.train.num_workers = 0
data_loader = make_data_loader(cfg, is_train=False)
for batch in tqdm.tqdm(data_loader):
pass
def run_network():
from lib.networks import make_network
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_network
from lib.utils.data_utils import to_cuda
import tqdm
import torch
import time
network = make_network(cfg).cuda()
load_network(network, cfg.trained_model_dir, epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
total_time = 0
for batch in tqdm.tqdm(data_loader):
batch = to_cuda(batch)
with torch.no_grad():
torch.cuda.synchronize()
start = time.time()
network(batch)
torch.cuda.synchronize()
total_time += time.time() - start
print(total_time / len(data_loader))
def run_evaluate():
from lib.datasets import make_data_loader
from lib.evaluators import make_evaluator
import tqdm
import torch
from lib.networks import make_network
from lib.utils import net_utils
import time
network = make_network(cfg).cuda()
net_utils.load_network(network,
cfg.trained_model_dir,
resume=cfg.resume,
epoch=cfg.test.epoch)
network.eval()
data_loader = make_data_loader(cfg, is_train=False)
evaluator = make_evaluator(cfg)
net_time = []
scenes = []
for batch in tqdm.tqdm(data_loader):
for k in batch:
if k != 'meta':
if 'novel_view' in k:
for v in batch[k]:
batch[k][v] = batch[k][v].cuda()
elif k == 'rendering_video_meta':
for i in range(len(batch[k])):
for v in batch[k][i]:
batch[k][i][v] = batch[k][i][v].cuda()
else:
batch[k] = batch[k].cuda()
if cfg.save_video:
with torch.no_grad():
network(batch)
else:
with torch.no_grad():
torch.cuda.synchronize()
start_time = time.time()
output = network(batch)
torch.cuda.synchronize()
end_time = time.time()
net_time.append(end_time - start_time)
evaluator.evaluate(output, batch)
scenes.append(batch['meta']['scene'][0])
if not cfg.save_video:
evaluator.summarize()
if len(net_time) > 1:
# print('net_time: ', np.mean(net_time[1:]))
print('FPS: ', 1./np.mean(net_time[1:]))
else:
# print('net_time: ', np.mean(net_time))
print('FPS: ', 1./np.mean(net_time))
if cfg.save_ply:
for scene in scenes:
fusion(cfg.dir_ply, scene)
if __name__ == '__main__':
globals()['run_' + args.type]()