Skip to content

vllm.v1.attention.backends.mla.prefill

MLA prefill backend implementations.

This package provides modular prefill backends for MLA (Multi-head Latent Attention) with priority-based selection.

Available backends: - FlashAttention (default, works on all hardware) - FlashInfer (optimized for Blackwell) - cuDNN (optimized for Blackwell, requires NVIDIA artifactory) - TRT-LLM Ragged (optimized for Blackwell, DeepSeek-specific)

Metadata classes (FlashInferPrefillMetadata, CudnnPrefillMetadata) are available through their respective backend modules to avoid circular imports.

Modules:

Name Description
base

Abstract base classes for MLA prefill backends.

cudnn

cuDNN backend for MLA prefill.

flash_attn

FlashAttention backend for MLA prefill.

flashinfer

FlashInfer backend for MLA prefill.

registry

Registry for MLA prefill backends.

selector

Selector for MLA prefill backends.

trtllm_ragged

TRT-LLM Ragged backend for MLA prefill.

MLAPrefillBackend

Bases: ABC

Abstract base class for MLA prefill backends.

Each prefill backend declares its capabilities (supported dtypes, compute capabilities, etc.) and provides a factory method for creating the implementation class.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
class MLAPrefillBackend(ABC):
    """Abstract base class for MLA prefill backends.

    Each prefill backend declares its capabilities (supported dtypes,
    compute capabilities, etc.) and provides a factory method for
    creating the implementation class.
    """

    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16,
        torch.bfloat16,
    ]
    requires_r1_mla_dimensions: ClassVar[bool] = False

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        """Return the name of this prefill backend."""
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_prefill_impl_cls() -> type["MLAPrefillImpl"]:
        """Return the implementation class for this prefill backend."""
        raise NotImplementedError

    @staticmethod
    def get_prefill_metadata_cls() -> type["MLACommonPrefillMetadata"]:
        """Return the metadata class for this prefill backend.

        Override this method if the backend requires a specialized
        metadata class (e.g., FlashInferPrefillMetadata).
        """
        from vllm.model_executor.layers.attention.mla_attention import (
            MLACommonPrefillMetadata,
        )

        return MLACommonPrefillMetadata

    @classmethod
    def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
        """Check if this backend supports the given compute capability.

        Args:
            device_capability: The device's compute capability.

        Override this method if the backend has specific hardware requirements.
        """
        return True

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        """Check if this backend supports the given dtype."""
        return dtype in cls.supported_dtypes

    @classmethod
    def is_available(cls) -> bool:
        """Check if this backend's dependencies are available.

        Override this method to check for required libraries/imports.
        """
        return True

    @classmethod
    def validate_configuration(
        cls,
        device_capability: "DeviceCapability",
        selector_config: "MLAPrefillSelectorConfig",
    ) -> list[str]:
        """Validate if this backend can be used with the given configuration.

        Args:
            device_capability: The device's compute capability.
            selector_config: Hashable configuration for backend selection.

        Returns:
            A list of invalid reasons. Empty list if configuration is valid.
        """
        invalid_reasons: list[str] = []

        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append(
                f"compute capability {device_capability.major}."
                f"{device_capability.minor} not supported"
            )

        if not cls.supports_dtype(selector_config.dtype):
            invalid_reasons.append(f"dtype {selector_config.dtype} not supported")

        if not cls.is_available():
            invalid_reasons.append("required dependencies not available")

        if cls.requires_r1_mla_dimensions and not selector_config.is_r1_compatible:
            invalid_reasons.append(
                "model does not have DeepSeek R1 MLA dimensions "
                "(qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128)"
            )

        return invalid_reasons

    @classmethod
    def create_builder_state(
        cls,
        vllm_config: "VllmConfig",
        kv_cache_spec: "AttentionSpec",
        layer_names: list[str],
        device: torch.device,
    ) -> MLAPrefillBuilderState:
        """Create backend-specific state for the metadata builder.

        This is called once when the metadata builder is initialized.
        Override to allocate workspaces, create wrappers, etc.

        Args:
            vllm_config: The vLLM configuration.
            kv_cache_spec: The attention specification.
            layer_names: Names of attention layers.
            device: The device to allocate tensors on.

        Returns:
            A state object containing backend-specific resources.
        """
        return MLAPrefillBuilderState()

    @staticmethod
    def get_chunked_context_metadata_cls() -> type:
        """Return the ChunkedContextMetadata class for this backend.

        Override if the backend needs a specialized ChunkedContextMetadata.
        """
        from vllm.model_executor.layers.attention.mla_attention import (
            MLACommonPrefillMetadata,
        )

        return MLACommonPrefillMetadata.ChunkedContextMetadata

    @classmethod  # noqa: B027
    def post_process_prefill_metadata(
        cls,
        prefill_metadata: "MLACommonPrefillMetadata",
        builder_state: MLAPrefillBuilderState,
        prefill_query_start_loc: torch.Tensor,
    ) -> None:
        """Post-process the prefill metadata after creation.

        This is called after the prefill metadata is created but before
        it's attached to the attention metadata. Use this to set
        backend-specific fields on the metadata.
        """
        pass

    @classmethod  # noqa: B027
    def finalize_attention_metadata(
        cls,
        attn_metadata: Any,
        builder_state: MLAPrefillBuilderState,
        num_prefills: int,
        num_heads: int,
        kv_cache_spec: "AttentionSpec",
        mla_dims: Any,
        model_config: Any,
    ) -> None:
        """Finalize the attention metadata after all components are built.

        This is called after the full attention metadata is constructed.
        Use this for any final processing (e.g., building FlashInfer wrappers).
        """
        pass

