forked from xdit-project/xDiT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
baseline.py
118 lines (98 loc) · 3 KB
/
baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import torch
from torch.profiler import profile, record_function, ProfilerActivity
from diffusers import (
PixArtAlphaPipeline,
StableDiffusion3Pipeline,
FlowMatchEulerDiscreteScheduler,
)
import time
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
default="stabilityai/stable-diffusion-3-medium-diffusers",
type=str,
help="Path to the pretrained model.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=28,
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="The height of image",
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="The width of image",
)
parser.add_argument(
"--output_type",
type=str,
default="pil",
choices=["latent", "pil"],
help="latent saves memory, pil will results a memory burst in vae",
)
parser.add_argument(
"--scheduler",
"-s",
default="FM-ED",
type=str,
choices=["dpm-solver", "ddim", "FM-ED"],
help="Scheduler to use.",
)
parser.add_argument(
"--prompt",
type=str,
default="An astronaut riding a green horse",
)
parser.add_argument("--output_file", type=str, default=None)
args = parser.parse_args()
# torch.backends.cudnn.benchmark=True
torch.backends.cudnn.deterministic = True
# for DiT the height and width are fixed according to the model
model_id = args.model_id
if args.scheduler == "FM-ED":
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
pretrained_model_name_or_path=model_id,
subfolder="scheduler",
)
pipeline = StableDiffusion3Pipeline.from_pretrained(
pretrained_model_name_or_path=model_id,
scheduler=scheduler,
).to("cuda")
# warmup
output = pipeline(
prompt=args.prompt,
generator=torch.Generator(device="cuda").manual_seed(42),
output_type=args.output_type,
)
torch.cuda.reset_peak_memory_stats()
case_name = f"baseline_hw_{args.height}_base"
if args.output_file:
case_name = args.output_file + "_" + case_name
start_time = time.time()
output = pipeline(
prompt=args.prompt,
generator=torch.Generator(device="cuda").manual_seed(42),
num_inference_steps=args.num_inference_steps,
output_type=args.output_type,
)
end_time = time.time()
torch.cuda.memory._record_memory_history(enabled=None)
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device="cuda")
print(
f"{case_name} epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
)
if args.output_type == "pil":
print(f"save images to ./results/{case_name}.png")
output.images[0].save(f"./results/{case_name}.png")
if __name__ == "__main__":
main()