Skip to content

vllm.model_executor.models.gemma4_mm

Gemma 4 multimodal model (image + audio + video support).

Adds vision tower, audio tower, and multimodal embedders on top of the text-only Gemma4ForCausalLM. The vision/audio encoders are loaded via AutoModel.from_config and run in eager mode while the language model uses the vLLM-optimized path.

Video support: Gemma4 does not have a native video tower. Videos are decomposed into timestamped image frames (up to 32 frames at 70 soft tokens each) and fed through the same vision tower as regular images. The processor inserts mm:ss timestamps between frames so the model can reason about temporal order.

Gemma4AudioInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of audios
  • s: Sequence length (MEL spectrogram frames)
  • f: Number of features (MEL bins)
Source code in vllm/model_executor/models/gemma4_mm.py
class Gemma4AudioInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of audios
        - s: Sequence length (MEL spectrogram frames)
        - f: Number of features (MEL bins)
    """

    type: Literal["audio"] = "audio"
    input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
    input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]

Gemma4ForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP

Source code in vllm/model_executor/models/gemma4_mm.py
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
@MULTIMODAL_REGISTRY.register_processor(
    Gemma4MultiModalProcessor,
    info=Gemma4ProcessingInfo,
    dummy_inputs=Gemma4DummyInputsBuilder,
)
class Gemma4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # Maps checkpoint prefixes to vLLM module paths.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.embed_audio.": "embed_audio.",
            "model.embed_vision.": "embed_vision.",
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.audio_tower.": "audio_tower.",
            "lm_head.": "language_model.lm_head.",
            "model": "language_model.model",
        }
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config

        # ---- Vision tower (shared by image and video) ----
        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.vision_tower = AutoModel.from_config(config=config.vision_config)
            self.embed_vision = Gemma4MultimodalEmbedder(
                config.vision_config, config.text_config
            )

        # ---- Audio tower (variants with audio_config) ----
        if config.audio_config is not None:
            with self._mark_tower_model(vllm_config, "audio"):
                self.audio_tower = AutoModel.from_config(config=config.audio_config)
                # AutoModel.from_config does NOT call post_init(),
                # which is needed to initialize buffers that are absent
                # from the checkpoint (e.g. inv_timescales for relative
                # position embeddings, softcap, gradient_clipping).
                self.audio_tower.post_init()
                self.embed_audio = Gemma4MultimodalEmbedder(
                    config.audio_config, config.text_config
                )
        else:
            self.audio_tower = None
            self.embed_audio = None

        # ---- Language model (vLLM optimised) ----
        with self._mark_language_model(vllm_config):
            self.language_model: Gemma4ForCausalLM = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Gemma4ForCausalLM"],
            )

            # Pre-allocate PLE buffer for CUDA graph compatibility.
            # Some variants have hidden_size_per_layer_input=None (no PLE).
            ple_dim = config.text_config.hidden_size_per_layer_input
            if ple_dim is not None:
                self.per_layer_embeddings = torch.zeros(
                    vllm_config.scheduler_config.max_num_batched_tokens,
                    config.text_config.num_hidden_layers,
                    ple_dim,
                    device=(self.language_model.model.embed_tokens.weight.device),
                    dtype=(self.language_model.model.embed_tokens.weight.dtype),
                )
            else:
                self.per_layer_embeddings = None

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

        # --- MixtureOfExperts delegation to language_model ---
        self.expert_weights = self.language_model.expert_weights
        self.moe_layers = self.language_model.moe_layers
        self.num_moe_layers = self.language_model.num_moe_layers
        self.num_logical_experts = self.language_model.num_logical_experts
        self.num_physical_experts = self.language_model.num_physical_experts
        self.num_local_physical_experts = self.language_model.num_local_physical_experts
        self.num_routed_experts = self.language_model.num_routed_experts
        self.num_expert_groups = self.language_model.num_expert_groups
        self.num_shared_experts = self.language_model.num_shared_experts
        self.num_redundant_experts = self.language_model.num_redundant_experts

    # ------------------------------------------------------------------ #
    # Input parsing
    # ------------------------------------------------------------------ #

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> Gemma4ImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_position_ids = kwargs.pop("pixel_position_ids", None)
        image_embeds = kwargs.pop("image_embeds", None)
        assert image_embeds is None, "Gemma4 does not support image_embeds."
        if pixel_values is None:
            return None
        return Gemma4ImagePixelInputs(
            pixel_values=pixel_values,
            pixel_position_ids=pixel_position_ids,
        )

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> Gemma4AudioInputs | None:
        input_features_padded = kwargs.pop("input_features_padded", None)
        if input_features_padded is None:
            return None
        input_features_mask = kwargs.pop("input_features_mask", None)
        if input_features_mask is None:
            return None
        return Gemma4AudioInputs(
            input_features_padded=input_features_padded,
            input_features_mask=input_features_mask,
        )

    def _parse_and_validate_video_input(
        self, **kwargs: object
    ) -> dict[str, torch.Tensor] | None:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        pixel_position_ids_videos = kwargs.pop("pixel_position_ids_videos", None)
        video_frame_counts = kwargs.pop("video_frame_counts", None)
        if pixel_values_videos is None:
            return None
        return {
            "pixel_values_videos": pixel_values_videos,
            "pixel_position_ids_videos": pixel_position_ids_videos,
            "video_frame_counts": video_frame_counts,
        }

    def _parse_and_validate_multimodal_inputs(
        self, **kwargs: object
    ) -> dict[str, Gemma4ImageInputs | Gemma4AudioInputs | Gemma4VideoInputs | None]:
        mm_input_by_modality = {}
        for input_key in list(kwargs):
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key == "pixel_values_videos"
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
            if (
                input_key == "input_features_padded"
                and "audio" not in mm_input_by_modality
            ):
                mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
                    **kwargs
                )
        return mm_input_by_modality

    # ------------------------------------------------------------------ #
    # Image processing
    # ------------------------------------------------------------------ #

    def _process_image_input(
        self,
        image_input: Gemma4ImageInputs,
    ) -> list[torch.Tensor]:
        pixel_values = image_input["pixel_values"]
        pixel_position_ids = image_input["pixel_position_ids"]

        # The HF image processor now outputs pre-patchified data:
        #   pixel_values:       (num_images, max_patches, patch_pixels)
        #   pixel_position_ids: (num_images, max_patches, 2)
        # We call the vision tower's forward() directly, which handles
        # patch embedding, encoding, pooling, padding removal, and
        # optional standardization internally.
        vt = self.vision_tower
        pooling_k2 = self.config.vision_config.pooling_kernel_size**2

        # TODO: Move this per-image loop into the input processor to
        # reduce dynamism at the model runner / engine core. This
        # requires spatially padding all images to uniform (H_max,
        # W_max) in _call_hf_processor() so they arrive as a single
        # stacked tensor, tracking padded regions via image_sizes
        # metadata, and validating numerical equivalence with the
        # current per-image path.
        #
        # Process each image individually through the vision tower.
        # The vision tower's forward() strips padding and returns a
        # flat tensor of valid tokens. We process per-image to get
        # variable-length outputs matching the dynamic token count
        # from get_image_repl.
        per_image_features = []
        for i in range(pixel_values.shape[0]):
            pv = pixel_values[i].unsqueeze(0)  # (1, max_patches, patch_pixels)
            pp = pixel_position_ids[i].unsqueeze(0)  # (1, max_patches, 2)

            # Derive the pooler's output_length from the total patch
            # count (including padding).  The vision tower encoder
            # processes ALL patches — padding patches get zero hidden
            # states but still occupy sequence positions.  The pooler's
            # _avg_pool_by_positions requires:
            #     input_seq_len / output_length == k²
            # where k == pooling_kernel_size.  The image processor
            # allocates max_patches = max_soft_tokens * k² total slots,
            # so output_length = max_patches / k² == max_soft_tokens.
            # Without this, the pooler falls back to
            # config.image_seq_length (e.g. 280), which fails when a
            # different max_soft_tokens was used at preprocessing time.
            max_patches = pv.shape[1]
            output_length = max_patches // pooling_k2

            vt_output = vt(pv, pp, output_length=output_length)
            # last_hidden_state: (num_valid_tokens, hidden_size)
            # — already flat with padding stripped by the vision tower
            per_image_features.append(vt_output.last_hidden_state)

        # Project each image's features into LM embedding space.
        # Per-image loop is required because images have variable
        # token counts after padding removal.
        # Cast to match the projection layer's dtype (model may be
        # bf16 while the vision tower outputs fp32).
        target_dtype = self.embed_vision.embedding_projection.weight.dtype
        return [
            self.embed_vision(inputs_embeds=img.unsqueeze(0).to(target_dtype)).squeeze(
                0
            )
            for img in per_image_features
        ]

    # ------------------------------------------------------------------ #
    # Video processing (frames through vision tower)
    # ------------------------------------------------------------------ #

    def _process_video_input(
        self,
        video_input: dict[str, torch.Tensor],
    ) -> list[torch.Tensor]:
        """Process video frames through the vision tower.

        Reuses the image processing pipeline — Gemma4 has no separate
        video tower; video frames are just images at lower resolution
        (max_soft_tokens=70).

        Returns one concatenated embedding tensor per video (not per
        frame), because vLLM treats one video as one multimodal item.
        The flat_from_sizes field config groups all frames of a video
        together, so embed_multimodal must return one tensor per video.
        """
        pixel_values = video_input["pixel_values_videos"]
        pixel_position_ids = video_input["pixel_position_ids_videos"]
        frame_counts = video_input["video_frame_counts"]

        vt = self.vision_tower
        pooling_k2 = self.config.vision_config.pooling_kernel_size**2
        target_dtype = self.embed_vision.embedding_projection.weight.dtype

        # Split flat tensors into per-video chunks
        if isinstance(frame_counts, torch.Tensor):
            fc_list = frame_counts.tolist()
        else:
            fc_list = list(frame_counts)

        pv_per_video = torch.split(pixel_values, fc_list, dim=0)
        pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0)

        per_video_embeddings = []
        for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video):
            frame_embs = []
            for i in range(pv_chunk.shape[0]):
                pv = pv_chunk[i].unsqueeze(0)
                pp = pp_chunk[i].unsqueeze(0)

                max_patches = pv.shape[1]
                output_length = max_patches // pooling_k2

                vt_output = vt(pv, pp, output_length=output_length)
                frame_emb = self.embed_vision(
                    inputs_embeds=(
                        vt_output.last_hidden_state.unsqueeze(0).to(target_dtype)
                    )
                ).squeeze(0)
                frame_embs.append(frame_emb)

            # Concatenate all frames of this video into one tensor.
            per_video_embeddings.append(torch.cat(frame_embs, dim=0))

        return per_video_embeddings

    # ------------------------------------------------------------------ #
    # Audio processing
    # ------------------------------------------------------------------ #

    def _process_audio_input(
        self,
        audio_input: Gemma4AudioInputs,
    ) -> list[torch.Tensor]:
        input_features = audio_input["input_features_padded"].squeeze(1)
        input_features_mask = audio_input["input_features_mask"].squeeze(1)

        # Run audio tower — mask uses standard HF convention
        # (True=valid, False=padding).
        audio_outputs = self.audio_tower(input_features, input_features_mask)
        if isinstance(audio_outputs, tuple):
            audio_encodings, audio_mask = audio_outputs
        else:
            audio_encodings = audio_outputs.last_hidden_state
            audio_mask = audio_outputs.attention_mask

        # Project into LM embedding space.
        audio_features = self.embed_audio(inputs_embeds=audio_encodings)

        # Strip padding per-batch element: only keep real (non-padding)
        # tokens. audio_mask is True for valid positions (HF convention).
        per_audio = []
        for enc, mask in zip(audio_features, audio_mask, strict=True):
            per_audio.append(enc[mask])  # [num_real, hidden_size]

        return per_audio

    # ------------------------------------------------------------------ #
    # MultiModalEmbeddings interface
    # ------------------------------------------------------------------ #

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        multimodal_embeddings: list[torch.Tensor] = []

        for modality, multimodal_input in mm_input_by_modality.items():
            if multimodal_input is None:
                continue
            if modality == "image":
                multimodal_embeddings.extend(
                    self._process_image_input(multimodal_input)
                )
            elif modality == "video":
                multimodal_embeddings.extend(
                    self._process_video_input(multimodal_input)
                )
            elif modality == "audio":
                multimodal_embeddings.extend(
                    self._process_audio_input(multimodal_input)
                )

        return multimodal_embeddings

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # Cache per-layer embeddings (PLE) for the language model's
        # forward pass.  During profiling embed_input_ids is not called,
        # so the pre-allocated zeros are used instead.
        if self.per_layer_embeddings is not None:
            # Mask multimodal tokens (image/audio) to 0 for PLE
            # computation (using token_type_ids == 0 as text_mask).
            # Replicate this: map image token positions to token 0.
            if is_multimodal is not None:
                is_multimodal = is_multimodal.to(input_ids.device)
                ple_input_ids = torch.where(
                    is_multimodal, torch.zeros_like(input_ids), input_ids
                )
            else:
                ple_input_ids = input_ids

            per_layer_inputs = self.language_model.model.get_per_layer_inputs(
                ple_input_ids
            )
            if per_layer_inputs is not None:
                per_layer_inputs = per_layer_inputs.reshape(
                    -1,
                    self.config.text_config.num_hidden_layers,
                    self.config.text_config.hidden_size_per_layer_input,
                )
                self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_(
                    per_layer_inputs
                )

        if multimodal_embeddings is None or is_multimodal is None:
            return super().embed_input_ids(input_ids)

        return super().embed_input_ids(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

    # ------------------------------------------------------------------ #
    # Forward
    # ------------------------------------------------------------------ #

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # Select the pre-cached PLEs for this batch (None when PLE
        # is disabled for variants without PLE).
        per_layer_inputs = (
            self.per_layer_embeddings[: inputs_embeds.shape[0]]
            if self.per_layer_embeddings is not None and inputs_embeds is not None
            else None
        )

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            per_layer_inputs=per_layer_inputs,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    # ------------------------------------------------------------------ #
    # Weight loading
    # ------------------------------------------------------------------ #

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        # Some checkpoints have vestigial embed_vision.embedding and
        # embed_audio.embedding weights from the Gemma3n architecture
        # that are not used by Gemma4's MultimodalEmbedder (which only
        # has embedding_projection + embedding_post_projection_norm).
        ignore_prefixes = [
            "embed_vision.embedding.",
            "embed_audio.embedding.",
        ]
        # Models without audio tower should skip
        # audio weights entirely.
        if self.audio_tower is None:
            ignore_prefixes.extend(
                [
                    "audio_tower.",
                    "embed_audio.",
                ]
            )
        loader = AutoWeightsLoader(
            self,
            ignore_unexpected_prefixes=ignore_prefixes,
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    # ------------------------------------------------------------------ #
    # LoRA / multimodal mapping
    # ------------------------------------------------------------------ #

    def get_mm_mapping(self) -> MultiModelKeys:
        """Get the module prefix mapping for multimodal models."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector=["embed_vision", "embed_audio"],
            tower_model=["vision_tower", "audio_tower"],
        )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality == "image":
            return "<image_soft_token>"
        if modality == "audio":
            return "<audio_soft_token>"
        if modality == "video":
            return "<|video|>"
        raise ValueError(f"Unsupported modality: {modality}")

