Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support softcap in ROCm Flash Attention #10500

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,12 @@
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")

if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap

self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -370,6 +372,14 @@
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Triton FlashAttention does not support attention"
"logits soft capping."
"please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")

from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
Expand All @@ -388,12 +398,17 @@
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func

Check failure on line 401 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Cannot determine type of "attn_func" [has-type]

Check failure on line 401 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Cannot determine type of "attn_func" [has-type]

Check failure on line 401 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Cannot determine type of "attn_func" [has-type]

Check failure on line 401 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Cannot determine type of "attn_func" [has-type]
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True

if self.use_naive_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Naive FlashAttention does not support attention"
"logits soft capping.")

self.attn_func = _sdpa_attention
logger.debug("Using naive attention in ROCmBackend")

Expand Down Expand Up @@ -492,7 +507,7 @@
query.dtype,
attn_metadata.seq_lens,
make_attn_mask=False) # type: ignore
out, _ = self.attn_func(

Check failure on line 510 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Cannot determine type of "attn_func" [has-type]

Check failure on line 510 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Cannot determine type of "attn_func" [has-type]

Check failure on line 510 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Cannot determine type of "attn_func" [has-type]

Check failure on line 510 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Cannot determine type of "attn_func" [has-type]
query,
key,
value,
Expand Down Expand Up @@ -521,7 +536,7 @@
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
# sdpa math backend attention
out = self.attn_func(

Check failure on line 539 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Cannot determine type of "attn_func" [has-type]

Check failure on line 539 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Cannot determine type of "attn_func" [has-type]

Check failure on line 539 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Cannot determine type of "attn_func" [has-type]

Check failure on line 539 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Cannot determine type of "attn_func" [has-type]
query,
key,
value,
Expand All @@ -533,7 +548,7 @@
attn_masks,
)
else:
out = self.attn_func(

Check failure on line 551 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Cannot determine type of "attn_func" [has-type]

Check failure on line 551 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Cannot determine type of "attn_func" [has-type]

Check failure on line 551 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Cannot determine type of "attn_func" [has-type]

Check failure on line 551 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Cannot determine type of "attn_func" [has-type]
q=query,
k=key,
v=value,
Expand All @@ -545,6 +560,7 @@
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)

# common code for prefill
Expand Down
Loading