Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaddleMIX ppdiffusers Stable Diffusion 3 inference optimize #681

Open
wants to merge 59 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
a6631e7
optimize SD3
chang-wenbin Aug 19, 2024
b0ea9ef
optimize SD3 transformer_SD3
chang-wenbin Aug 19, 2024
f06a61a
optimize SD3 transformer_SD3
chang-wenbin Aug 19, 2024
dcff90c
update SD3
chang-wenbin Aug 20, 2024
15c5e44
uodate triton &sim_SD3
chang-wenbin Aug 20, 2024
ab73a63
modify temb_silu && modify nvtx
chang-wenbin Aug 20, 2024
ed2b7b1
modify linear from fused_linear
chang-wenbin Aug 20, 2024
f4330d3
modify simplified_sd3
chang-wenbin Aug 20, 2024
cc1af0f
add split_concat triton kernel
chang-wenbin Aug 20, 2024
70e6b6e
modify split_concat triton kernel
chang-wenbin Aug 21, 2024
9543b11
update
chang-wenbin Aug 21, 2024
357b75a
update transformer_sd3
chang-wenbin Aug 21, 2024
f54bf84
update transformer_sd3
chang-wenbin Aug 21, 2024
3245b2f
update triton & simplified_sd3
chang-wenbin Aug 21, 2024
5516df6
update simplified_sd3
chang-wenbin Aug 22, 2024
874d5d7
update simplified_sd3
chang-wenbin Aug 22, 2024
111f4cd
delete context_pre_only=False
chang-wenbin Aug 22, 2024
18777b6
modify triton_optimize
chang-wenbin Aug 22, 2024
7a288e4
modify triton_optimize
chang-wenbin Aug 22, 2024
840b153
modify triton_optimize
chang-wenbin Aug 22, 2024
95c9e47
modify triton_fuse & Modifying performance issues affected by CUDA sy…
chang-wenbin Aug 22, 2024
84a9e7a
modify transformer_sd3 if optimize_prigin
chang-wenbin Aug 23, 2024
9dd918d
update vae triton_split
chang-wenbin Aug 23, 2024
3a0b7e1
vae T5 d2s & transformer forward d2s
chang-wenbin Aug 26, 2024
6d02d79
update demo
chang-wenbin Aug 26, 2024
5d81b44
update five model d2s
chang-wenbin Aug 26, 2024
4bab118
update SD3 clip T5 vae
chang-wenbin Aug 27, 2024
5a14a0f
update clip
chang-wenbin Aug 27, 2024
cd2ef01
uodate T5
chang-wenbin Aug 27, 2024
624168c
uodate T5
chang-wenbin Aug 27, 2024
b009b9f
update scheduling_flow_match_euler_discrete
chang-wenbin Aug 27, 2024
8caa10a
update normalization
chang-wenbin Aug 28, 2024
377629a
update normalization
chang-wenbin Aug 28, 2024
6863054
Merge remote-tracking branch 'upstream/develop' into SD3_PaddleMIX_819
chang-wenbin Aug 28, 2024
15fda4e
update SD3
chang-wenbin Aug 29, 2024
cb993c5
merge develop
chang-wenbin Aug 30, 2024
0e90eaf
update cutlass gemm&fast_gelu
chang-wenbin Sep 2, 2024
c5bb81f
update per-mmdit
chang-wenbin Sep 4, 2024
2c8cc85
merge develop
chang-wenbin Sep 4, 2024
499752a
update triton op split_concat
chang-wenbin Sep 4, 2024
1084f4a
update embeddings
chang-wenbin Sep 5, 2024
e3a5d7c
merge
chang-wenbin Sep 6, 2024
fa84559
recovery
chang-wenbin Sep 6, 2024
27c62f9
recovery
chang-wenbin Sep 6, 2024
951f7a6
merge
chang-wenbin Sep 6, 2024
9515323
update normalization
chang-wenbin Sep 10, 2024
d61e4cb
update dtype
chang-wenbin Sep 10, 2024
d961a4a
add SD3 doc
chang-wenbin Sep 10, 2024
ac1e139
merge develop
chang-wenbin Sep 18, 2024
48c66a6
update SD3 doc
chang-wenbin Sep 18, 2024
24c3c9e
add 'del transformer_blocks'
chang-wenbin Sep 19, 2024
422f33b
update SD3
chang-wenbin Sep 19, 2024
c43d84f
update SD3
chang-wenbin Sep 19, 2024
9d03624
update Notes
chang-wenbin Sep 19, 2024
ded06bf
add Notes
chang-wenbin Sep 19, 2024
d845da2
update demo
chang-wenbin Sep 19, 2024
db6aad1
update doc
chang-wenbin Sep 19, 2024
3527954
update SD3
chang-wenbin Sep 19, 2024
e7848a3
merge zkk
chang-wenbin Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion ppdiffusers/deploy/sd3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,29 @@ python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compati
# 安装develop版本的paddle,请根据自己的cuda版本选择对应的paddle版本,这里选择12.3的cuda版本
python -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加一句,请使用2024年9月6日之后的PaddleNLP,因为在该天,我们修复了一个针对PaddleNLP的bug。
https://github.com/PaddlePaddle/PaddleNLP/pull/9016/files

