diff --git a/optimum/graphcore/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_mixin.py b/optimum/graphcore/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_mixin.py index 6119edf20..77eef33b4 100644 --- a/optimum/graphcore/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_mixin.py +++ b/optimum/graphcore/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_mixin.py @@ -19,7 +19,7 @@ import torch from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel from diffusers.models.autoencoder_kl import AutoencoderKLOutput -from diffusers.models.cross_attention import CrossAttention +from diffusers.models.cross_attention import CrossAttention, LoRACrossAttnProcessor from diffusers.models.unet_2d_condition import UNet2DConditionOutput from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution from transformers import CLIPTextModel @@ -58,25 +58,32 @@ def _nearest_divisor(target, start, end): return divisor raise ValueError(f"No divisor found in range [{start}, {end}].") - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + @staticmethod + def _forward(attn: CrossAttention, hidden_states, attn_matrix_target_mem_mb, encoder_hidden_states=None, + attention_mask=None, lora_cross_attn_processor=None, scale=1.0): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) + if lora_cross_attn_processor is not None: + query += scale * lora_cross_attn_processor.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + if lora_cross_attn_processor is not None: + key += scale * lora_cross_attn_processor.to_k_lora(encoder_hidden_states) + value += scale * lora_cross_attn_processor.to_v_lora(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) # Begin IPU modifications. attn_matrix_mem = query.element_size() * query.shape[0] * query.shape[1] * key.shape[1] - num_slices = attn_matrix_mem // (self._attn_matrix_target_mem_mb * 1024 * 1024) + num_slices = attn_matrix_mem // (attn_matrix_target_mem_mb * 1024 * 1024) num_slices = max(num_slices, 1) - num_slices = self._nearest_divisor(query.shape[1], num_slices, 2 * num_slices) + num_slices = IPUSlicedAttnProcessor._nearest_divisor(query.shape[1], num_slices, 2 * num_slices) slice_size = query.shape[1] // num_slices hidden_states = [] @@ -101,11 +108,29 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No # linear proj hidden_states = attn.to_out[0](hidden_states) + if lora_cross_attn_processor is not None: + hidden_states += scale * lora_cross_attn_processor.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states + def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + return self._forward(attn, hidden_states, self._attn_matrix_target_mem_mb, encoder_hidden_states, attention_mask) + + +class IPULoRASlicedAttnProcessor(torch.nn.Module): + def __init__(self, attn_matrix_target_mem_mb: int, lora_cross_attn_processor: LoRACrossAttnProcessor): + super().__init__() + self._attn_matrix_target_mem_mb = attn_matrix_target_mem_mb + self._lora_cross_attn_processor = lora_cross_attn_processor + + def __call__( + self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): + return IPUSlicedAttnProcessor._forward(attn, hidden_states, self._attn_matrix_target_mem_mb, encoder_hidden_states, + attention_mask, self._lora_cross_attn_processor, scale) + class IPUCLIPTextModel(CLIPTextModel, PipelineMixin): def parallelize(self): @@ -148,15 +173,20 @@ def forward( class IPUUNet2DConditionModel(UNet2DConditionModel, PipelineMixin): - def change_cross_attention_processor(self, attn_matrix_target_mem_mb): - for module in self.modules(): - if isinstance(module, CrossAttention): - module.set_processor(IPUSlicedAttnProcessor(attn_matrix_target_mem_mb)) + def change_cross_attention_processor(self, attn_matrix_target_mem_mb, lora_name_or_path_or_dict=None): + attn_processors = {} + for attn_processor_name, attn_processor in self.attn_processors.items(): + if lora_name_or_path_or_dict is not None: + attn_processors[attn_processor_name] = IPULoRASlicedAttnProcessor(attn_matrix_target_mem_mb, attn_processor) + else: + attn_processors[attn_processor_name] = IPUSlicedAttnProcessor(attn_matrix_target_mem_mb) + self.set_attn_processor(attn_processors) + - def parallelize(self, attn_matrix_target_mem_mb=None): + def parallelize(self, attn_matrix_target_mem_mb=None, lora_name_or_path_or_dict=None): super().parallelize() - self.change_cross_attention_processor(attn_matrix_target_mem_mb) + self.change_cross_attention_processor(attn_matrix_target_mem_mb, lora_name_or_path_or_dict=lora_name_or_path_or_dict) self.conv_in = poptorch.BeginBlock(self.conv_in, "conv_in", ipu_id=0) self.down_blocks[2].downsamplers[0] = poptorch.BeginBlock( @@ -269,6 +299,7 @@ def __init__( vae_ipu_config=None, safety_checker_ipu_config=None, common_ipu_config_kwargs=None, + lora_name_or_path_or_dict=None, ): default_common_ipu_config_kwargs = { "enable_half_partials": True, @@ -399,7 +430,12 @@ def run_safety_checker(self, image, device, dtype): unet_ipu = copy.deepcopy(unet) unet_ipu.__class__ = IPUUNet2DConditionModel unet_ipu.ipu_config = unet_ipu_config - unet_ipu.parallelize(attn_matrix_target_mem_mb=attn_matrix_target_mem_mb) + if lora_name_or_path_or_dict is not None: + unet_ipu.load_attn_procs(lora_name_or_path_or_dict) + unet_ipu.parallelize( + attn_matrix_target_mem_mb=attn_matrix_target_mem_mb, + lora_name_or_path_or_dict=lora_name_or_path_or_dict + ) override_module_eps(unet_ipu) opts = unet_ipu_config.to_options(for_inference=True)