Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft] add LoRA layer support to IPU SD Pipeline #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.cross_attention import CrossAttention
from diffusers.models.cross_attention import CrossAttention, LoRACrossAttnProcessor
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
from transformers import CLIPTextModel
Expand Down Expand Up @@ -58,25 +58,39 @@ def _nearest_divisor(target, start, end):
return divisor
raise ValueError(f"No divisor found in range [{start}, {end}].")

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
@staticmethod
def _forward(
attn: CrossAttention,
hidden_states,
attn_matrix_target_mem_mb,
encoder_hidden_states=None,
attention_mask=None,
lora_cross_attn_processor=None,
scale=1.0,
):
batch_size, sequence_length, _ = hidden_states.shape

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

query = attn.to_q(hidden_states)
if lora_cross_attn_processor is not None:
query += scale * lora_cross_attn_processor.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)

encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if lora_cross_attn_processor is not None:
key += scale * lora_cross_attn_processor.to_k_lora(encoder_hidden_states)
value += scale * lora_cross_attn_processor.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

# Begin IPU modifications.
attn_matrix_mem = query.element_size() * query.shape[0] * query.shape[1] * key.shape[1]
num_slices = attn_matrix_mem // (self._attn_matrix_target_mem_mb * 1024 * 1024)
num_slices = attn_matrix_mem // (attn_matrix_target_mem_mb * 1024 * 1024)
num_slices = max(num_slices, 1)
num_slices = self._nearest_divisor(query.shape[1], num_slices, 2 * num_slices)
num_slices = IPUSlicedAttnProcessor._nearest_divisor(query.shape[1], num_slices, 2 * num_slices)
slice_size = query.shape[1] // num_slices

hidden_states = []
Expand All @@ -101,11 +115,38 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No

# linear proj
hidden_states = attn.to_out[0](hidden_states)
if lora_cross_attn_processor is not None:
hidden_states += scale * lora_cross_attn_processor.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

return hidden_states

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
return self._forward(
attn, hidden_states, self._attn_matrix_target_mem_mb, encoder_hidden_states, attention_mask
)


class IPULoRASlicedAttnProcessor(torch.nn.Module):
def __init__(self, attn_matrix_target_mem_mb: int, lora_cross_attn_processor: LoRACrossAttnProcessor):
super().__init__()
self._attn_matrix_target_mem_mb = attn_matrix_target_mem_mb
self._lora_cross_attn_processor = lora_cross_attn_processor

def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
return IPUSlicedAttnProcessor._forward(
attn,
hidden_states,
self._attn_matrix_target_mem_mb,
encoder_hidden_states,
attention_mask,
self._lora_cross_attn_processor,
scale,
)


class IPUCLIPTextModel(CLIPTextModel, PipelineMixin):
def parallelize(self):
Expand Down Expand Up @@ -148,15 +189,23 @@ def forward(


class IPUUNet2DConditionModel(UNet2DConditionModel, PipelineMixin):
def change_cross_attention_processor(self, attn_matrix_target_mem_mb):
for module in self.modules():
if isinstance(module, CrossAttention):
module.set_processor(IPUSlicedAttnProcessor(attn_matrix_target_mem_mb))
def change_cross_attention_processor(self, attn_matrix_target_mem_mb, lora_name_or_path_or_dict=None):
attn_processors = {}
for attn_processor_name, attn_processor in self.attn_processors.items():
if lora_name_or_path_or_dict is not None:
attn_processors[attn_processor_name] = IPULoRASlicedAttnProcessor(
attn_matrix_target_mem_mb, attn_processor
)
else:
attn_processors[attn_processor_name] = IPUSlicedAttnProcessor(attn_matrix_target_mem_mb)
self.set_attn_processor(attn_processors)

def parallelize(self, attn_matrix_target_mem_mb=None):
def parallelize(self, attn_matrix_target_mem_mb=None, lora_name_or_path_or_dict=None):
super().parallelize()

self.change_cross_attention_processor(attn_matrix_target_mem_mb)
self.change_cross_attention_processor(
attn_matrix_target_mem_mb, lora_name_or_path_or_dict=lora_name_or_path_or_dict
)

self.conv_in = poptorch.BeginBlock(self.conv_in, "conv_in", ipu_id=0)
self.down_blocks[2].downsamplers[0] = poptorch.BeginBlock(
Expand Down Expand Up @@ -269,6 +318,7 @@ def __init__(
vae_ipu_config=None,
safety_checker_ipu_config=None,
common_ipu_config_kwargs=None,
lora_name_or_path_or_dict=None,
):
default_common_ipu_config_kwargs = {
"enable_half_partials": True,
Expand Down Expand Up @@ -399,7 +449,12 @@ def run_safety_checker(self, image, device, dtype):
unet_ipu = copy.deepcopy(unet)
unet_ipu.__class__ = IPUUNet2DConditionModel
unet_ipu.ipu_config = unet_ipu_config
unet_ipu.parallelize(attn_matrix_target_mem_mb=attn_matrix_target_mem_mb)
if lora_name_or_path_or_dict is not None:
unet_ipu.load_attn_procs(lora_name_or_path_or_dict)
unet_ipu.parallelize(
attn_matrix_target_mem_mb=attn_matrix_target_mem_mb,
lora_name_or_path_or_dict=lora_name_or_path_or_dict,
)
override_module_eps(unet_ipu)

opts = unet_ipu_config.to_options(for_inference=True)
Expand Down