_process_video_input

_process_video_input(
    video_input: dict[str, Tensor],
) -> list[Tensor]

Process video frames through the vision tower.

Reuses the image processing pipeline — Gemma4 has no separate video tower; video frames are just images at lower resolution (max_soft_tokens=70).

Returns one concatenated embedding tensor per video (not per frame), because vLLM treats one video as one multimodal item. The flat_from_sizes field config groups all frames of a video together, so embed_multimodal must return one tensor per video.

Source code in vllm/model_executor/models/gemma4_mm.py
def _process_video_input(
    self,
    video_input: dict[str, torch.Tensor],
) -> list[torch.Tensor]:
    """Process video frames through the vision tower.

    Reuses the image processing pipeline — Gemma4 has no separate
    video tower; video frames are just images at lower resolution
    (max_soft_tokens=70).

    Returns one concatenated embedding tensor per video (not per
    frame), because vLLM treats one video as one multimodal item.
    The flat_from_sizes field config groups all frames of a video
    together, so embed_multimodal must return one tensor per video.
    """
    pixel_values = video_input["pixel_values_videos"]
    pixel_position_ids = video_input["pixel_position_ids_videos"]
    frame_counts = video_input["video_frame_counts"]

    vt = self.vision_tower
    pooling_k2 = self.config.vision_config.pooling_kernel_size**2
    target_dtype = self.embed_vision.embedding_projection.weight.dtype

    # Split flat tensors into per-video chunks
    if isinstance(frame_counts, torch.Tensor):
        fc_list = frame_counts.tolist()
    else:
        fc_list = list(frame_counts)

    pv_per_video = torch.split(pixel_values, fc_list, dim=0)
    pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0)

    per_video_embeddings = []
    for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video):
        frame_embs = []
        for i in range(pv_chunk.shape[0]):
            pv = pv_chunk[i].unsqueeze(0)
            pp = pp_chunk[i].unsqueeze(0)

            max_patches = pv.shape[1]
            output_length = max_patches // pooling_k2

            vt_output = vt(pv, pp, output_length=output_length)
            frame_emb = self.embed_vision(
                inputs_embeds=(
                    vt_output.last_hidden_state.unsqueeze(0).to(target_dtype)
                )
            ).squeeze(0)
            frame_embs.append(frame_emb)

        # Concatenate all frames of this video into one tensor.
        per_video_embeddings.append(torch.cat(frame_embs, dim=0))

    return per_video_embeddings

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix mapping for multimodal models.

