Skip to content

vllm.model_executor.layers.quantization.utils.marlin_utils_fp8

mxfp8_marlin_process_scales

mxfp8_marlin_process_scales(
    marlin_scales: Tensor,
) -> Tensor

Reorder scales for e8m0 kernel layout and convert to float8_e8m0fnu.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
def mxfp8_marlin_process_scales(marlin_scales: torch.Tensor) -> torch.Tensor:
    """Reorder scales for e8m0 kernel layout and convert to float8_e8m0fnu."""
    # fit the layout of fp8 dequantization
    marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view(
        marlin_scales.size(0), -1
    )
    marlin_scales = marlin_scales.to(torch.float8_e8m0fnu)
    return marlin_scales

pack_fp8_to_int32

pack_fp8_to_int32(
    fp8_tensor: Tensor, size_k_first: bool = True
) -> Tensor

Repack FP8 weights to gptq format (packed int32 elements)

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
def pack_fp8_to_int32(
    fp8_tensor: torch.Tensor, size_k_first: bool = True
) -> torch.Tensor:
    """
    Repack FP8 weights to gptq format (packed int32 elements)
    """
    assert fp8_tensor.dtype == torch.float8_e4m3fn
    assert fp8_tensor.ndim == 2

    fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
    fp8_tensor = fp8_tensor.contiguous()
    # fp8_tensor is contiguous and have shape (N, K) now
    # with `.view(torch.int32)`, it become (N, K // 4)
    int32_tensor = fp8_tensor.view(torch.int32)
    return int32_tensor.T.contiguous() if size_k_first else int32_tensor

prepare_fp8_moe_layer_for_marlin

prepare_fp8_moe_layer_for_marlin(
    layer: Module,
    w13_weight: Tensor,
    w2_weight: Tensor,
    w13_weight_scale: Tensor,
    w2_weight_scale: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor]

Shuffle weights and scales into marlin format.