create_builder_state classmethod

create_builder_state(
    vllm_config: VllmConfig,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    device: device,
) -> MLAPrefillBuilderState

Create backend-specific state for the metadata builder.

This is called once when the metadata builder is initialized. Override to allocate workspaces, create wrappers, etc.

Parameters:

Name Type Description Default
vllm_config VllmConfig

The vLLM configuration.

required
kv_cache_spec AttentionSpec

The attention specification.

required
layer_names list[str]

Names of attention layers.

required
device device

The device to allocate tensors on.

required

Returns:

Type Description
MLAPrefillBuilderState

A state object containing backend-specific resources.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def create_builder_state(
    cls,
    vllm_config: "VllmConfig",
    kv_cache_spec: "AttentionSpec",
    layer_names: list[str],
    device: torch.device,
) -> MLAPrefillBuilderState:
    """Create backend-specific state for the metadata builder.

    This is called once when the metadata builder is initialized.
    Override to allocate workspaces, create wrappers, etc.

    Args:
        vllm_config: The vLLM configuration.
        kv_cache_spec: The attention specification.
        layer_names: Names of attention layers.
        device: The device to allocate tensors on.

    Returns:
        A state object containing backend-specific resources.
    """
    return MLAPrefillBuilderState()

finalize_attention_metadata classmethod

finalize_attention_metadata(
    attn_metadata: Any,
    builder_state: MLAPrefillBuilderState,
    num_prefills: int,
    num_heads: int,
    kv_cache_spec: AttentionSpec,
    mla_dims: Any,
    model_config: Any,
) -> None

Finalize the attention metadata after all components are built.

This is called after the full attention metadata is constructed. Use this for any final processing (e.g., building FlashInfer wrappers).

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod  # noqa: B027
def finalize_attention_metadata(
    cls,
    attn_metadata: Any,
    builder_state: MLAPrefillBuilderState,
    num_prefills: int,
    num_heads: int,
    kv_cache_spec: "AttentionSpec",
    mla_dims: Any,
    model_config: Any,
) -> None:
    """Finalize the attention metadata after all components are built.

    This is called after the full attention metadata is constructed.
    Use this for any final processing (e.g., building FlashInfer wrappers).
    """
    pass

get_chunked_context_metadata_cls staticmethod

get_chunked_context_metadata_cls() -> type

Return the ChunkedContextMetadata class for this backend.

Override if the backend needs a specialized ChunkedContextMetadata.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@staticmethod
def get_chunked_context_metadata_cls() -> type:
    """Return the ChunkedContextMetadata class for this backend.

    Override if the backend needs a specialized ChunkedContextMetadata.
    """
    from vllm.model_executor.layers.attention.mla_attention import (
        MLACommonPrefillMetadata,
    )

    return MLACommonPrefillMetadata.ChunkedContextMetadata

get_name abstractmethod staticmethod

get_name() -> str

Return the name of this prefill backend.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@staticmethod
@abstractmethod
def get_name() -> str:
    """Return the name of this prefill backend."""
    raise NotImplementedError

