Skip to content

Commit

Permalink
Enable LoRA support for Intel Gaudi
Browse files Browse the repository at this point in the history
Signed-off-by: Sanju C Sudhakaran <[email protected]>
  • Loading branch information
SanjuCSudhakaran committed Nov 22, 2024
1 parent a111d01 commit 965c60a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
23 changes: 19 additions & 4 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
LinearScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm_hpu_extension.punica_hpu import GaudiPunicaWrapper

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -308,10 +312,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)

# Embedding layer only need expand op
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
if current_platform.is_hpu():
assert isinstance(self.punica_wrapper, GaudiPunicaWrapper)
# HPU handles LoRA-B multiplication differently when compared to
# `PunicaWrapper.add_expand`
self.punica_wrapper.add_lora_embedding(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
else:
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
add_input=True)
return full_output.view_as(full_output_org)

@classmethod
Expand Down Expand Up @@ -1491,6 +1504,8 @@ def _get_logits(
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))
if current_platform.is_hpu():
lora_logits = lora_logits[:logits.shape[0], :]
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
Expand Down
15 changes: 12 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,18 @@ def __init__(
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
if self.device == torch.device("hpu"):
# Increasing max_num_batched_tokens by 3x to handle increase in
# tensor size due to padding.
from vllm_hpu_extension.punica_hpu import GaudiPunicaWrapper
self.punica_wrapper = GaudiPunicaWrapper(
3 * max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
else:
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
Expand Down
17 changes: 4 additions & 13 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,11 +1282,9 @@ def create_dummy_seq_group_metadata(self,
def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1]
max_seq_len = min(
self.bucketing_global_state.prompt_seq_bucket_cfg[-1],
self.max_num_batched_tokens // max_batch_size)

max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
return
Expand All @@ -1304,7 +1302,6 @@ def warmup_scenario(self,
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
Expand All @@ -1326,16 +1323,10 @@ def warmup_scenario(self,
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
for idx in range(batch_size)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs or is_pt_profiler_run else 1
if self.lora_config and not is_lora_profile_run:
lora_mapping = LoRAMapping(
**dict(index_mapping=[0] * batch_size * seq_len,
prompt_mapping=[0] * batch_size * seq_len,
is_prefill=is_prompt))
self.set_active_loras(set(), lora_mapping)
if is_prompt:
seqs = [
self.create_dummy_seq_group_metadata(
Expand Down

0 comments on commit 965c60a

Please sign in to comment.