Skip to content

vllm.v1.attention.backends.mla.prefill.base

Abstract base classes for MLA prefill backends.

This module defines the interface for MLA prefill backends, enabling priority-based selection similar to how MLA decode backends work.

MLAPrefillBackend

Bases: ABC

Abstract base class for MLA prefill backends.

Each prefill backend declares its capabilities (supported dtypes, compute capabilities, etc.) and provides a factory method for creating the implementation class.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
class MLAPrefillBackend(ABC):
    """Abstract base class for MLA prefill backends.

    Each prefill backend declares its capabilities (supported dtypes,
    compute capabilities, etc.) and provides a factory method for
    creating the implementation class.
    """

    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16,
        torch.bfloat16,
    ]
    requires_r1_mla_dimensions: ClassVar[bool] = False

    @staticmethod
    @abstractmethod
    def get_name() -> str:
        """Return the name of this prefill backend."""
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_prefill_impl_cls() -> type["MLAPrefillImpl"]:
        """Return the implementation class for this prefill backend."""
        raise NotImplementedError

    @staticmethod
    def get_prefill_metadata_cls() -> type["MLACommonPrefillMetadata"]:
        """Return the metadata class for this prefill backend.

        Override this method if the backend requires a specialized
        metadata class (e.g., FlashInferPrefillMetadata).
        """
        from vllm.model_executor.layers.attention.mla_attention import (
            MLACommonPrefillMetadata,
        )

        return MLACommonPrefillMetadata

    @classmethod
    def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
        """Check if this backend supports the given compute capability.

        Args:
            device_capability: The device's compute capability.

        Override this method if the backend has specific hardware requirements.
        """
        return True

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        """Check if this backend supports the given dtype."""
        return dtype in cls.supported_dtypes

    @classmethod
    def is_available(cls) -> bool:
        """Check if this backend's dependencies are available.

        Override this method to check for required libraries/imports.
        """
        return True

    @classmethod
    def validate_configuration(
        cls,
        device_capability: "DeviceCapability",
        selector_config: "MLAPrefillSelectorConfig",
    ) -> list[str]:
        """Validate if this backend can be used with the given configuration.

        Args:
            device_capability: The device's compute capability.
            selector_config: Hashable configuration for backend selection.

        Returns:
            A list of invalid reasons. Empty list if configuration is valid.
        """
        invalid_reasons: list[str] = []

        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append(
                f"compute capability {device_capability.major}."
                f"{device_capability.minor} not supported"
            )

        if not cls.supports_dtype(selector_config.dtype):
            invalid_reasons.append(f"dtype {selector_config.dtype} not supported")

        if not cls.is_available():
            invalid_reasons.append("required dependencies not available")

        if cls.requires_r1_mla_dimensions and not selector_config.is_r1_compatible:
            invalid_reasons.append(
                "model does not have DeepSeek R1 MLA dimensions "
                "(qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128)"
            )

        return invalid_reasons

    @classmethod
    def create_builder_state(
        cls,
        vllm_config: "VllmConfig",
        kv_cache_spec: "AttentionSpec",
        layer_names: list[str],
        device: torch.device,
    ) -> MLAPrefillBuilderState:
        """Create backend-specific state for the metadata builder.

        This is called once when the metadata builder is initialized.
        Override to allocate workspaces, create wrappers, etc.

        Args:
            vllm_config: The vLLM configuration.
            kv_cache_spec: The attention specification.
            layer_names: Names of attention layers.
            device: The device to allocate tensors on.

        Returns:
            A state object containing backend-specific resources.
        """
        return MLAPrefillBuilderState()

    @staticmethod
    def get_chunked_context_metadata_cls() -> type:
        """Return the ChunkedContextMetadata class for this backend.

        Override if the backend needs a specialized ChunkedContextMetadata.
        """
        from vllm.model_executor.layers.attention.mla_attention import (
            MLACommonPrefillMetadata,
        )

        return MLACommonPrefillMetadata.ChunkedContextMetadata

    @classmethod  # noqa: B027
    def post_process_prefill_metadata(
        cls,
        prefill_metadata: "MLACommonPrefillMetadata",
        builder_state: MLAPrefillBuilderState,
        prefill_query_start_loc: torch.Tensor,
    ) -> None:
        """Post-process the prefill metadata after creation.

        This is called after the prefill metadata is created but before
        it's attached to the attention metadata. Use this to set
        backend-specific fields on the metadata.
        """
        pass

    @classmethod  # noqa: B027
    def finalize_attention_metadata(
        cls,
        attn_metadata: Any,
        builder_state: MLAPrefillBuilderState,
        num_prefills: int,
        num_heads: int,
        kv_cache_spec: "AttentionSpec",
        mla_dims: Any,
        model_config: Any,
    ) -> None:
        """Finalize the attention metadata after all components are built.

        This is called after the full attention metadata is constructed.
        Use this for any final processing (e.g., building FlashInfer wrappers).
        """
        pass