# 安装PaddleNLP,请使用2024年9月6日之后的PaddleNLP,因为在该天,我们修复了一个针对PaddleNLP的bug。
# https://github.com/PaddlePaddle/PaddleNLP/pull/9016/files
python -m pip install paddlenlp==3.0.0b1

# 指定Tensor-RT的lib路径
export LD_LIBRARY_PATH=/your_TensorRT_dir//lib:$LD_LIBRARY_PATH

# 指定cutlass包路径
export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/conv2d/build:$LD_LIBRARY_PATH

# 指定 libCutlassGemmEpilogue.so 的路径
# 详情请参考 https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/README.md
export LD_LIBRARY_PATH=/your_dir/Paddle/paddle/phi/kernels/fusion/cutlass/gemm_epilogue/build:$LD_LIBRARY_PATH
```

高性能推理指令:
```shell
# 执行FP16推理
# step1: 生成FP32的paddle模型,同时根据Paddle模型生成FP16的TensorRT engine。
python text_to_image_generation-stable_diffusion_3.py --dtype float32 --height 512 --width 512 \
--num-inference-steps 50 --inference_optimize 1 \
--benchmark 1

# step2: 执行FP16推理
python text_to_image_generation-stable_diffusion_3.py --dtype float16 --height 512 --width 512 \
--num-inference-steps 50 --inference_optimize 1 \
--benchmark 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
# limitations under the License.
import os

os.environ["FLAGS_use_cuda_managed_memory"] = "true"
import argparse
import datetime
import os

os.environ["FLAGS_use_cuda_managed_memory"] = "true"
import argparse
import datetime
Expand Down Expand Up @@ -61,18 +66,97 @@ def parse_args():
paddle_dtype=inference_dtype,
)

pipe.text_encoder = paddle.incubate.jit.inference(
pipe.text_encoder,
save_model_dir="./tmp/text_encoder",
cache_static_model=True,
with_trt=True,
trt_precision_mode="float16",
trt_use_static=True,
)

pipe.text_encoder_2 = paddle.incubate.jit.inference(
pipe.text_encoder_2,
save_model_dir="./tmp/text_encoder_2",
cache_static_model=True,
with_trt=True,
trt_precision_mode="float16",
trt_use_static=True,
)


pipe.text_encoder_3 = paddle.incubate.jit.inference(
pipe.text_encoder_3,
save_model_dir="./tmp/text_encoder_3_T5",
cache_static_model=True,
with_trt=True,
trt_precision_mode="float16",
trt_use_static=True,
)

pipe.transformer = paddle.incubate.jit.inference(
pipe.transformer,
save_model_dir="./tmp/sd3",
enable_new_ir=True,
cache_static_model=True,
cache_static_model=False,
exp_enable_use_cutlass=True,
delete_pass_lists=["add_norm_fuse_pass"],
)


# for vae model
pipe.vae.decode = paddle.incubate.jit.inference(
pipe.vae.decode,
save_model_dir="./tmp/vae_static_models",
cache_static_model=True,
with_trt=True,
trt_precision_mode="float16",
trt_use_static=True,
)

