class FlashAttnPrefillImpl(MLAPrefillImpl):
"""FlashAttention implementation for MLA prefill."""
def __init__(
self,
num_heads: int,
scale: float,
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
vllm_config: "VllmConfig",
device: torch.device,
) -> None:
super().__init__(
num_heads=num_heads,
scale=scale,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
vllm_config=vllm_config,
device=device,
)
# Handle the differences between the flash_attn_varlen from
# flash_attn and the one from vllm_flash_attn
assert flash_attn_varlen_func is not None, (
"FlashAttnPrefillImpl requires flash_attn_varlen_func. "
"Ensure FlashAttnPrefillBackend.is_available() is checked first."
)
qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.flash_attn_varlen_func = flash_attn_varlen_func
self.vllm_flash_attn_version = get_flash_attn_version(head_size=qk_head_dim)
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = functools.partial(
flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
)
# Determine if we need to pad V
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims.
# FA3 on Hopper (SM90) and FA4 natively handle diff headdims.
device_capability = current_platform.get_device_capability()
self.requires_v_padding = self.vllm_flash_attn_version is None or not (
(
self.vllm_flash_attn_version == 3
and device_capability is not None
and device_capability[0] == 9
)
or self.vllm_flash_attn_version == 4
)
# Track whether we're using vllm's FA or upstream (for ROCm)
self._is_vllm_fa = current_platform.is_cuda() or current_platform.is_xpu()
def _flash_attn_varlen_diff_headdims(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
return_softmax_lse: bool = False,
softmax_scale: float | None = None,
**kwargs,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Run flash attention with potentially different Q/K and V head dims."""
maybe_padded_v = v
if self.requires_v_padding:
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0
)
if self._is_vllm_fa:
kwargs["return_softmax_lse"] = return_softmax_lse
else:
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
if envs.VLLM_BATCH_INVARIANT:
kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
softmax_scale=softmax_scale,
**kwargs,
)
# Unpack the output if there are multiple results
lse = None
if isinstance(attn_out, tuple):
attn_out, lse = attn_out[0], attn_out[1]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
return attn_out, lse
return attn_out
def run_prefill_new_tokens(
self,
prefill_metadata: "MLACommonPrefillMetadata",
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
return_softmax_lse: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=return_softmax_lse,
)
def run_prefill_context_chunk(
self,
prefill_metadata: "MLACommonPrefillMetadata",
chunk_idx: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
assert prefill_metadata.chunked_context is not None
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[chunk_idx],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[chunk_idx],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)