Skip to content

Commit

Permalink
add LoRA layer support to IPU SD Pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
roscisz committed Sep 28, 2023
1 parent 8c4a1dd commit e1ccb00
Showing 1 changed file with 66 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from diffusers import AutoencoderKL, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.cross_attention import CrossAttention
from diffusers.models.cross_attention import CrossAttention, LoRACrossAttnProcessor
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
from transformers import CLIPTextModel
Expand Down Expand Up @@ -58,25 +58,39 @@ def _nearest_divisor(target, start, end):
return divisor
raise ValueError(f"No divisor found in range [{start}, {end}].")

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
@staticmethod
def _forward(
attn: CrossAttention,
hidden_states,
attn_matrix_target_mem_mb,
encoder_hidden_states=None,
attention_mask=None,
lora_cross_attn_processor=None,
scale=1.0,
):
batch_size, sequence_length, _ = hidden_states.shape

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

query = attn.to_q(hidden_states)
if lora_cross_attn_processor is not None:
query += scale * lora_cross_attn_processor.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)

encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if lora_cross_attn_processor is not None:
key += scale * lora_cross_attn_processor.to_k_lora(encoder_hidden_states)
value += scale * lora_cross_attn_processor.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

# Begin IPU modifications.
attn_matrix_mem = query.element_size() * query.shape[0] * query.shape[1] * key.shape[1]
num_slices = attn_matrix_mem // (self._attn_matrix_target_mem_mb * 1024 * 1024)
num_slices = attn_matrix_mem // (attn_matrix_target_mem_mb * 1024 * 1024)
num_slices = max(num_slices, 1)
num_slices = self._nearest_divisor(query.shape[1], num_slices, 2 * num_slices)
num_slices = IPUSlicedAttnProcessor._nearest_divisor(query.shape[1], num_slices, 2 * num_slices)
slice_size = query.shape[1] // num_slices

hidden_states = []
Expand All @@ -101,11 +115,38 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No

# linear proj
hidden_states = attn.to_out[0](hidden_states)
if lora_cross_attn_processor is not None:
hidden_states += scale * lora_cross_attn_processor.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

return hidden_states

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
return self._forward(
attn, hidden_states, self._attn_matrix_target_mem_mb, encoder_hidden_states, attention_mask
)


class IPULoRASlicedAttnProcessor(torch.nn.Module):
def __init__(self, attn_matrix_target_mem_mb: int, lora_cross_attn_processor: LoRACrossAttnProcessor):
super().__init__()
self._attn_matrix_target_mem_mb = attn_matrix_target_mem_mb
self._lora_cross_attn_processor = lora_cross_attn_processor

def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
return IPUSlicedAttnProcessor._forward(
attn,
hidden_states,
self._attn_matrix_target_mem_mb,
encoder_hidden_states,
attention_mask,
self._lora_cross_attn_processor,
scale,
)


class IPUCLIPTextModel(CLIPTextModel, PipelineMixin):
def parallelize(self):
Expand Down Expand Up @@ -148,15 +189,23 @@ def forward(


class IPUUNet2DConditionModel(UNet2DConditionModel, PipelineMixin):
def change_cross_attention_processor(self, attn_matrix_target_mem_mb):
for module in self.modules():
if isinstance(module, CrossAttention):
module.set_processor(IPUSlicedAttnProcessor(attn_matrix_target_mem_mb))
def change_cross_attention_processor(self, attn_matrix_target_mem_mb, lora_name_or_path_or_dict=None):
attn_processors = {}
for attn_processor_name, attn_processor in self.attn_processors.items():
if lora_name_or_path_or_dict is not None:
attn_processors[attn_processor_name] = IPULoRASlicedAttnProcessor(
attn_matrix_target_mem_mb, attn_processor
)
else:
attn_processors[attn_processor_name] = IPUSlicedAttnProcessor(attn_matrix_target_mem_mb)
self.set_attn_processor(attn_processors)

def parallelize(self, attn_matrix_target_mem_mb=None):
def parallelize(self, attn_matrix_target_mem_mb=None, lora_name_or_path_or_dict=None):
super().parallelize()

self.change_cross_attention_processor(attn_matrix_target_mem_mb)
self.change_cross_attention_processor(
attn_matrix_target_mem_mb, lora_name_or_path_or_dict=lora_name_or_path_or_dict
)

self.conv_in = poptorch.BeginBlock(self.conv_in, "conv_in", ipu_id=0)
self.down_blocks[2].downsamplers[0] = poptorch.BeginBlock(
Expand Down Expand Up @@ -269,6 +318,7 @@ def __init__(
vae_ipu_config=None,
safety_checker_ipu_config=None,
common_ipu_config_kwargs=None,
lora_name_or_path_or_dict=None,
):
default_common_ipu_config_kwargs = {
"enable_half_partials": True,
Expand Down Expand Up @@ -399,7 +449,12 @@ def run_safety_checker(self, image, device, dtype):
unet_ipu = copy.deepcopy(unet)
unet_ipu.__class__ = IPUUNet2DConditionModel
unet_ipu.ipu_config = unet_ipu_config
unet_ipu.parallelize(attn_matrix_target_mem_mb=attn_matrix_target_mem_mb)
if lora_name_or_path_or_dict is not None:
unet_ipu.load_attn_procs(lora_name_or_path_or_dict)
unet_ipu.parallelize(
attn_matrix_target_mem_mb=attn_matrix_target_mem_mb,
lora_name_or_path_or_dict=lora_name_or_path_or_dict,
)
override_module_eps(unet_ipu)

opts = unet_ipu_config.to_options(for_inference=True)
Expand Down

0 comments on commit e1ccb00

Please sign in to comment.