forked from xdit-project/xDiT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ditxl_example.py
115 lines (99 loc) · 3.06 KB
/
ditxl_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import torch
from legacy.pipefuser.pipelines.dit import DistriDiTPipeline
from legacy.pipefuser.utils import DistriConfig
HAS_LONG_CTX_ATTN = False
try:
from yunchang import set_seq_parallel_pg
HAS_LONG_CTX_ATTN = True
except ImportError:
print("yunchang not found")
import time
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
default="facebook/DiT-XL-2-256",
type=str,
help="Path to the pretrained model.",
)
parser.add_argument(
"--parallelism",
"-p",
default="patch",
type=str,
choices=["patch", "naive_patch", "tensor"],
help="Parallelism to use.",
)
parser.add_argument(
"--sync_mode",
type=str,
default="corrected_async_gn",
choices=[
"separate_gn",
"async_gn",
"corrected_async_gn",
"sync_gn",
"full_sync",
"no_sync",
],
help="Different GroupNorm synchronization modes",
)
parser.add_argument(
"--use_seq_parallel_attn",
action="store_true",
default=False,
help="Enable sequence parallel attention.",
)
parser.add_argument(
"--ulysses_degree",
type=int,
default=1,
)
parser.add_argument(
"--use_use_ulysses_low",
action="store_true",
)
args = parser.parse_args()
# for DiT the height and width are fixed according to the model
distri_config = DistriConfig(
height=1024,
width=1024,
warmup_steps=4,
do_classifier_free_guidance=True,
split_batch=False,
parallelism=args.parallelism,
mode=args.sync_mode,
use_cuda_graph=False,
)
pipeline = DistriDiTPipeline.from_pretrained(
distri_config=distri_config,
pretrained_model_name_or_path=args.model_id,
# variant="fp16",
# use_safetensors=True,
)
pipeline.set_progress_bar_config(disable=distri_config.rank != 0)
case_name = f"{args.parallelism}_{args.sync_mode}_sp_{args.use_seq_parallel_attn}_u{args.ulysses_degree}_w{distri_config.world_size}"
# warmup
output = pipeline(
# prompt="Emma Stone flying in the sky, cold color palette, muted colors, detailed, 8k",
prompt=["panda"],
generator=torch.Generator(device="cuda").manual_seed(42),
)
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipeline(
# prompt="Emma Stone flying in the sky, cold color palette, muted colors, detailed, 8k",
prompt=["panda"],
generator=torch.Generator(device="cuda").manual_seed(42),
)
end_time = time.time()
peak_memory = torch.cuda.max_memory_allocated(device="cuda")
if distri_config.rank == 0:
elapsed_time = end_time - start_time
print(
f"{case_name}: elapse: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
)
output.images[0].save(f"./results/{case_name}_panda.png")
if __name__ == "__main__":
main()