create_builder_state classmethod

create_builder_state(
    vllm_config: VllmConfig,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    device: device,
) -> MLAPrefillBuilderState

Create backend-specific state for the metadata builder.

This is called once when the metadata builder is initialized. Override to allocate workspaces, create wrappers, etc.

Parameters:

Name Type Description Default
vllm_config VllmConfig

The vLLM configuration.

required
kv_cache_spec AttentionSpec

The attention specification.

required
layer_names list[str]

Names of attention layers.

required
device device

The device to allocate tensors on.

required

Returns:

Type Description
MLAPrefillBuilderState

A state object containing backend-specific resources.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def create_builder_state(
    cls,
    vllm_config: "VllmConfig",
    kv_cache_spec: "AttentionSpec",
    layer_names: list[str],
    device: torch.device,
) -> MLAPrefillBuilderState:
    """Create backend-specific state for the metadata builder.

    This is called once when the metadata builder is initialized.
    Override to allocate workspaces, create wrappers, etc.

    Args:
        vllm_config: The vLLM configuration.
        kv_cache_spec: The attention specification.
        layer_names: Names of attention layers.
        device: The device to allocate tensors on.

    Returns:
        A state object containing backend-specific resources.
    """
    return MLAPrefillBuilderState()

finalize_attention_metadata classmethod

finalize_attention_metadata(
    attn_metadata: Any,
    builder_state: MLAPrefillBuilderState,
    num_prefills: int,
    num_heads: int,
    kv_cache_spec: AttentionSpec,
    mla_dims: Any,
    model_config: Any,
) -> None

Finalize the attention metadata after all components are built.

This is called after the full attention metadata is constructed. Use this for any final processing (e.g., building FlashInfer wrappers).

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod  # noqa: B027
def finalize_attention_metadata(
    cls,
    attn_metadata: Any,
    builder_state: MLAPrefillBuilderState,
    num_prefills: int,
    num_heads: int,
    kv_cache_spec: "AttentionSpec",
    mla_dims: Any,
    model_config: Any,
) -> None:
    """Finalize the attention metadata after all components are built.

    This is called after the full attention metadata is constructed.
    Use this for any final processing (e.g., building FlashInfer wrappers).
    """
    pass

get_chunked_context_metadata_cls staticmethod

get_chunked_context_metadata_cls() -> type

Return the ChunkedContextMetadata class for this backend.

Override if the backend needs a specialized ChunkedContextMetadata.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@staticmethod
def get_chunked_context_metadata_cls() -> type:
    """Return the ChunkedContextMetadata class for this backend.

    Override if the backend needs a specialized ChunkedContextMetadata.
    """
    from vllm.model_executor.layers.attention.mla_attention import (
        MLACommonPrefillMetadata,
    )

    return MLACommonPrefillMetadata.ChunkedContextMetadata

get_name abstractmethod staticmethod

get_name() -> str

Return the name of this prefill backend.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@staticmethod
@abstractmethod
def get_name() -> str:
    """Return the name of this prefill backend."""
    raise NotImplementedError

get_prefill_impl_cls abstractmethod staticmethod

get_prefill_impl_cls() -> type[MLAPrefillImpl]

Return the implementation class for this prefill backend.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@staticmethod
@abstractmethod
def get_prefill_impl_cls() -> type["MLAPrefillImpl"]:
    """Return the implementation class for this prefill backend."""
    raise NotImplementedError

get_prefill_metadata_cls staticmethod

get_prefill_metadata_cls() -> type[
    MLACommonPrefillMetadata
]

Return the metadata class for this prefill backend.

