-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
88 lines (65 loc) · 1.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Testing script
"""
import sys
sys.path.append("mast3r")
sys.path.append("mast3r/dust3r")
sys.path.append("mast3r/dust3r/croco")
from pathlib import Path
import starster
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
RES = 224
files = []
dir = Path("../data/room/").absolute()
for file in dir.iterdir():
if file.suffix.lower() == ".jpg":
files.append(str(file))
"""
# Test load image
img = starster.load_image(files[0])
print(type(img))
print(img.shape, img.dtype)
exit()
"""
imgs = []
for file in files:
imgs.append(starster.load_image(file, RES))
"""
imgs_mast3r = starster.prepare_images_for_mast3r(imgs)
for i in range(len(imgs)):
imgs[i] = imgs[i].to(DEVICE)
pairs = starster.make_pair_indices(len(imgs), symmetric=True)
"""
model = starster.Mast3rModel.from_pretrained("../models/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth").to(DEVICE)
scene = starster.Scene(device=DEVICE)
scene.add_images(model, imgs[:2])
scene.add_images(model, imgs[2:])
print(scene.imgs[0].shape)
"""
import numpy as np
for i, pts in enumerate(scene.pts3d):
print(pts.shape)
np.savetxt(f"pts{i}.txt", pts)
print(len(scene.pts3d), scene.pts3d[0].shape)
"""
import numpy as np
import cv2
scene.init_3dgs()
"""
# Show progress.
for i in range(50):
imgs, alpha, info = gs.render_views_original(RES, RES)
imgs = torch.clip(imgs.detach().cpu(), 0, 1)
imgs = (imgs.numpy()[..., ::-1] * 255).astype(np.uint8)
cv2.imwrite(f"imgs/{i}.jpg", imgs[0])
gs.run_optimization(10, enable_pruning=True, verbose=True)
"""
scene.run_3dgs_optim(400, enable_pruning=True, verbose=True)
scene.run_3dgs_optim(100, enable_pruning=False, verbose=True)
imgs, alpha, info = scene.render_3dgs_original(RES, RES)
print(imgs.shape)
imgs = torch.clip(imgs.detach().cpu(), 0, 1)
imgs = (imgs.numpy()[..., ::-1] * 255).astype(np.uint8)
for i, img in enumerate(imgs):
cv2.imwrite(f"{i}.png", img)