get_prefill_impl_cls abstractmethod staticmethod

get_prefill_impl_cls() -> type[MLAPrefillImpl]

Return the implementation class for this prefill backend.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@staticmethod
@abstractmethod
def get_prefill_impl_cls() -> type["MLAPrefillImpl"]:
    """Return the implementation class for this prefill backend."""
    raise NotImplementedError

get_prefill_metadata_cls staticmethod

get_prefill_metadata_cls() -> type[
    MLACommonPrefillMetadata
]

Return the metadata class for this prefill backend.

Override this method if the backend requires a specialized metadata class (e.g., FlashInferPrefillMetadata).

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@staticmethod
def get_prefill_metadata_cls() -> type["MLACommonPrefillMetadata"]:
    """Return the metadata class for this prefill backend.

    Override this method if the backend requires a specialized
    metadata class (e.g., FlashInferPrefillMetadata).
    """
    from vllm.model_executor.layers.attention.mla_attention import (
        MLACommonPrefillMetadata,
    )

    return MLACommonPrefillMetadata

is_available classmethod

is_available() -> bool

Check if this backend's dependencies are available.

Override this method to check for required libraries/imports.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def is_available(cls) -> bool:
    """Check if this backend's dependencies are available.

    Override this method to check for required libraries/imports.
    """
    return True

post_process_prefill_metadata classmethod

post_process_prefill_metadata(
    prefill_metadata: MLACommonPrefillMetadata,
    builder_state: MLAPrefillBuilderState,
    prefill_query_start_loc: Tensor,
) -> None

Post-process the prefill metadata after creation.

This is called after the prefill metadata is created but before it's attached to the attention metadata. Use this to set backend-specific fields on the metadata.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod  # noqa: B027
def post_process_prefill_metadata(
    cls,
    prefill_metadata: "MLACommonPrefillMetadata",
    builder_state: MLAPrefillBuilderState,
    prefill_query_start_loc: torch.Tensor,
) -> None:
    """Post-process the prefill metadata after creation.

    This is called after the prefill metadata is created but before
    it's attached to the attention metadata. Use this to set
    backend-specific fields on the metadata.
    """
    pass

supports_compute_capability classmethod

supports_compute_capability(
    device_capability: DeviceCapability,
) -> bool

Check if this backend supports the given compute capability.

Parameters:

Name Type Description Default
device_capability DeviceCapability

The device's compute capability.

required

Override this method if the backend has specific hardware requirements.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
    """Check if this backend supports the given compute capability.

    Args:
        device_capability: The device's compute capability.

    Override this method if the backend has specific hardware requirements.
    """
    return True

supports_dtype classmethod

supports_dtype(dtype: dtype) -> bool

Check if this backend supports the given dtype.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
    """Check if this backend supports the given dtype."""
    return dtype in cls.supported_dtypes

validate_configuration classmethod

validate_configuration(
    device_capability: DeviceCapability,
    selector_config: MLAPrefillSelectorConfig,
) -> list[str]

Validate if this backend can be used with the given configuration.

Parameters:

Name Type Description Default
device_capability DeviceCapability

The device's compute capability.

required
selector_config MLAPrefillSelectorConfig

Hashable configuration for backend selection.

required

Returns:

Type Description
list[str]

A list of invalid reasons. Empty list if configuration is valid.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def validate_configuration(
    cls,
    device_capability: "DeviceCapability",
    selector_config: "MLAPrefillSelectorConfig",
) -> list[str]:
    """Validate if this backend can be used with the given configuration.

    Args:
        device_capability: The device's compute capability.
        selector_config: Hashable configuration for backend selection.

    Returns:
        A list of invalid reasons. Empty list if configuration is valid.
    """
    invalid_reasons: list[str] = []

    if not cls.supports_compute_capability(device_capability):
        invalid_reasons.append(
            f"compute capability {device_capability.major}."
            f"{device_capability.minor} not supported"
        )

    if not cls.supports_dtype(selector_config.dtype):
        invalid_reasons.append(f"dtype {selector_config.dtype} not supported")

    if not cls.is_available():
        invalid_reasons.append("required dependencies not available")

    if cls.requires_r1_mla_dimensions and not selector_config.is_r1_compatible:
        invalid_reasons.append(
            "model does not have DeepSeek R1 MLA dimensions "
            "(qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128)"
        )

    return invalid_reasons