Override this method if the backend requires a specialized metadata class (e.g., FlashInferPrefillMetadata).

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@staticmethod
def get_prefill_metadata_cls() -> type["MLACommonPrefillMetadata"]:
    """Return the metadata class for this prefill backend.

    Override this method if the backend requires a specialized
    metadata class (e.g., FlashInferPrefillMetadata).
    """
    from vllm.model_executor.layers.attention.mla_attention import (
        MLACommonPrefillMetadata,
    )

    return MLACommonPrefillMetadata

is_available classmethod

is_available() -> bool

Check if this backend's dependencies are available.

Override this method to check for required libraries/imports.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def is_available(cls) -> bool:
    """Check if this backend's dependencies are available.

    Override this method to check for required libraries/imports.
    """
    return True

post_process_prefill_metadata classmethod

post_process_prefill_metadata(
    prefill_metadata: MLACommonPrefillMetadata,
    builder_state: MLAPrefillBuilderState,
    prefill_query_start_loc: Tensor,
) -> None

Post-process the prefill metadata after creation.

This is called after the prefill metadata is created but before it's attached to the attention metadata. Use this to set backend-specific fields on the metadata.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod  # noqa: B027
def post_process_prefill_metadata(
    cls,
    prefill_metadata: "MLACommonPrefillMetadata",
    builder_state: MLAPrefillBuilderState,
    prefill_query_start_loc: torch.Tensor,
) -> None:
    """Post-process the prefill metadata after creation.

    This is called after the prefill metadata is created but before
    it's attached to the attention metadata. Use this to set
    backend-specific fields on the metadata.
    """
    pass

supports_compute_capability classmethod

supports_compute_capability(
    device_capability: DeviceCapability,
) -> bool

Check if this backend supports the given compute capability.

Parameters:

Name Type Description Default
device_capability DeviceCapability

The device's compute capability.

required

Override this method if the backend has specific hardware requirements.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def supports_compute_capability(cls, device_capability: "DeviceCapability") -> bool:
    """Check if this backend supports the given compute capability.

    Args:
        device_capability: The device's compute capability.

    Override this method if the backend has specific hardware requirements.
    """
    return True

supports_dtype classmethod

supports_dtype(dtype: dtype) -> bool

Check if this backend supports the given dtype.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def supports_dtype(cls, dtype: torch.dtype) -> bool:
    """Check if this backend supports the given dtype."""
    return dtype in cls.supported_dtypes

validate_configuration classmethod

validate_configuration(
    device_capability: DeviceCapability,
    selector_config: MLAPrefillSelectorConfig,
) -> list[str]

Validate if this backend can be used with the given configuration.

Parameters:

Name Type Description Default
device_capability DeviceCapability

The device's compute capability.

required
selector_config MLAPrefillSelectorConfig

Hashable configuration for backend selection.

required

Returns:

Type Description
list[str]

A list of invalid reasons. Empty list if configuration is valid.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@classmethod
def validate_configuration(
    cls,
    device_capability: "DeviceCapability",
    selector_config: "MLAPrefillSelectorConfig",
) -> list[str]:
    """Validate if this backend can be used with the given configuration.

    Args:
        device_capability: The device's compute capability.
        selector_config: Hashable configuration for backend selection.

    Returns:
        A list of invalid reasons. Empty list if configuration is valid.
    """
    invalid_reasons: list[str] = []

    if not cls.supports_compute_capability(device_capability):
        invalid_reasons.append(
            f"compute capability {device_capability.major}."
            f"{device_capability.minor} not supported"
        )

    if not cls.supports_dtype(selector_config.dtype):
        invalid_reasons.append(f"dtype {selector_config.dtype} not supported")

    if not cls.is_available():
        invalid_reasons.append("required dependencies not available")

    if cls.requires_r1_mla_dimensions and not selector_config.is_r1_compatible:
        invalid_reasons.append(
            "model does not have DeepSeek R1 MLA dimensions "
            "(qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128)"
        )

    return invalid_reasons

MLAPrefillBuilderState dataclass

State created by a prefill backend for use during metadata building.

This class holds backend-specific resources (workspaces, wrappers, etc.) that persist across metadata build calls. Backends can subclass this to add their own state.

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@dataclass
class MLAPrefillBuilderState:
    """State created by a prefill backend for use during metadata building.

    This class holds backend-specific resources (workspaces, wrappers, etc.)
    that persist across metadata build calls. Backends can subclass this
    to add their own state.
    """

    # Common state that may be used by multiple backends
    workspace_buffer: torch.Tensor | None = None

    # Generic storage for backend-specific state
    backend_state: dict[str, Any] = field(default_factory=dict)