Source code in vllm/model_executor/models/gemma4_mm.py
def get_mm_mapping(self) -> MultiModelKeys:
    """Get the module prefix mapping for multimodal models."""
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector=["embed_vision", "embed_audio"],
        tower_model=["vision_tower", "audio_tower"],
    )

Gemma4ImagePixelInputs

Bases: TensorSchema

Pre-patchified image inputs from the Gemma4 image processor.

Dimensions
  • bn: Batch size * number of images
  • np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²)
  • pp: Patch pixels (patch_size² * 3)

The HF Gemma4ImageProcessor outputs pixel_values as (batch, max_patches, patch_pixels) — already patchified with zero-padding for patches beyond the real image content. pixel_position_ids provides (x, y) coordinates per patch, with (-1, -1) for padding patches.

Source code in vllm/model_executor/models/gemma4_mm.py
class Gemma4ImagePixelInputs(TensorSchema):
    """
    Pre-patchified image inputs from the Gemma4 image processor.

    Dimensions:
        - bn: Batch size * number of images
        - np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²)
        - pp: Patch pixels (patch_size² * 3)

    The HF Gemma4ImageProcessor outputs pixel_values as
    (batch, max_patches, patch_pixels) — already patchified with
    zero-padding for patches beyond the real image content.
    pixel_position_ids provides (x, y) coordinates per patch,
    with (-1, -1) for padding patches.
    """

    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("bn", "np", "pp"),
    ]
    pixel_position_ids: Annotated[
        torch.Tensor,
        TensorShape("bn", "np", 2),
    ]