MLAPrefillBackendEnum

Bases: Enum

Enumeration of all supported MLA prefill backends.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
class MLAPrefillBackendEnum(Enum, metaclass=_MLAPrefillBackendEnumMeta):
    """Enumeration of all supported MLA prefill backends."""

    FLASH_ATTN = (
        "vllm.v1.attention.backends.mla.prefill.flash_attn.FlashAttnPrefillBackend"
    )
    FLASHINFER = (
        "vllm.v1.attention.backends.mla.prefill.flashinfer.FlashInferPrefillBackend"
    )
    CUDNN = "vllm.v1.attention.backends.mla.prefill.cudnn.CudnnPrefillBackend"
    TRTLLM_RAGGED = (
        "vllm.v1.attention.backends.mla.prefill.trtllm_ragged."
        "TrtllmRaggedPrefillBackend"
    )

    def get_path(self) -> str:
        """Get the fully qualified class path for this backend."""
        return self.value

    def get_class(self) -> "type[MLAPrefillBackend]":
        """Lazy load and return the backend class."""
        return resolve_obj_by_qualname(self.get_path())

get_class

get_class() -> type[MLAPrefillBackend]

Lazy load and return the backend class.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def get_class(self) -> "type[MLAPrefillBackend]":
    """Lazy load and return the backend class."""
    return resolve_obj_by_qualname(self.get_path())

get_path

get_path() -> str

Get the fully qualified class path for this backend.

Source code in vllm/v1/attention/backends/mla/prefill/registry.py
def get_path(self) -> str:
    """Get the fully qualified class path for this backend."""
    return self.value

MLAPrefillBuilderState dataclass

State created by a prefill backend for use during metadata building.

This class holds backend-specific resources (workspaces, wrappers, etc.) that persist across metadata build calls. Backends can subclass this to add their own state.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@dataclass
class MLAPrefillBuilderState:
    """State created by a prefill backend for use during metadata building.

    This class holds backend-specific resources (workspaces, wrappers, etc.)
    that persist across metadata build calls. Backends can subclass this
    to add their own state.
    """

    # Common state that may be used by multiple backends
    workspace_buffer: torch.Tensor | None = None

    # Generic storage for backend-specific state
    backend_state: dict[str, Any] = field(default_factory=dict)

MLAPrefillImpl

Bases: ABC

Abstract base class for MLA prefill implementations.

Each implementation provides the actual prefill attention computation for new tokens (causal) and context chunks (non-causal).