MLAPrefillImpl

Bases: ABC

Abstract base class for MLA prefill implementations.

Each implementation provides the actual prefill attention computation for new tokens (causal) and context chunks (non-causal).

Source code in vllm/v1/attention/backends/mla/prefill/base.py
class MLAPrefillImpl(ABC):
    """Abstract base class for MLA prefill implementations.

    Each implementation provides the actual prefill attention computation
    for new tokens (causal) and context chunks (non-causal).
    """

    # Whether this backend needs to pad V to match Q/K head dim
    requires_v_padding: bool = True

    def __init__(
        self,
        num_heads: int,
        scale: float,
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        vllm_config: "VllmConfig",
        device: torch.device,
    ) -> None:
        """Initialize the prefill implementation.

        Args:
            num_heads: Number of attention heads.
            scale: Softmax scale factor.
            kv_lora_rank: Latent dimension for KV.
            qk_nope_head_dim: QK head dimension without RoPE.
            qk_rope_head_dim: QK head dimension with RoPE.
            v_head_dim: Value head dimension.
            vllm_config: vLLM configuration.
            device: Device to use for computation.
        """
        self.num_heads = num_heads
        self.scale = scale
        self.kv_lora_rank = kv_lora_rank
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.vllm_config = vllm_config
        self.device = device

    @abstractmethod
    def run_prefill_new_tokens(
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        return_softmax_lse: bool,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """Run prefill attention for new tokens (causal).

        Args:
            prefill_metadata: Metadata for the prefill operation.
            q: Query tensor of shape [num_tokens, num_heads, qk_head_dim].
            k: Key tensor of shape [num_tokens, num_heads, qk_head_dim].
            v: Value tensor of shape [num_tokens, num_heads, v_head_dim].
            return_softmax_lse: Whether to return log-sum-exp values.

        Returns:
            If return_softmax_lse is False:
                Output tensor of shape [num_tokens, num_heads, v_head_dim].
            If return_softmax_lse is True:
                Tuple of (output, lse) where lse has shape [num_heads, num_tokens].
        """
        raise NotImplementedError

    @abstractmethod
    def run_prefill_context_chunk(
        self,
        prefill_metadata: "MLACommonPrefillMetadata",
        chunk_idx: int,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Run prefill attention for context chunks (non-causal).

        This is used for chunked prefill where we process cached context
        in chunks to manage memory usage.

        Args:
            prefill_metadata: Metadata for the prefill operation.
            chunk_idx: Index of the current context chunk.
            q: Query tensor of shape [num_tokens, num_heads, qk_head_dim].
            k: Key tensor of shape [chunk_tokens, num_heads, qk_head_dim].
            v: Value tensor of shape [chunk_tokens, num_heads, v_head_dim].

        Returns:
            Tuple of (output, lse) where:
                output has shape [num_tokens, num_heads, v_head_dim]
                lse has shape [num_heads, num_tokens]
        """
        raise NotImplementedError

__init__

__init__(
    num_heads: int,
    scale: float,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    vllm_config: VllmConfig,
    device: device,
) -> None

Initialize the prefill implementation.

Parameters:

Name Type Description Default
num_heads int

Number of attention heads.

required
scale float

Softmax scale factor.

required
kv_lora_rank int

Latent dimension for KV.

required
qk_nope_head_dim int

QK head dimension without RoPE.

required
qk_rope_head_dim int

QK head dimension with RoPE.

required
v_head_dim int

Value head dimension.

required
vllm_config VllmConfig

vLLM configuration.

required
device device

Device to use for computation.

required
Source code in vllm/v1/attention/backends/mla/prefill/base.py
def __init__(
    self,
    num_heads: int,
    scale: float,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    vllm_config: "VllmConfig",
    device: torch.device,
) -> None:
    """Initialize the prefill implementation.

    Args:
        num_heads: Number of attention heads.
        scale: Softmax scale factor.
        kv_lora_rank: Latent dimension for KV.
        qk_nope_head_dim: QK head dimension without RoPE.
        qk_rope_head_dim: QK head dimension with RoPE.
        v_head_dim: Value head dimension.
        vllm_config: vLLM configuration.
        device: Device to use for computation.
    """
    self.num_heads = num_heads
    self.scale = scale
    self.kv_lora_rank = kv_lora_rank
    self.qk_nope_head_dim = qk_nope_head_dim
    self.qk_rope_head_dim = qk_rope_head_dim
    self.v_head_dim = v_head_dim
    self.vllm_config = vllm_config
    self.device = device

run_prefill_context_chunk abstractmethod

run_prefill_context_chunk(
    prefill_metadata: MLACommonPrefillMetadata,
    chunk_idx: int,
    q: Tensor,
    k: Tensor,
    v: Tensor,
) -> tuple[Tensor, Tensor]

Run prefill attention for context chunks (non-causal).

This is used for chunked prefill where we process cached context in chunks to manage memory usage.

Parameters:

Name Type Description Default
prefill_metadata MLACommonPrefillMetadata

Metadata for the prefill operation.

required
chunk_idx int

Index of the current context chunk.

required
q Tensor

Query tensor of shape [num_tokens, num_heads, qk_head_dim].

required
k Tensor

Key tensor of shape [chunk_tokens, num_heads, qk_head_dim].

required
v Tensor

Value tensor of shape [chunk_tokens, num_heads, v_head_dim].

required

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple of (output, lse) where: output has shape [num_tokens, num_heads, v_head_dim] lse has shape [num_heads, num_tokens]

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@abstractmethod
def run_prefill_context_chunk(
    self,
    prefill_metadata: "MLACommonPrefillMetadata",
    chunk_idx: int,
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Run prefill attention for context chunks (non-causal).

    This is used for chunked prefill where we process cached context
    in chunks to manage memory usage.

    Args:
        prefill_metadata: Metadata for the prefill operation.
        chunk_idx: Index of the current context chunk.
        q: Query tensor of shape [num_tokens, num_heads, qk_head_dim].
        k: Key tensor of shape [chunk_tokens, num_heads, qk_head_dim].
        v: Value tensor of shape [chunk_tokens, num_heads, v_head_dim].

    Returns:
        Tuple of (output, lse) where:
            output has shape [num_tokens, num_heads, v_head_dim]
            lse has shape [num_heads, num_tokens]
    """
    raise NotImplementedError

run_prefill_new_tokens abstractmethod

run_prefill_new_tokens(
    prefill_metadata: MLACommonPrefillMetadata,
    q: Tensor,
    k: Tensor,
    v: Tensor,
    return_softmax_lse: bool,
) -> Tensor | tuple[Tensor, Tensor]

Run prefill attention for new tokens (causal).

Parameters:

Name Type Description Default
prefill_metadata MLACommonPrefillMetadata

Metadata for the prefill operation.

required
q Tensor

Query tensor of shape [num_tokens, num_heads, qk_head_dim].

required
k Tensor

Key tensor of shape [num_tokens, num_heads, qk_head_dim].

required
v Tensor

Value tensor of shape [num_tokens, num_heads, v_head_dim].

required
return_softmax_lse bool

Whether to return log-sum-exp values.

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor]

If return_softmax_lse is False: Output tensor of shape [num_tokens, num_heads, v_head_dim].

Tensor | tuple[Tensor, Tensor]

If return_softmax_lse is True: Tuple of (output, lse) where lse has shape [num_heads, num_tokens].

Source code in vllm/v1/attention/backends/mla/prefill/base.py
@abstractmethod
def run_prefill_new_tokens(
    self,
    prefill_metadata: "MLACommonPrefillMetadata",
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    return_softmax_lse: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Run prefill attention for new tokens (causal).

    Args:
        prefill_metadata: Metadata for the prefill operation.
        q: Query tensor of shape [num_tokens, num_heads, qk_head_dim].
        k: Key tensor of shape [num_tokens, num_heads, qk_head_dim].
        v: Value tensor of shape [num_tokens, num_heads, v_head_dim].
        return_softmax_lse: Whether to return log-sum-exp values.

    Returns:
        If return_softmax_lse is False:
            Output tensor of shape [num_tokens, num_heads, v_head_dim].
        If return_softmax_lse is True:
            Tuple of (output, lse) where lse has shape [num_heads, num_tokens].
    """
    raise NotImplementedError