Gemma4MultimodalEmbedder

Bases: Module

Projects vision/audio soft tokens into LM embedding space.

Architecture

inputs_embeds → embedding_projection → embedding_post_projection_norm

Unlike Gemma3n which has separate hard/soft embedding paths with per-path normalization and a learned embedding table, Gemma4 uses a simplified 2-layer design: a linear projection followed by RMSNorm (without learnable scale). The checkpoint confirms this — only embedding_projection.weight exists; there is no embedding table or pre-projection norm weights.

Source code in vllm/model_executor/models/gemma4_mm.py
class Gemma4MultimodalEmbedder(nn.Module):
    """Projects vision/audio soft tokens into LM embedding space.

    Architecture:
        inputs_embeds → embedding_projection → embedding_post_projection_norm

    Unlike Gemma3n which has separate hard/soft embedding paths with
    per-path normalization and a learned embedding table, Gemma4 uses a
    simplified 2-layer design: a linear projection followed by RMSNorm
    (without learnable scale).  The checkpoint confirms this — only
    ``embedding_projection.weight`` exists; there is no embedding table
    or pre-projection norm weights.
    """

    def __init__(
        self,
        multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig,
        text_config: Gemma4TextConfig,
    ):
        super().__init__()

        self.eps = multimodal_config.rms_norm_eps
        self.text_hidden_size = text_config.hidden_size

        # Audio tower uses output_proj_dims (1536) rather than hidden_size
        # (1024); vision uses hidden_size (768) directly.
        embedding_dim = (
            getattr(multimodal_config, "output_proj_dims", None)
            or multimodal_config.hidden_size
        )

        self.embedding_projection = ReplicatedLinear(
            embedding_dim,
            self.text_hidden_size,
            bias=False,
        )

        self.embedding_post_projection_norm = RMSNorm(
            self.text_hidden_size,
            eps=self.eps,
            has_weight=False,
        )

    def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
        """Project soft tokens from a multimodal tower into LM space."""
        embs_proj, _ = self.embedding_projection(inputs_embeds)
        return self.embedding_post_projection_norm(embs_proj)

