-
Notifications
You must be signed in to change notification settings - Fork 17
/
generate.py
71 lines (60 loc) · 2.5 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import os
import torch
from utils import get_pipeline
from vars import PROMPT_TEMPLATES
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--model", type=str, default="schnell", choices=["schnell", "dev"], help="Which FLUX.1 model to use"
)
parser.add_argument(
"-p", "--precision", type=str, default="int4", choices=["int4", "bf16"], help="Which precision to use"
)
parser.add_argument(
"--prompt", type=str, default="A cat holding a sign that says hello world", help="Prompt for the image"
)
parser.add_argument("--seed", type=int, default=2333, help="Random seed (-1 for random)")
parser.add_argument("-t", "--num-inference-steps", type=int, default=4, help="Number of inference steps")
parser.add_argument("-o", "--output-path", type=str, default="output.png", help="Image output path")
parser.add_argument("-g", "--guidance-scale", type=float, default=0, help="Guidance scale.")
parser.add_argument("--use-qencoder", action="store_true", help="Whether to use 4-bit text encoder")
known_args, _ = parser.parse_known_args()
if known_args.model == "dev":
parser.set_defaults(num_inference_steps=50, guidance_scale=3.5)
parser.add_argument(
"-n",
"--lora-name",
type=str,
default="None",
choices=PROMPT_TEMPLATES.keys(),
help="Name of the LoRA layer",
)
parser.add_argument("-a", "--lora-weight", type=float, default=1, help="Weight of the LoRA layer")
args = parser.parse_args()
return args
def main():
args = get_args()
pipeline = get_pipeline(
model_name=args.model,
precision=args.precision,
use_qencoder=args.use_qencoder,
lora_name=getattr(args, "lora_name", "None"),
lora_weight=getattr(args, "lora_weight", 1),
device="cuda",
)
if args.model == "dev":
prompt = PROMPT_TEMPLATES[args.lora_name].format(prompt=args.prompt)
else:
prompt = args.prompt
image = pipeline(
prompt=prompt,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator().manual_seed(args.seed) if args.seed >= 0 else None,
).images[0]
output_dir = os.path.dirname(os.path.abspath(os.path.expanduser(args.output_path)))
os.makedirs(output_dir, exist_ok=True)
image.save(args.output_path)
if __name__ == "__main__":
main()