forked from LIU-Yuxin/SyncMVD
-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_experiment.py
113 lines (90 loc) · 3.34 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
from os.path import join, isdir, abspath, dirname, basename, splitext
from IPython.display import display
from datetime import datetime
import torch
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from diffusers import DDPMScheduler, UniPCMultistepScheduler
from src.pipeline import StableSyncMVDPipeline
from src.configs import *
from shutil import copy
opt = parse_config()
# print(opt)
if opt.mesh_config_relative:
mesh_path = join(dirname(opt.config), opt.mesh)
else:
mesh_path = abspath(opt.mesh)
if opt.output:
output_root = abspath(opt.output)
else:
output_root = dirname(opt.config)
output_name_components = []
if opt.prefix and opt.prefix != "":
output_name_components.append(opt.prefix)
if opt.use_mesh_name:
mesh_name = splitext(basename(mesh_path))[0].replace(" ", "_")
output_name_components.append(mesh_name)
if opt.timeformat and opt.timeformat != "":
output_name_components.append(datetime.now().strftime(opt.timeformat))
output_name = "_".join(output_name_components)
output_dir = join(output_root, output_name)
if not isdir(output_dir):
os.mkdir(output_dir)
else:
print(f"Results exist in the output directory, use time string to avoid name collision.")
exit(0)
print(f"Saving to {output_dir}")
copy(opt.config, join(output_dir, "config.yaml"))
logging_config = {
"output_dir":output_dir,
# "output_name":None,
# "intermediate":False,
"log_interval":opt.log_interval,
"view_fast_preview": opt.view_fast_preview,
"tex_fast_preview": opt.tex_fast_preview,
}
if opt.cond_type == "normal":
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_normalbae", variant="fp16", torch_dtype=torch.float16)
elif opt.cond_type == "depth":
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", variant="fp16", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
syncmvd = StableSyncMVDPipeline(**pipe.components)
result_tex_rgb, textured_views, v = syncmvd(
prompt=opt.prompt,
height=opt.latent_view_size*8,
width=opt.latent_view_size*8,
num_inference_steps=opt.steps,
guidance_scale=opt.guidance_scale,
negative_prompt=opt.negative_prompt,
generator=torch.manual_seed(opt.seed),
max_batch_size=64,
controlnet_guess_mode=opt.guess_mode,
controlnet_conditioning_scale = opt.conditioning_scale,
controlnet_conditioning_end_scale= opt.conditioning_scale_end,
control_guidance_start= opt.control_guidance_start,
control_guidance_end = opt.control_guidance_end,
guidance_rescale = opt.guidance_rescale,
use_directional_prompt=True,
mesh_path=mesh_path,
mesh_transform={"scale":opt.mesh_scale},
mesh_autouv=not opt.keep_mesh_uv,
camera_azims=opt.camera_azims,
top_cameras=not opt.no_top_cameras,
texture_size=opt.latent_tex_size,
render_rgb_size=opt.rgb_view_size,
texture_rgb_size=opt.rgb_tex_size,
multiview_diffusion_end=opt.mvd_end,
exp_start=opt.mvd_exp_start,
exp_end=opt.mvd_exp_end,
ref_attention_end=opt.ref_attention_end,
shuffle_background_change=opt.shuffle_bg_change,
shuffle_background_end=opt.shuffle_bg_end,
logging_config=logging_config,
cond_type=opt.cond_type,
max_hits=opt.max_hits,
style_prompt=opt.style_prompt
)
display(v)