generator = paddle.Generator().manual_seed(42)
prompt = "A cat holding a sign that says hello world"


image = pipe(
prompt, num_inference_steps=args.num_inference_steps, width=args.width, height=args.height, generator=generator
).images[0]

if args.benchmark:
# warmup
for i in range(3):
image = pipe(
prompt,
num_inference_steps=args.num_inference_steps,
width=args.width,
height=args.height,
generator=generator,
).images[0]

repeat_times = 10
sumtime = 0.0
for i in range(repeat_times):
paddle.device.synchronize()
starttime = datetime.datetime.now()
image = pipe(
prompt,
num_inference_steps=args.num_inference_steps,
width=args.width,
height=args.height,
generator=generator,
).images[0]
paddle.device.synchronize()
endtime = datetime.datetime.now()
duringtime = endtime - starttime
duringtime = duringtime.seconds * 1000 + duringtime.microseconds / 1000.0
sumtime += duringtime
print("SD3 end to end time : ", duringtime, "ms")

print("SD3 ave end to end time : ", sumtime / repeat_times, "ms")
cuda_mem_after_used = paddle.device.cuda.max_memory_allocated() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB")


image = pipe(
prompt, num_inference_steps=args.num_inference_steps, width=args.width, height=args.height, generator=generator
).images[0]
Expand Down
6 changes: 6 additions & 0 deletions ppdiffusers/ppdiffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, Optional, Tuple, Union

import paddle
Expand Down Expand Up @@ -88,6 +89,9 @@ def __init__(
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
):
# NOTE:(changwenbin,zhoukangkang) SD3 vae use memory_efficient_attention op which is not well supported by Paddle-TensorRT
# so set USE_PPXFORMERS=False to avoid using memory_efficient_attention op.
os.environ["USE_PPXFORMERS"] = "False"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里解释下为什么我们需要将其设置为False吧

super().__init__()
# if down_block_out_channels not given, we will use block_out_channels
_down_block_out_channels = block_out_channels if down_block_out_channels is None else down_block_out_channels
Expand Down Expand Up @@ -116,6 +120,8 @@ def __init__(
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
del os.environ["USE_PPXFORMERS"]
# NOTE:(changwenbin,zhoukangkang) del set USE_PPXFORMERS=False to Restore Defaults

self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1) if use_post_quant_conv else None
Expand Down
1 change: 1 addition & 0 deletions ppdiffusers/ppdiffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def __init__(
def forward(self, x: paddle.Tensor, conditioning_embedding: paddle.Tensor) -> paddle.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
emb = self.linear(self.silu(conditioning_embedding).cast(x.dtype))
emb = self.linear(self.silu(conditioning_embedding).cast(x.dtype))
scale, shift = paddle.chunk(emb, 2, axis=1)
if os.getenv("INFERENCE_OPTIMIZE_TRITON"):
# NOTE:(changwenbin,zhoukangkang)
Expand Down
36 changes: 36 additions & 0 deletions ppdiffusers/ppdiffusers/models/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, Atte
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)

def fn_recursive_attn_processor(name: str, module: paddle.nn.Module, processor):
def fn_recursive_attn_processor(name: str, module: paddle.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
Expand Down Expand Up @@ -278,6 +279,38 @@ def custom_forward(*inputs):
)
return encoder_hidden_states, hidden_states