forward

forward(inputs_embeds: Tensor) -> Tensor

Project soft tokens from a multimodal tower into LM space.

Source code in vllm/model_executor/models/gemma4_mm.py
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
    """Project soft tokens from a multimodal tower into LM space."""
    embs_proj, _ = self.embedding_projection(inputs_embeds)
    return self.embedding_post_projection_norm(embs_proj)

Gemma4ProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/gemma4_mm.py
class Gemma4ProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Gemma4Config)

    def get_default_tok_params(self):
        """Gemma4's chat template already embeds a literal ``<bos>`` token in
        the rendered text.  If ``add_special_tokens=True`` (the base-class
        default), the tokenizer prepends *another* BOS, producing a
        ``[2, 2, ...]`` double-BOS sequence that the model was not trained on.

        Setting ``add_special_tokens=False`` here prevents the duplicate and
        ensures both ``llm.generate()`` and the chat/completions API behave
        correctly.
        """
        params = super().get_default_tok_params()
        params = params.with_kwargs(add_special_tokens=False)
        return params

    def get_hf_processor(self, **kwargs: object) -> Gemma4Processor:
        return self.ctx.get_hf_processor(
            Gemma4Processor,
            **kwargs,
        )

    def validate_num_items(self, modality: str, num_items: int) -> None:
        if (
            modality == "audio"
            and num_items > 0
            and self.get_hf_config().audio_config is None
        ):
            model = self.ctx.model_config.model
            raise ValueError(
                f"Audio input was provided but the model "
                f"'{model}' does not have an audio tower. "
                f"Audio inference is only supported for Gemma4 "
                f"models that include an audio_config "
                f"(i.e., models that include an audio_config)."
            )
        super().validate_num_items(modality, num_items)

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        limits: dict[str, int | None] = {"image": None}
        if self.get_hf_config().audio_config is not None:
            limits["audio"] = None
        limits["video"] = None
        return limits

    def get_mm_max_tokens_per_item(
        self, seq_len: int, mm_counts: Mapping[str, int]
    ) -> Mapping[str, int] | None:
        config = self.get_hf_config()
        # Upper bound: the pooler outputs default_output_length slots
        # per image (280).  After padding is stripped the actual count
        # is ≤ this value, but vLLM needs the max for memory planning.
        tokens_per_image = config.vision_config.default_output_length
        tokens: dict[str, int] = {"image": tokens_per_image}
        if config.audio_config is not None:
            # Audio max tokens from the processor's audio_seq_length.
            processor = self.get_hf_processor()
            tokens["audio"] = processor.audio_seq_length
        # Video: each frame ≤ 70 soft tokens + boi + eoi + ~6 ts tokens.
        tokens["video"] = _VIDEO_MAX_FRAMES * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6)
        return tokens

    def get_data_parser(self) -> MultiModalDataParser:
        config = self.get_hf_config()
        kwargs: dict[str, Any] = {"video_needs_metadata": True}
        if getattr(config, "audio_config", None) is not None:
            processor = self.get_hf_processor()
            kwargs["target_sr"] = processor.feature_extractor.sampling_rate
        return MultiModalDataParser(**kwargs)

    def _compute_num_soft_tokens(
        self,
        image_width: int,
        image_height: int,
        max_soft_tokens: int | None = None,
    ) -> int:
        """Compute the number of soft tokens the vision tower produces
        for an image of the given dimensions, after padding is stripped.

        Args:
            max_soft_tokens: Override for the vision config's
                ``default_output_length``.  When *None*, the value from
                the model config is used.
        """
        vision_cfg = self.get_hf_config().vision_config
        patch_size = vision_cfg.patch_size
        pooling_kernel_size = vision_cfg.pooling_kernel_size

        if max_soft_tokens is None:
            max_soft_tokens = vision_cfg.default_output_length

        unit = patch_size * pooling_kernel_size
        max_patches = max_soft_tokens * pooling_kernel_size**2
        num_patches_orig = (image_height / patch_size) * (image_width / patch_size)
        scale = math.sqrt(max_patches / num_patches_orig)
        target_h = max(unit, int(math.floor(image_height * scale / unit)) * unit)
        target_w = max(unit, int(math.floor(image_width * scale / unit)) * unit)
        num_patches = (target_h // patch_size) * (target_w // patch_size)
        return num_patches // (pooling_kernel_size**2)

    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Gemma4Processor | None,
        max_soft_tokens: int | None = None,
    ) -> PromptUpdateDetails[list[int]]:
        """Return the dynamic image token sequence for this image.

        Computes the exact number of soft tokens the vision tower will
        produce after stripping padding.

        Args:
            max_soft_tokens: Override for the default token budget.
                When *None*, falls back to the model config value.
        """
        if processor is None:
            processor = self.get_hf_processor()

        num_soft = self._compute_num_soft_tokens(
            image_width,
            image_height,
            max_soft_tokens=max_soft_tokens,
        )
        config = self.get_hf_config()
        token_ids = (
            [config.boi_token_id]
            + [processor.image_token_id] * num_soft
            + [config.eoi_token_id]
        )
        return PromptUpdateDetails.select_token_id(token_ids, processor.image_token_id)

    def get_audio_repl(
        self,
        *,
        audio_len: int,
        processor: Gemma4Processor | None,
    ) -> PromptUpdateDetails[list[int]]:
        """Return the dynamic audio token sequence for this audio.

        Computes the number of soft tokens from the audio waveform
        length using ``ceil(duration_ms / audio_ms_per_token)``.
        """
        if processor is None:
            processor = self.get_hf_processor()

        sampling_rate = processor.feature_extractor.sampling_rate
        num_tokens = processor._compute_audio_num_tokens(
            torch.zeros(audio_len), sampling_rate
        )
        config = self.get_hf_config()
        token_ids = (
            [config.boa_token_id]
            + [processor.audio_token_id] * num_tokens
            + [config.eoa_token_id]
        )
        return PromptUpdateDetails.select_token_id(token_ids, processor.audio_token_id)

    def get_video_repl(
        self,
        *,
        timestamps: list[float],
        num_soft_tokens_per_frame: list[int],
        processor: Gemma4Processor,
    ) -> PromptUpdateDetails[list[int]]:
        """Build the full token replacement for one video.

        Produces the same interleaved sequence as the HF Gemma4Processor:
            mm:ss <boi><|video|>*N<eoi> mm:ss <boi><|video|>*N<eoi> ...
        """
        tokenizer = self.ctx.get_tokenizer()
        config = self.get_hf_config()

        boi_token_id = config.boi_token_id
        eoi_token_id = config.eoi_token_id
        video_token_id = processor.video_token_id

        all_token_ids: list[int] = []
        for i, (ts, n_tokens) in enumerate(zip(timestamps, num_soft_tokens_per_frame)):
            # mm:ss timestamp — matches transformers: int-truncated,
            # zero-padded.
            minutes = int(ts // 60)
            seconds = int(ts % 60)
            ts_str = f"{minutes:02d}:{seconds:02d}"

            prefix = f" {ts_str} " if i > 0 else f"{ts_str} "
            ts_token_ids = tokenizer.encode(prefix, add_special_tokens=False)
            all_token_ids.extend(ts_token_ids)

            all_token_ids.append(boi_token_id)
            all_token_ids.extend([video_token_id] * n_tokens)
            all_token_ids.append(eoi_token_id)

        return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)