Note that this function has the side effect of adding a workspace attribute to the layer. This workspace does not need to be registered as a Parameter as it is not used during weight reloading.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
def prepare_fp8_moe_layer_for_marlin(
    layer: torch.nn.Module,
    w13_weight: torch.Tensor,
    w2_weight: torch.Tensor,
    w13_weight_scale: torch.Tensor,
    w2_weight_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Shuffle weights and scales into marlin format.

    Note that this function has the side effect of adding a `workspace`
    attribute to the layer. This `workspace` does not need to be
    registered as a Parameter as it is not used during weight reloading.
    """

    logger.warning_once(
        "Your GPU does not have native support for FP8 computation but "
        "FP8 quantization is being used. Weight-only FP8 compression will "
        "be used leveraging the Marlin kernel. This may degrade "
        "performance for compute-heavy workloads."
    )
    input_dtype = get_marlin_input_dtype()
    if input_dtype is not None and input_dtype.itemsize == 1:
        raise NotImplementedError("Marlin W8A8 is not supported.")

    e = layer.num_experts
    k = layer.hidden_size
    n = layer.intermediate_size_per_partition
    w13_n = w13_weight.size(1)
    weight_block_size = getattr(layer, "weight_block_size", None)

    # WORKSPACE
    device = layer.w13_weight.device
    # NOTE(rob): we do not need to register the workspace as a param
    # because it is not used as part of the weight reloading process.
    layer.workspace = marlin_make_workspace_new(device, 4)
    perm = torch.empty(0, dtype=torch.int, device=device)

    # WEIGHT
    # Repack weights to marlin format
    def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
        tensor_list = []
        if "w13" in name:
            size_n, size_k = w13_n, k
        else:
            size_n, size_k = k, n

        assert weight.shape == (e, size_n, size_k)

        for i in range(e):
            qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
            qweight = qweight.T.contiguous()

            marlin_qweight = ops.gptq_marlin_repack(
                b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
            )
            tensor_list.append(marlin_qweight)

        return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)

    w13_weight = repack_weight("w13", w13_weight)
    w2_weight = repack_weight("w2", w2_weight)

    # WEIGHT SCALES
    # Permute scales
    group_size = -1 if weight_block_size is None else weight_block_size[1]

    def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
        scales = scales.to(layer.orig_dtype)
        tensor_list = []
        if "w13" in name:
            size_n, size_k = w13_n, k
        else:
            size_n, size_k = k, n

        # marlin kernel only support channel-wise and group-wise quantization
        # we need to convert the scales
        if weight_block_size is None:
            if scales.nelement() == e:
                # tensor-wise quantization -> channel-wise quantization
                # (e, 1, 1) =>(repeat)=> (e, 1, size_n)
                scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
            elif scales.nelement() > e and scales.nelement() != e * size_n:
                assert (e * size_n) % scales.nelement() == 0
                s_size = scales.nelement() // e
                # tensor-wise quantization (for gate-up proj)
                #     -> channel-wise quantization
                # (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
                scales = scales.view(e, 1, s_size)
                scales = scales.repeat_interleave(size_n // s_size, 2)
            else:
                # channel-wise quantization
                # (e, 1, size_n)
                scales = scales.view(e, 1, size_n)
        else:
            # block-wise quantization -> group-wise quantization
            # (e, size_k // block_size[1], ceil(size_n / block_size[0]))
            #  =>(repeat)=> (e, size_k // block_size[1], size_n)
            scales = scales.permute(0, 2, 1)
            block_n = weight_block_size[0]
            scales = scales.repeat_interleave(block_n, 2)
            # size_n may not divisible by block_size[0]
            scales = scales[..., :size_n].contiguous()

        for i in range(e):
            marlin_scales = marlin_permute_scales(
                s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
            )
            tensor_list.append(marlin_scales)

        scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
        if input_dtype != torch.float8_e4m3fn:
            scales = fp8_fused_exponent_bias_into_scales(scales)
        return scales

    w13_weight_scale = permute_scales(w13_weight_scale, "w13")
    w2_weight_scale = permute_scales(w2_weight_scale, "w2")

    return w13_weight, w2_weight, w13_weight_scale, w2_weight_scale

prepare_mxfp8_layer_for_marlin

prepare_mxfp8_layer_for_marlin(layer: Module) -> None

Repack MXFP8 weights and scales into Marlin kernel format.

Expects the layer to have
  • weight: [N, K] float8_e4m3fn
  • weight_scale: [N, K//32] uint8 (e8m0 encoded)
  • input_size_per_partition / output_size_per_partition
Source code in vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
def prepare_mxfp8_layer_for_marlin(layer: torch.nn.Module) -> None:
    """Repack MXFP8 weights and scales into Marlin kernel format.

    Expects the layer to have:
      - weight: [N, K] float8_e4m3fn
      - weight_scale: [N, K//32] uint8 (e8m0 encoded)
      - input_size_per_partition / output_size_per_partition
    """
    part_size_n = layer.output_size_per_partition
    part_size_k = layer.input_size_per_partition
    group_size = 32  # MX standard block size

    device = layer.weight.device

    # WORKSPACE
    layer.workspace = marlin_make_workspace_new(device)

    # WEIGHT - repack FP8 weights to Marlin format
    perm = torch.empty(0, dtype=torch.int, device=device)
    qweight = pack_fp8_to_int32(layer.weight, size_k_first=False)
    qweight = qweight.T.contiguous()

    marlin_qweight = ops.gptq_marlin_repack(
        b_q_weight=qweight,
        perm=perm,
        size_k=part_size_k,
        size_n=part_size_n,
        num_bits=8,
    )
    replace_parameter(layer, "weight", marlin_qweight)

    # WEIGHT SCALES
    # Convert uint8 scales -> e8m0fnu -> param_dtype for permutation
    # Scales are [N, K//32], need [K//32, N] for marlin_permute_scales
    param_dtype = torch.get_default_dtype()
    scales = layer.weight_scale.data[:part_size_n, : part_size_k // group_size]
    scales = scales.contiguous()
    scales = scales.view(torch.float8_e8m0fnu).to(param_dtype)
    scales = scales.T.contiguous()

    # Permute scales to Marlin layout
    marlin_scales = marlin_permute_scales(
        s=scales,
        size_k=part_size_k,
        size_n=part_size_n,
        group_size=group_size,
    )

    # Reorder for e8m0 kernel layout and convert back to e8m0fnu
    marlin_scales = mxfp8_marlin_process_scales(marlin_scales)
    replace_parameter(layer, "weight_scale", marlin_scales)

    # BIAS
    if hasattr(layer, "bias") and layer.bias is not None:
        assert layer.bias.shape == (part_size_n,)
        bias = marlin_permute_bias(layer.bias)
        replace_parameter(layer, "bias", bias)

prepare_mxfp8_moe_layer_for_marlin

prepare_mxfp8_moe_layer_for_marlin(
    layer: Module,
    w13: Tensor,
    w2: Tensor,
    w13_scale: Tensor,
    w2_scale: Tensor,
) -> tuple[Tensor, Tensor, Tensor, Tensor]

Repack MXFP8 MoE weights and scales into Marlin kernel format.

Parameters:

Name Type Description Default
layer Module

MoE layer (used to read params_dtype and attach workspace).

required
w13 Tensor

[E, 2*N, K] float8_e4m3fn weights.

required
w2 Tensor

[E, K, N] float8_e4m3fn weights.

required
w13_scale Tensor

[E, 2*N, K//32] uint8 e8m0 scales.

required
w2_scale Tensor

[E, K, N//32] uint8 e8m0 scales.

required

Returns:

Type Description
tuple[Tensor, Tensor, Tensor, Tensor]

(w13, w2, w13_scale, w2_scale) in Marlin format.

Source code in vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
def prepare_mxfp8_moe_layer_for_marlin(
    layer: torch.nn.Module,
    w13: torch.Tensor,
    w2: torch.Tensor,
    w13_scale: torch.Tensor,
    w2_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Repack MXFP8 MoE weights and scales into Marlin kernel format.

    Args:
        layer: MoE layer (used to read params_dtype and attach workspace).
        w13: [E, 2*N, K] float8_e4m3fn weights.
        w2:  [E, K, N] float8_e4m3fn weights.
        w13_scale: [E, 2*N, K//32] uint8 e8m0 scales.
        w2_scale:  [E, K, N//32] uint8 e8m0 scales.

    Returns:
        (w13, w2, w13_scale, w2_scale) in Marlin format.
    """
    group_size = 32
    e = w13.shape[0]
    w13_n = w13.shape[1]
    k = w13.shape[2]
    n = w2.shape[2]

    device = w13.device
    param_dtype = torch.get_default_dtype()
    perm = torch.empty(0, dtype=torch.int, device=device)

    layer.workspace = marlin_make_workspace_new(device, 4)

    def repack_weight(weight: torch.Tensor, name: str) -> torch.Tensor:
        if "w13" in name:
            size_n, size_k = w13_n, k
        else:
            size_n, size_k = k, n

        assert weight.shape == (e, size_n, size_k)

        tensor_list = []
        for i in range(e):
            qweight = pack_fp8_to_int32(weight[i], size_k_first=False)
            qweight = qweight.T.contiguous()
            marlin_qweight = ops.gptq_marlin_repack(
                b_q_weight=qweight,
                perm=perm,
                size_k=size_k,
                size_n=size_n,
                num_bits=8,
            )
            tensor_list.append(marlin_qweight)
        return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)

    w13 = repack_weight(w13, "w13")
    w2 = repack_weight(w2, "w2")

    def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
        if "w13" in name:
            size_n, size_k = w13_n, k
        else:
            size_n, size_k = k, n

        tensor_list = []
        for i in range(e):
            s = scales[i][:size_n, : size_k // group_size].contiguous()
            s = s.view(torch.float8_e8m0fnu).to(param_dtype)
            s = s.T.contiguous()
            marlin_s = marlin_permute_scales(
                s=s,
                size_k=size_k,
                size_n=size_n,
                group_size=group_size,
            )
            marlin_s = mxfp8_marlin_process_scales(marlin_s)
            tensor_list.append(marlin_s)
        return torch.cat([x.unsqueeze(0) for x in tensor_list], 0)

    w13_scale = permute_scales(w13_scale, "w13")
    w2_scale = permute_scales(w2_scale, "w2")

    return w13, w2, w13_scale, w2_scale