Source code in vllm/v1/attention/backends/mla/prefill/base.py
class MLAPrefillImpl(ABC):
    """Abstract base class for MLA prefill implementations.

    Each implementation provides the actual prefill attention computation
    for new tokens (causal) and context chunks (non-causal).
    """

    # Whether this backend needs to pad V to match Q/K head dim
    requires_v_padding: bool = True

    def __init__(
        self,
        num_heads: int,
        scale: float,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        vllm_config: "VllmConfig",
        device: torch.device,
    ) -> None:
        """Initialize the prefill implementation.

        Args:
            num_heads: Number of attention heads.
            scale: Softmax scale factor.
            kv_lora_rank: Latent dimension for KV.
            qk_nope_head_dim: QK head dimension without RoPE.
            qk_rope_head_dim: QK head dimension with RoPE.
            v_head_dim: Value head dimension.
            vllm_config: vLLM configuration.
            device: Device to use for computation.
        """
        self.num_heads = num_heads
        self.scale = scale
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.vllm_config = vllm_config
        self.device = device

    @abstractmethod
    def run_prefill_new_tokens(
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        return_softmax_lse: bool,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """Run prefill attention for new tokens (causal).

        Args:
            prefill_metadata: Metadata for the prefill operation.
            q: Query tensor of shape [num_tokens, num_heads, qk_head_dim].
            k: Key tensor of shape [num_tokens, num_heads, qk_head_dim].
            v: Value tensor of shape [num_tokens, num_heads, v_head_dim].
            return_softmax_lse: Whether to return log-sum-exp values.

        Returns:
            If return_softmax_lse is False:
                Output tensor of shape [num_tokens, num_heads, v_head_dim].
            If return_softmax_lse is True:
                Tuple of (output, lse) where lse has shape [num_heads, num_tokens].
        """
        raise NotImplementedError

    @abstractmethod
    def run_prefill_context_chunk(
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
        chunk_idx: int,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Run prefill attention for context chunks (non-causal).

        This is used for chunked prefill where we process cached context
        in chunks to manage memory usage.

        Args:
            prefill_metadata: Metadata for the prefill operation.
            chunk_idx: Index of the current context chunk.
            q: Query tensor of shape [num_tokens, num_heads, qk_head_dim].
            k: Key tensor of shape [chunk_tokens, num_heads, qk_head_dim].
            v: Value tensor of shape [chunk_tokens, num_heads, v_head_dim].

        Returns:
            Tuple of (output, lse) where:
                output has shape [num_tokens, num_heads, v_head_dim]
                lse has shape [num_heads, num_tokens]
        """
        raise NotImplementedError

__init__

__init__(
    num_heads: int,
    scale: float,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    vllm_config: VllmConfig,
    device: device,
) -> None

Initialize the prefill implementation.

Parameters:

Name Type Description Default
num_heads int

Number of attention heads.

required
scale float

Softmax scale factor.

required
kv_lora_rank int

Latent dimension for KV.

required
qk_nope_head_dim int

QK head dimension without RoPE.

required
qk_rope_head_dim int

QK head dimension with RoPE.

required
v_head_dim int

Value head dimension.

required
vllm_config VllmConfig

vLLM configuration.

required
device device

Device to use for computation.

required
Source code in vllm/v1/attention/backends/mla/prefill/base.py
def __init__(
    self,
    num_heads: int,
    scale: float,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    vllm_config: "VllmConfig",
    device: torch.device,
) -> None:
    """Initialize the prefill implementation.

    Args:
        num_heads: Number of attention heads.
        scale: Softmax scale factor.
        kv_lora_rank: Latent dimension for KV.
        qk_nope_head_dim: QK head dimension without RoPE.
        qk_rope_head_dim: QK head dimension with RoPE.
        v_head_dim: Value head dimension.
        vllm_config: vLLM configuration.
        device: Device to use for computation.
    """
    self.num_heads = num_heads
    self.scale = scale
    self.kv_lora_rank = kv_lora_rank
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.v_head_dim = v_head_dim
    self.vllm_config = vllm_config
    self.device = device

run_prefill_context_chunk abstractmethod

run_prefill_context_chunk(
    prefill_metadata: MLACommonPrefillMetadata,
    chunk_idx: int,
    q: Tensor,
    k: Tensor,
    v: Tensor,
) -> tuple[Tensor, Tensor]

Run prefill attention for context chunks (non-causal).

This is used for chunked prefill where we process cached context in chunks to manage memory usage.

Parameters:

Name Type Description Default
prefill_metadata MLACommonPrefillMetadata

Metadata for the prefill operation.

required
chunk_idx int

Index of the current context chunk.

required
q Tensor

Query tensor of shape [num_tokens, num_heads, qk_head_dim].

required
k Tensor

Key tensor of shape [chunk_tokens, num_heads, qk_head_dim].

required
v Tensor

Value tensor of shape [chunk_tokens, num_heads, v_head_dim].

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (output, lse) where: output has shape [num_tokens, num_heads, v_head_dim] lse has shape [num_heads, num_tokens]

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@abstractmethod
def run_prefill_context_chunk(
    self,
    prefill_metadata: "MLACommonPrefillMetadata",
    chunk_idx: int,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Run prefill attention for context chunks (non-causal).

    This is used for chunked prefill where we process cached context
    in chunks to manage memory usage.

    Args:
        prefill_metadata: Metadata for the prefill operation.
        chunk_idx: Index of the current context chunk.
        q: Query tensor of shape [num_tokens, num_heads, qk_head_dim].
        k: Key tensor of shape [chunk_tokens, num_heads, qk_head_dim].
        v: Value tensor of shape [chunk_tokens, num_heads, v_head_dim].

    Returns:
        Tuple of (output, lse) where:
            output has shape [num_tokens, num_heads, v_head_dim]
            lse has shape [num_heads, num_tokens]
    """
    raise NotImplementedError

run_prefill_new_tokens abstractmethod

run_prefill_new_tokens(
    prefill_metadata: MLACommonPrefillMetadata,
    q: Tensor,
    k: Tensor,
    v: Tensor,
    return_softmax_lse: bool,
) -> Tensor | tuple[Tensor, Tensor]

Run prefill attention for new tokens (causal).

Parameters:

Name Type Description Default
prefill_metadata MLACommonPrefillMetadata

Metadata for the prefill operation.

required
q Tensor

Query tensor of shape [num_tokens, num_heads, qk_head_dim].

required
k Tensor

Key tensor of shape [num_tokens, num_heads, qk_head_dim].

required
v Tensor

Value tensor of shape [num_tokens, num_heads, v_head_dim].

required
return_softmax_lse bool

Whether to return log-sum-exp values.

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor]

If return_softmax_lse is False: Output tensor of shape [num_tokens, num_heads, v_head_dim].

Tensor | tuple[Tensor, Tensor]

If return_softmax_lse is True: Tuple of (output, lse) where lse has shape [num_heads, num_tokens].

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@abstractmethod
def run_prefill_new_tokens(
    self,
    prefill_metadata: "MLACommonPrefillMetadata",
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    return_softmax_lse: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Run prefill attention for new tokens (causal).

    Args:
        prefill_metadata: Metadata for the prefill operation.
        q: Query tensor of shape [num_tokens, num_heads, qk_head_dim].
        k: Key tensor of shape [num_tokens, num_heads, qk_head_dim].
        v: Value tensor of shape [num_tokens, num_heads, v_head_dim].
        return_softmax_lse: Whether to return log-sum-exp values.

    Returns:
        If return_softmax_lse is False:
            Output tensor of shape [num_tokens, num_heads, v_head_dim].
        If return_softmax_lse is True:
            Tuple of (output, lse) where lse has shape [num_heads, num_tokens].
    """
    raise NotImplementedError

get_mla_prefill_backend

get_mla_prefill_backend(
    vllm_config: VllmConfig,
) -> type[MLAPrefillBackend]

Select the MLA prefill backend based on configuration and device.

This function first checks for explicit user preferences via mla.prefill_backend in AttentionConfig, then falls back to automatic priority-based selection.

Parameters:

Name Type Description Default
vllm_config VllmConfig

The vLLM configuration.

required

Returns:

Type Description
type[MLAPrefillBackend]

The selected prefill backend class.

Source code in vllm/v1/attention/backends/mla/prefill/selector.py
def get_mla_prefill_backend(
    vllm_config: "VllmConfig",
) -> "type[MLAPrefillBackend]":
    """Select the MLA prefill backend based on configuration and device.

    This function first checks for explicit user preferences via
    mla.prefill_backend in AttentionConfig, then falls back to automatic
    priority-based selection.

    Args:
        vllm_config: The vLLM configuration.

    Returns:
        The selected prefill backend class.
    """
    from vllm.platforms import current_platform

    device_capability = current_platform.get_device_capability()
    if device_capability is None:
        # Fallback for non-CUDA platforms or during profiling
        logger.info_once(
            "Device capability not available, using FlashAttention MLA prefill"
        )
        return MLAPrefillBackendEnum.FLASH_ATTN.get_class()

    attention_config = vllm_config.attention_config

    # Build hashable selector config for caching
    selector_config = MLAPrefillSelectorConfig(
        dtype=vllm_config.model_config.dtype,
        is_r1_compatible=is_deepseek_r1_mla_compatible(vllm_config),
    )

    # Check for explicit backend selection (includes migrated deprecated flags)
    if attention_config.mla_prefill_backend is not None:
        backend_enum = attention_config.mla_prefill_backend
        try:
            backend_cls = backend_enum.get_class()
            invalid_reasons = backend_cls.validate_configuration(
                device_capability, selector_config
            )
            if not invalid_reasons:
                logger.info_once("Using %s for MLA prefill", backend_cls.get_name())
                return backend_cls
            else:
                logger.warning(
                    "Requested MLA prefill backend %s is not valid: %s. "
                    "Falling back to auto-selection.",
                    backend_enum.name,
                    invalid_reasons,
                )
        except ImportError as e:
            logger.warning(
                "Requested MLA prefill backend %s is not available: %s. "
                "Falling back to auto-selection.",
                backend_enum.name,
                e,
            )

    # Auto-select based on priority
    return _auto_select_mla_prefill_backend(
        device_capability,
        selector_config,
    )