Skip to content

vllm.v1.attention.ops.triton_unified_attention

_get_tile_size

_get_tile_size(
    head_size: int,
    sliding_window: int,
    element_size: int,
    is_prefill: bool,
) -> int

Select tile size with Gemma3-specific optimization.

For Gemma3, use 32 for both prefill and decode to better utilize the larger head dimension (128/256). For other models, use the default vLLM behavior.

Source code in vllm/v1/attention/ops/triton_unified_attention.py
def _get_tile_size(
    head_size: int,
    sliding_window: int,
    element_size: int,
    is_prefill: bool,
) -> int:
    """Select tile size with Gemma3-specific optimization.

    For Gemma3, use 32 for both prefill and decode to better utilize
    the larger head dimension (128/256). For other models, use
    the default vLLM behavior.
    """
    if _is_gemma3_attention(head_size, sliding_window):
        # Gemma3: use 32 for decode (default is 16)
        return 32

    # Default behavior
    if is_prefill:
        return 32
    return 16 if element_size >= 2 else 32

_is_gemma3_attention

_is_gemma3_attention(
    head_size: int, sliding_window: int
) -> bool

Detect Gemma3 models via unique (head_size, sliding_window) signature.

Gemma3 models are the only ones using sliding_window=1024 with head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use different window sizes (Mistral=4096, Phi-3=2047).

Source code in vllm/v1/attention/ops/triton_unified_attention.py
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
    """Detect Gemma3 models via unique (head_size, sliding_window) signature.

    Gemma3 models are the only ones using sliding_window=1024 with
    head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
    different window sizes (Mistral=4096, Phi-3=2047).
    """
    return sliding_window == 1024 and head_size in (128, 256)

_prepare_kv_tile

_prepare_kv_tile(
    data,
    Q,
    tensor_scale,
    scale_cache_ptr,
    physical_block_idx,
    seq_offset,
    kv_head_idx,
    stride_s_blk,
    stride_s_slot,
    stride_s_head,
    tile_mask,
    BLOCK_SIZE: constexpr,
    KV_QUANT_MODE: constexpr,
)

Prepare a loaded KV tile for attention computation.

Casts the raw KV data to Q's dtype and loads per-token-head scales when applicable:

  • KV_QUANT_MODE == 0: cast only (no-op for bf16/fp16).
  • KV_QUANT_MODE == 1 (FP8 per-tensor): dequantize inline using the tensor-wide scale.
  • KV_QUANT_MODE >= 2 (per-token-head int8/fp8): cast to Q's dtype and return per-head scales separately — the caller applies them after the dot product for better numerical efficiency.

Returns (data, token_head_scales). token_head_scales is only meaningful when KV_QUANT_MODE >= 2; callers gate its use on the same constexpr so the compiler eliminates dead code.

Source code in vllm/v1/attention/ops/triton_unified_attention.py
@triton.jit
def _prepare_kv_tile(
    data,
    Q,
    tensor_scale,
    scale_cache_ptr,
    physical_block_idx,
    seq_offset,
    kv_head_idx,
    stride_s_blk,
    stride_s_slot,
    stride_s_head,
    tile_mask,
    BLOCK_SIZE: tl.constexpr,
    KV_QUANT_MODE: tl.constexpr,
):
    """Prepare a loaded KV tile for attention computation.

    Casts the raw KV data to Q's dtype and loads per-token-head scales
    when applicable:

    - ``KV_QUANT_MODE == 0``: cast only (no-op for bf16/fp16).
    - ``KV_QUANT_MODE == 1`` (FP8 per-tensor): dequantize inline
      using the tensor-wide scale.
    - ``KV_QUANT_MODE >= 2`` (per-token-head int8/fp8): cast to Q's
      dtype and return per-head scales separately — the caller applies
      them after the dot product for better numerical efficiency.

    Returns ``(data, token_head_scales)``.  *token_head_scales* is only
    meaningful when ``KV_QUANT_MODE >= 2``; callers gate its use on
    the same constexpr so the compiler eliminates dead code.
    """
    # KV_QUANT_MODE values: 0=none, 1=fp8 per-tensor,
    #                       2=int8 per-token-head, 3=fp8 per-token-head

    # Placeholder scales (float32) — never read when KV_QUANT_MODE < 2.
    unused_scales = tile_mask.to(tl.float32)

    if KV_QUANT_MODE == 1:  # FP8 per-tensor
        if Q.dtype.is_fp8():
            return data.to(Q.dtype), unused_scales
        return (data.to(tl.float32) * tl.load(tensor_scale)).to(Q.dtype), unused_scales
    if KV_QUANT_MODE >= 2:  # per-token-head (int8 or fp8)
        scale_idx = (
            physical_block_idx * stride_s_blk
            + (seq_offset % BLOCK_SIZE) * stride_s_slot
            + kv_head_idx * stride_s_head
        )
        token_head_scales = tl.load(
            scale_cache_ptr + scale_idx, mask=tile_mask, other=1.0
        )
        return data.to(Q.dtype), token_head_scales
    # .to(Q.dtype) is a no-op when data is already Q's type (bf16/fp16),
    # but required so Triton sees consistent return types across branches.
    return data.to(Q.dtype), unused_scales