Skip to content

vllm.v1.attention.backends.mla.prefill.flash_attn

FlashAttention backend for MLA prefill.

FlashAttnPrefillBackend

Bases: MLAPrefillBackend

FlashAttention backend for MLA prefill.

This is the default/fallback backend that works on most hardware.

Source code in vllm/v1/attention/backends/mla/prefill/flash_attn.py
class FlashAttnPrefillBackend(MLAPrefillBackend):
    """FlashAttention backend for MLA prefill.

    This is the default/fallback backend that works on most hardware.
    """

    @staticmethod
    def get_name() -> str:
        return "FLASH_ATTN_PREFILL"

    @staticmethod
    def get_prefill_impl_cls() -> type["FlashAttnPrefillImpl"]:
        return FlashAttnPrefillImpl

    @classmethod
    def is_available(cls) -> bool:
        return is_flash_attn_varlen_func_available()

FlashAttnPrefillImpl

Bases: MLAPrefillImpl

FlashAttention implementation for MLA prefill.

Source code in vllm/v1/attention/backends/mla/prefill/flash_attn.py
class FlashAttnPrefillImpl(MLAPrefillImpl):
    """FlashAttention implementation for MLA prefill."""

    def __init__(
        self,
        num_heads: int,
        scale: float,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        vllm_config: "VllmConfig",
        device: torch.device,
    ) -> None:
        super().__init__(
            num_heads=num_heads,
            scale=scale,
            kv_lora_rank=kv_lora_rank,
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
            vllm_config=vllm_config,
            device=device,
        )

        # Handle the differences between the flash_attn_varlen from
        # flash_attn and the one from vllm_flash_attn
        assert flash_attn_varlen_func is not None, (
            "FlashAttnPrefillImpl requires flash_attn_varlen_func. "
            "Ensure FlashAttnPrefillBackend.is_available() is checked first."
        )
        qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.flash_attn_varlen_func = flash_attn_varlen_func
        self.vllm_flash_attn_version = get_flash_attn_version(head_size=qk_head_dim)
        if self.vllm_flash_attn_version is not None:
            self.flash_attn_varlen_func = functools.partial(
                flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
            )

        # Determine if we need to pad V
        # For MLA the v head dim is smaller than qk head dim so we pad out
        # v with 0s to match the qk head dim for attention backends that do
        # not support different headdims.
        # FA3 on Hopper (SM90) and FA4 natively handle diff headdims.
        device_capability = current_platform.get_device_capability()
        self.requires_v_padding = self.vllm_flash_attn_version is None or not (
            (
                self.vllm_flash_attn_version == 3
                and device_capability is not None
                and device_capability[0] == 9
            )
            or self.vllm_flash_attn_version == 4
        )

        # Track whether we're using vllm's FA or upstream (for ROCm)
        self._is_vllm_fa = current_platform.is_cuda() or current_platform.is_xpu()

    def _flash_attn_varlen_diff_headdims(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        return_softmax_lse: bool = False,
        softmax_scale: float | None = None,
        **kwargs,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """Run flash attention with potentially different Q/K and V head dims."""
        maybe_padded_v = v
        if self.requires_v_padding:
            maybe_padded_v = torch.nn.functional.pad(
                v, [0, q.shape[-1] - v.shape[-1]], value=0
            )

        if self._is_vllm_fa:
            kwargs["return_softmax_lse"] = return_softmax_lse
        else:
            # ROCm leverages the upstream flash_attn, which takes a parameter
            # called "return_attn_probs" instead of return_softmax_lse
            kwargs["return_attn_probs"] = return_softmax_lse
        if envs.VLLM_BATCH_INVARIANT:
            kwargs["num_splits"] = 1

        attn_out = self.flash_attn_varlen_func(
            q=q,
            k=k,
            v=maybe_padded_v,
            softmax_scale=softmax_scale,
            **kwargs,
        )

        # Unpack the output if there are multiple results
        lse = None
        if isinstance(attn_out, tuple):
            attn_out, lse = attn_out[0], attn_out[1]

        # Remain consistent with old `flash_attn_varlen_func` where there
        # is only one output tensor if `return_softmax_lse` is False.
        if return_softmax_lse:
            return attn_out, lse
        return attn_out

    def run_prefill_new_tokens(
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        return_softmax_lse: bool,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        return self._flash_attn_varlen_diff_headdims(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=prefill_metadata.query_start_loc,
            cu_seqlens_k=prefill_metadata.query_start_loc,
            max_seqlen_q=prefill_metadata.max_query_len,
            max_seqlen_k=prefill_metadata.max_query_len,
            softmax_scale=self.scale,
            causal=True,
            return_softmax_lse=return_softmax_lse,
        )

    def run_prefill_context_chunk(
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
        chunk_idx: int,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        assert prefill_metadata.chunked_context is not None
        return self._flash_attn_varlen_diff_headdims(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=prefill_metadata.query_start_loc,
            cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[chunk_idx],
            max_seqlen_q=prefill_metadata.max_query_len,
            max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[chunk_idx],
            softmax_scale=self.scale,
            causal=False,  # Context is unmasked
            return_softmax_lse=True,
        )

_flash_attn_varlen_diff_headdims

_flash_attn_varlen_diff_headdims(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    return_softmax_lse: bool = False,
    softmax_scale: float | None = None,
    **kwargs,
) -> Tensor | tuple[Tensor, Tensor]

Run flash attention with potentially different Q/K and V head dims.

Source code in vllm/v1/attention/backends/mla/prefill/flash_attn.py
def _flash_attn_varlen_diff_headdims(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    return_softmax_lse: bool = False,
    softmax_scale: float | None = None,
    **kwargs,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Run flash attention with potentially different Q/K and V head dims."""
    maybe_padded_v = v
    if self.requires_v_padding:
        maybe_padded_v = torch.nn.functional.pad(
            v, [0, q.shape[-1] - v.shape[-1]], value=0
        )

    if self._is_vllm_fa:
        kwargs["return_softmax_lse"] = return_softmax_lse
    else:
        # ROCm leverages the upstream flash_attn, which takes a parameter
        # called "return_attn_probs" instead of return_softmax_lse
        kwargs["return_attn_probs"] = return_softmax_lse
    if envs.VLLM_BATCH_INVARIANT:
        kwargs["num_splits"] = 1

    attn_out = self.flash_attn_varlen_func(
        q=q,
        k=k,
        v=maybe_padded_v,
        softmax_scale=softmax_scale,
        **kwargs,
    )

    # Unpack the output if there are multiple results
    lse = None
    if isinstance(attn_out, tuple):
        attn_out, lse = attn_out[0], attn_out[1]

    # Remain consistent with old `flash_attn_varlen_func` where there
    # is only one output tensor if `return_softmax_lse` is False.
    if return_softmax_lse:
        return attn_out, lse
    return attn_out