Skip to content

vllm.config.attention

AttentionConfig

Configuration for attention mechanisms in vLLM.

Source code in vllm/config/attention.py
@config
class AttentionConfig:
    """Configuration for attention mechanisms in vLLM."""

    backend: AttentionBackendEnum | None = None
    """Attention backend to use. Use "auto" or None for automatic selection."""

    flash_attn_version: Literal[2, 3, 4] | None = None
    """Force vllm to use a specific flash-attention version (2, 3, or 4).
    Only valid when using the flash-attention backend."""

    use_prefill_decode_attention: bool = False
    """Use separate prefill and decode kernels for attention instead of
    the unified triton kernel."""

    flash_attn_max_num_splits_for_cuda_graph: int = 32
    """Flash Attention max number splits for cuda graph decode."""

    use_cudnn_prefill: bool = False
    """Whether to use cudnn prefill."""

    use_trtllm_ragged_deepseek_prefill: bool = True
    """Whether to use TRTLLM ragged deepseek prefill."""

    use_trtllm_attention: bool | None = None
    """If set to True/False, use or don't use the TRTLLM attention backend
    in flashinfer. If None, auto-detect the attention backend in flashinfer."""

    disable_flashinfer_prefill: bool = True
    """Whether to disable flashinfer prefill."""

    disable_flashinfer_q_quantization: bool = False
    """If set, when using fp8 kv, do not quantize Q to fp8."""

    mla_prefill_backend: MLAPrefillBackendEnum | None = None
    """MLA prefill backend to use. If None, will be selected automatically.
    Valid options: FLASH_ATTN, FLASHINFER, CUDNN, TRTLLM_RAGGED.
    This option supersedes use_cudnn_prefill, use_trtllm_ragged_deepseek_prefill,
    and disable_flashinfer_prefill which are deprecated."""

    use_prefill_query_quantization: bool = False
    """If set, quantize query for attention in prefill."""

    def compute_hash(self) -> str:
        """
        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        from vllm.config.utils import get_hash_factors, hash_factors

        ignored_factors: set[str] = set()
        factors = get_hash_factors(self, ignored_factors)
        return hash_factors(factors)

    @field_validator("backend", mode="before")
    @classmethod
    def validate_backend_before(cls, value: Any) -> Any:
        """Enable parsing of the `backend` enum type from string.

        The special value "auto" is treated as None, which triggers
        automatic backend selection.
        """
        if isinstance(value, str):
            if value.lower() == "auto":
                return None
            return AttentionBackendEnum[value.upper()]
        return value

    @field_validator("mla_prefill_backend", mode="before")
    @classmethod
    def validate_mla_prefill_backend_before(cls, value: Any) -> Any:
        """Enable parsing of the `mla_prefill_backend` enum type from string."""
        if isinstance(value, str):
            return MLAPrefillBackendEnum[value.upper()]
        return value

    def __post_init__(self) -> None:
        self._migrate_deprecated_mla_prefill_flags()

    def _migrate_deprecated_mla_prefill_flags(self) -> None:
        """Migrate deprecated MLA prefill flags to mla_prefill_backend."""
        # If the new option is already set, it takes precedence
        if self.mla_prefill_backend is not None:
            return

        # Check for deprecated flags and migrate them
        deprecated_flag_used = False

        if self.use_cudnn_prefill:
            deprecated_flag_used = True
            self.mla_prefill_backend = MLAPrefillBackendEnum.CUDNN
            logger.warning_once(
                "use_cudnn_prefill is deprecated and will be removed in "
                "v0.17. Use --attention-config.mla_prefill_backend="
                "CUDNN instead."
            )

        if self.use_trtllm_ragged_deepseek_prefill:
            if deprecated_flag_used:
                logger.warning_once(
                    "Multiple deprecated MLA prefill flags are set. "
                    "use_trtllm_ragged_deepseek_prefill will be ignored in "
                    "favor of use_cudnn_prefill. Use "
                    "--attention-config.mla_prefill_backend instead."
                )
            else:
                deprecated_flag_used = True
                self.mla_prefill_backend = MLAPrefillBackendEnum.TRTLLM_RAGGED
                logger.warning_once(
                    "use_trtllm_ragged_deepseek_prefill is deprecated and "
                    "will be removed in v0.17. Use "
                    "--attention-config.mla_prefill_backend=TRTLLM_RAGGED "
                    "instead."
                )

        if self.disable_flashinfer_prefill and not deprecated_flag_used:
            # disable_flashinfer_prefill means "use FLASH_ATTN instead"
            # This is only relevant if no other backend was explicitly selected
            self.mla_prefill_backend = MLAPrefillBackendEnum.FLASH_ATTN
            logger.warning_once(
                "disable_flashinfer_prefill is deprecated and will be removed "
                "in v0.17. Use --attention-config.mla_prefill_backend="
                "FLASH_ATTN instead."
            )

backend class-attribute instance-attribute

backend: AttentionBackendEnum | None = None

Attention backend to use. Use "auto" or None for automatic selection.

disable_flashinfer_prefill class-attribute instance-attribute

disable_flashinfer_prefill: bool = True

Whether to disable flashinfer prefill.

disable_flashinfer_q_quantization class-attribute instance-attribute

disable_flashinfer_q_quantization: bool = False

If set, when using fp8 kv, do not quantize Q to fp8.

flash_attn_max_num_splits_for_cuda_graph class-attribute instance-attribute

flash_attn_max_num_splits_for_cuda_graph: int = 32

Flash Attention max number splits for cuda graph decode.

flash_attn_version class-attribute instance-attribute

flash_attn_version: Literal[2, 3, 4] | None = None

Force vllm to use a specific flash-attention version (2, 3, or 4). Only valid when using the flash-attention backend.

mla_prefill_backend class-attribute instance-attribute

mla_prefill_backend: MLAPrefillBackendEnum | None = None

MLA prefill backend to use. If None, will be selected automatically. Valid options: FLASH_ATTN, FLASHINFER, CUDNN, TRTLLM_RAGGED. This option supersedes use_cudnn_prefill, use_trtllm_ragged_deepseek_prefill, and disable_flashinfer_prefill which are deprecated.

use_cudnn_prefill class-attribute instance-attribute

use_cudnn_prefill: bool = False

Whether to use cudnn prefill.

use_prefill_decode_attention class-attribute instance-attribute

use_prefill_decode_attention: bool = False

Use separate prefill and decode kernels for attention instead of the unified triton kernel.

use_prefill_query_quantization class-attribute instance-attribute

use_prefill_query_quantization: bool = False

If set, quantize query for attention in prefill.

use_trtllm_attention class-attribute instance-attribute

use_trtllm_attention: bool | None = None

If set to True/False, use or don't use the TRTLLM attention backend in flashinfer. If None, auto-detect the attention backend in flashinfer.

use_trtllm_ragged_deepseek_prefill class-attribute instance-attribute

use_trtllm_ragged_deepseek_prefill: bool = True

Whether to use TRTLLM ragged deepseek prefill.

_migrate_deprecated_mla_prefill_flags

_migrate_deprecated_mla_prefill_flags() -> None

Migrate deprecated MLA prefill flags to mla_prefill_backend.

Source code in vllm/config/attention.py
def _migrate_deprecated_mla_prefill_flags(self) -> None:
    """Migrate deprecated MLA prefill flags to mla_prefill_backend."""
    # If the new option is already set, it takes precedence
    if self.mla_prefill_backend is not None:
        return

    # Check for deprecated flags and migrate them
    deprecated_flag_used = False

    if self.use_cudnn_prefill:
        deprecated_flag_used = True
        self.mla_prefill_backend = MLAPrefillBackendEnum.CUDNN
        logger.warning_once(
            "use_cudnn_prefill is deprecated and will be removed in "
            "v0.17. Use --attention-config.mla_prefill_backend="
            "CUDNN instead."
        )

    if self.use_trtllm_ragged_deepseek_prefill:
        if deprecated_flag_used:
            logger.warning_once(
                "Multiple deprecated MLA prefill flags are set. "
                "use_trtllm_ragged_deepseek_prefill will be ignored in "
                "favor of use_cudnn_prefill. Use "
                "--attention-config.mla_prefill_backend instead."
            )
        else:
            deprecated_flag_used = True
            self.mla_prefill_backend = MLAPrefillBackendEnum.TRTLLM_RAGGED
            logger.warning_once(
                "use_trtllm_ragged_deepseek_prefill is deprecated and "
                "will be removed in v0.17. Use "
                "--attention-config.mla_prefill_backend=TRTLLM_RAGGED "
                "instead."
            )

    if self.disable_flashinfer_prefill and not deprecated_flag_used:
        # disable_flashinfer_prefill means "use FLASH_ATTN instead"
        # This is only relevant if no other backend was explicitly selected
        self.mla_prefill_backend = MLAPrefillBackendEnum.FLASH_ATTN
        logger.warning_once(
            "disable_flashinfer_prefill is deprecated and will be removed "
            "in v0.17. Use --attention-config.mla_prefill_backend="
            "FLASH_ATTN instead."
        )

compute_hash

compute_hash() -> str

Provide a hash that uniquely identifies all the configs that affect the structure of the computation graph from input ids/embeddings to the final hidden states, excluding anything before input ids/embeddings and after the final hidden states.

Source code in vllm/config/attention.py
def compute_hash(self) -> str:
    """
    Provide a hash that uniquely identifies all the configs
    that affect the structure of the computation
    graph from input ids/embeddings to the final hidden states,
    excluding anything before input ids/embeddings and after
    the final hidden states.
    """
    from vllm.config.utils import get_hash_factors, hash_factors

    ignored_factors: set[str] = set()
    factors = get_hash_factors(self, ignored_factors)
    return hash_factors(factors)

validate_backend_before classmethod

validate_backend_before(value: Any) -> Any

Enable parsing of the backend enum type from string.

The special value "auto" is treated as None, which triggers automatic backend selection.

Source code in vllm/config/attention.py
@field_validator("backend", mode="before")
@classmethod
def validate_backend_before(cls, value: Any) -> Any:
    """Enable parsing of the `backend` enum type from string.

    The special value "auto" is treated as None, which triggers
    automatic backend selection.
    """
    if isinstance(value, str):
        if value.lower() == "auto":
            return None
        return AttentionBackendEnum[value.upper()]
    return value

validate_mla_prefill_backend_before classmethod

validate_mla_prefill_backend_before(value: Any) -> Any

Enable parsing of the mla_prefill_backend enum type from string.

Source code in vllm/config/attention.py
@field_validator("mla_prefill_backend", mode="before")
@classmethod
def validate_mla_prefill_backend_before(cls, value: Any) -> Any:
    """Enable parsing of the `mla_prefill_backend` enum type from string."""
    if isinstance(value, str):
        return MLAPrefillBackendEnum[value.upper()]
    return value