_compute_num_soft_tokens

_compute_num_soft_tokens(
    image_width: int,
    image_height: int,
    max_soft_tokens: int | None = None,
) -> int

Compute the number of soft tokens the vision tower produces for an image of the given dimensions, after padding is stripped.

Parameters:

Name Type Description Default
max_soft_tokens int | None

Override for the vision config's default_output_length. When None, the value from the model config is used.

None
Source code in vllm/model_executor/models/gemma4_mm.py
def _compute_num_soft_tokens(
    self,
    image_width: int,
    image_height: int,
    max_soft_tokens: int | None = None,
) -> int:
    """Compute the number of soft tokens the vision tower produces
    for an image of the given dimensions, after padding is stripped.

    Args:
        max_soft_tokens: Override for the vision config's
            ``default_output_length``.  When *None*, the value from
            the model config is used.
    """
    vision_cfg = self.get_hf_config().vision_config
    patch_size = vision_cfg.patch_size
    pooling_kernel_size = vision_cfg.pooling_kernel_size

    if max_soft_tokens is None:
        max_soft_tokens = vision_cfg.default_output_length

    unit = patch_size * pooling_kernel_size
    max_patches = max_soft_tokens * pooling_kernel_size**2
    num_patches_orig = (image_height / patch_size) * (image_width / patch_size)
    scale = math.sqrt(max_patches / num_patches_orig)
    target_h = max(unit, int(math.floor(image_height * scale / unit)) * unit)
    target_w = max(unit, int(math.floor(image_width * scale / unit)) * unit)
    num_patches = (target_h // patch_size) * (target_w // patch_size)
    return num_patches // (pooling_kernel_size**2)

get_audio_repl

get_audio_repl(
    *, audio_len: int, processor: Gemma4Processor | None
) -> PromptUpdateDetails[list[int]]

Return the dynamic audio token sequence for this audio.

Computes the number of soft tokens from the audio waveform length using ceil(duration_ms / audio_ms_per_token).

Source code in vllm/model_executor/models/gemma4_mm.py
def get_audio_repl(
    self,
    *,
    audio_len: int,
    processor: Gemma4Processor | None,
) -> PromptUpdateDetails[list[int]]:
    """Return the dynamic audio token sequence for this audio.

    Computes the number of soft tokens from the audio waveform
    length using ``ceil(duration_ms / audio_ms_per_token)``.
    """
    if processor is None:
        processor = self.get_hf_processor()

    sampling_rate = processor.feature_extractor.sampling_rate
    num_tokens = processor._compute_audio_num_tokens(
        torch.zeros(audio_len), sampling_rate
    )
    config = self.get_hf_config()
    token_ids = (
        [config.boa_token_id]
        + [processor.audio_token_id] * num_tokens
        + [config.eoa_token_id]
    )
    return PromptUpdateDetails.select_token_id(token_ids, processor.audio_token_id)

get_default_tok_params

get_default_tok_params()

Gemma4's chat template already embeds a literal <bos> token in the rendered text. If add_special_tokens=True (the base-class default), the tokenizer prepends another BOS, producing a [2, 2, ...] double-BOS sequence that the model was not trained on.

Setting add_special_tokens=False here prevents the duplicate and ensures both llm.generate() and the chat/completions API behave correctly.

Source code in vllm/model_executor/models/gemma4_mm.py
def get_default_tok_params(self):
    """Gemma4's chat template already embeds a literal ``<bos>`` token in
    the rendered text.  If ``add_special_tokens=True`` (the base-class
    default), the tokenizer prepends *another* BOS, producing a
    ``[2, 2, ...]`` double-BOS sequence that the model was not trained on.

    Setting ``add_special_tokens=False`` here prevents the duplicate and
    ensures both ``llm.generate()`` and the chat/completions API behave
    correctly.
    """
    params = super().get_default_tok_params()
    params = params.with_kwargs(add_special_tokens=False)
    return params

get_image_repl

get_image_repl(
    *,
    image_width: int,
    image_height: int,
    processor: Gemma4Processor | None,
    max_soft_tokens: int | None = None,
) -> PromptUpdateDetails[list[int]]

Return the dynamic image token sequence for this image.

Computes the exact number of soft tokens the vision tower will produce after stripping padding.

Parameters:

Name Type Description Default
max_soft_tokens int | None

Override for the default token budget. When None, falls back to the model config value.

None
Source code in vllm/model_executor/models/gemma4_mm.py
def get_image_repl(
    self,
    *,
    image_width: int,
    image_height: int,
    processor: Gemma4Processor | None,
    max_soft_tokens: int | None = None,
) -> PromptUpdateDetails[list[int]]:
    """Return the dynamic image token sequence for this image.

    Computes the exact number of soft tokens the vision tower will
    produce after stripping padding.

    Args:
        max_soft_tokens: Override for the default token budget.
            When *None*, falls back to the model config value.
    """
    if processor is None:
        processor = self.get_hf_processor()

    num_soft = self._compute_num_soft_tokens(
        image_width,
        image_height,
        max_soft_tokens=max_soft_tokens,
    )
    config = self.get_hf_config()
    token_ids = (
        [config.boi_token_id]
        + [processor.image_token_id] * num_soft
        + [config.eoi_token_id]
    )
    return PromptUpdateDetails.select_token_id(token_ids, processor.image_token_id)

get_video_repl

get_video_repl(
    *,
    timestamps: list[float],
    num_soft_tokens_per_frame: list[int],
    processor: Gemma4Processor,
) -> PromptUpdateDetails[list[int]]

Build the full token replacement for one video.

Produces the same interleaved sequence as the HF Gemma4Processor

mm:ss <|video|>N mm:ss <|video|>N ...

Source code in vllm/model_executor/models/gemma4_mm.py
def get_video_repl(
    self,
    *,
    timestamps: list[float],
    num_soft_tokens_per_frame: list[int],
    processor: Gemma4Processor,
) -> PromptUpdateDetails[list[int]]:
    """Build the full token replacement for one video.

    Produces the same interleaved sequence as the HF Gemma4Processor:
        mm:ss <boi><|video|>*N<eoi> mm:ss <boi><|video|>*N<eoi> ...
    """
    tokenizer = self.ctx.get_tokenizer()
    config = self.get_hf_config()

    boi_token_id = config.boi_token_id
    eoi_token_id = config.eoi_token_id
    video_token_id = processor.video_token_id

    all_token_ids: list[int] = []
    for i, (ts, n_tokens) in enumerate(zip(timestamps, num_soft_tokens_per_frame)):
        # mm:ss timestamp — matches transformers: int-truncated,
        # zero-padded.
        minutes = int(ts // 60)
        seconds = int(ts % 60)
        ts_str = f"{minutes:02d}:{seconds:02d}"

        prefix = f" {ts_str} " if i > 0 else f"{ts_str} "
        ts_token_ids = tokenizer.encode(prefix, add_special_tokens=False)
        all_token_ids.extend(ts_token_ids)

        all_token_ids.append(boi_token_id)
        all_token_ids.extend([video_token_id] * n_tokens)
        all_token_ids.append(eoi_token_id)

    return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id)

Gemma4VideoInputs

Bases: TensorSchema

Video frame inputs — same tensor format as image inputs.

Gemma4 has no separate video tower; video frames are processed through the vision tower at lower resolution (max_soft_tokens=70).

Source code in vllm/model_executor/models/gemma4_mm.py
class Gemma4VideoInputs(TensorSchema):
    """Video frame inputs — same tensor format as image inputs.

    Gemma4 has no separate video tower; video frames are processed
    through the vision tower at lower resolution (max_soft_tokens=70).
    """

    type: Literal["pixel_values_videos"] = "pixel_values_videos"
    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("bn", "np", "pp"),
    ]
    pixel_position_ids_videos: Annotated[
        torch.Tensor,
        TensorShape("bn", "np", 2),
    ]