def sd3_origin_transformer(
self,
hidden_states,
encoder_hidden_states,
temb,
):
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing and not use_old_recompute():

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False}
hidden_states = recompute(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
return encoder_hidden_states, hidden_states

def forward(
self,
hidden_states: paddle.Tensor,
Expand Down Expand Up @@ -321,6 +354,7 @@ def forward(
scale_lora_layers(self, lora_scale)
else:
logger.info("Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.")
logger.info("Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.")

height, width = hidden_states.shape[-2:]

Expand Down Expand Up @@ -351,6 +385,8 @@ def forward(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
)

hidden_states = paddle.transpose(hidden_states, [0, 5, 1, 3, 2, 4])

hidden_states = paddle.transpose(hidden_states, [0, 5, 1, 3, 2, 4])
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,13 @@ def _get_t5_prompt_embeds(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
# breakpoint()
prompt_embeds = self.text_encoder_3(text_input_ids)[0]

outputs = self.text_encoder_3(text_input_ids)
if paddle.incubate.jit.is_inference_mode(self.text_encoder_3):
# NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference
prompt_embeds = outputs
else:
prompt_embeds = outputs[0]

dtype = self.text_encoder_3.dtype
prompt_embeds = prompt_embeds.astype(dtype=dtype)
Expand Down Expand Up @@ -277,13 +282,23 @@ def _get_clip_prompt_embeds(
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
if paddle.incubate.jit.is_inference_mode(text_encoder):
# NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference
pooled_prompt_embeds = prompt_embeds[-1]
if clip_skip is None:
prompt_embeds = prompt_embeds[:-2][-2]
else:
prompt_embeds = prompt_embeds[:-2][-(clip_skip + 2)]
else:
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
pooled_prompt_embeds = prompt_embeds[0]

if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]

pooled_prompt_embeds = pooled_prompt_embeds.astype(dtype=text_encoder.dtype)
prompt_embeds = prompt_embeds.astype(dtype=self.text_encoder.dtype)

_, seq_len, _ = prompt_embeds.shape
Expand Down Expand Up @@ -391,6 +406,9 @@ def encode_prompt(
clip_prompt_embeds,
(0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]),
data_format="NCL",
clip_prompt_embeds,
(0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]),
data_format="NCL",
)

prompt_embeds = paddle.concat([clip_prompt_embeds, t5_prompt_embed], axis=-2)
Expand Down Expand Up @@ -439,12 +457,16 @@ def encode_prompt(
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
prompt=negative_prompt_3,
num_images_per_prompt=num_images_per_prompt,
prompt=negative_prompt_3,
num_images_per_prompt=num_images_per_prompt,
)

negative_clip_prompt_embeds = paddle.nn.functional.pad(
negative_clip_prompt_embeds,
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
data_format="NCL",
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
data_format="NCL",
)

negative_prompt_embeds = paddle.concat([negative_clip_prompt_embeds, t5_negative_prompt_embed], axis=-2)
Expand Down Expand Up @@ -850,7 +872,15 @@ def __call__(
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor

image = self.vae.decode(latents, return_dict=False)[0]
# in order to d2s
if paddle.incubate.jit.is_inference_mode(self.vae.decode):
latents = latents.cast("float32")
image_out = self.vae.decode(latents, return_dict=False)
if paddle.incubate.jit.is_inference_mode(self.vae.decode):
# NOTE:(changwenbin,zhoukangkang) this is for paddle.incubate.jit.inference
image = image_out
else:
image = image_out[0]
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload all models
Expand All @@ -860,3 +890,4 @@ def __call__(
return (image,)

return StableDiffusion3PipelineOutput(images=image)

13 changes: 12 additions & 1 deletion ppdiffusers/ppdiffusers/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from paddle import nn
from paddle.amp.auto_cast import amp_state
from paddle.distributed import fleet
from paddle.framework import in_dynamic_or_pir_mode
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from paddlenlp.transformers.activations import ACT2FN
from paddlenlp.transformers.conversion_utils import (
Expand Down Expand Up @@ -1555,6 +1556,12 @@ def __init__(self, config: T5Config):
# Initialize weights and apply final processing
self.post_init()

# NOTE:(changwenbin,zhoukangkang)
# When you use 'paddle.incubate.jit.inference' to speed up your model,
# if you have set 'cache_static_model=True',
# you can use 'del self.encoder' to reduce the global memory usage.
# del self.encoder

def get_input_embeddings(self):
return self.shared

Expand Down Expand Up @@ -1605,7 +1612,11 @@ def forward(
return_dict=return_dict,
)

return encoder_output
if in_dynamic_or_pir_mode():
return encoder_output
else:
# NOTE:(changwenbin,zhoukangkang)there is a bug in dy2s,we fix it here.
return encoder_output.last_hidden_state


class T5ForSequenceClassification(T5PretrainedModel):
Expand Down