@support_torch_compile
class Gemma4Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = _get_text_config(vllm_config.model_config.hf_config)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
# PLE config values (default to 0 if not present — disables PLE)
self.hidden_size_per_layer_input = getattr(
config, "hidden_size_per_layer_input", 0
)
self.vocab_size_per_layer_input = getattr(
config, "vocab_size_per_layer_input", config.vocab_size
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
# Per-Layer Embedding (PLE) components
if (
self.hidden_size_per_layer_input is not None
and self.hidden_size_per_layer_input > 0
):
total_ple_dim = self.hidden_size_per_layer_input * config.num_hidden_layers
self.embed_tokens_per_layer = VocabParallelEmbedding(
self.vocab_size_per_layer_input,
total_ple_dim,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens_per_layer",
)
# Scaled embedding factor (from config, not hardcoded)
# Register as buffer so it moves to GPU with the model
# and interacts correctly with torch.compile AOT caching.
self.register_buffer(
"embed_scale_per_layer",
torch.tensor(self.hidden_size_per_layer_input**0.5),
persistent=False,
)
# Projection: hidden_size → total_ple_dim
# ColumnParallelLinear with gather_output=True
self.per_layer_model_projection = ColumnParallelLinear(
config.hidden_size,
total_ple_dim,
bias=False,
gather_output=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.per_layer_model_projection",
)
# PLE projection norm: output = norm(x) * weight
self.per_layer_projection_norm = RMSNorm(
self.hidden_size_per_layer_input,
eps=config.rms_norm_eps,
)
# Scale factor for combining projection + per_layer_inputs
# Register as buffer so it moves to GPU with the model
# and interacts correctly with torch.compile AOT caching.
self.register_buffer(
"per_layer_input_scale",
torch.rsqrt(torch.tensor(2.0)),
persistent=False,
)
# Scaled projection: multiply output by hidden_size**-0.5.
# Register as buffer for GPU placement and torch.compile.
self.register_buffer(
"per_layer_projection_scale",
torch.tensor(config.hidden_size**-0.5),
persistent=False,
)
else:
self.embed_tokens_per_layer = None
self.embed_scale_per_layer = None
self.per_layer_model_projection = None
self.per_layer_projection_norm = None
self.per_layer_input_scale = None
self.per_layer_projection_scale = None
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma4DecoderLayer(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
# Final norm: output = norm(x) * weight
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Embedding scale = sqrt(hidden_size)
# Downcast to model dtype (bfloat16 etc.) for numerical parity
self.register_buffer(
"normalizer",
torch.tensor(config.hidden_size**0.5),
persistent=False,
)
# Custom factory that includes per_layer_inputs for PLE-enabled PP.
# per_layer_inputs has shape (batch, num_layers, per_layer_dim),
# which differs from the standard (batch, hidden_size) shape,
# so we can't use the default factory.
ple_dim = self.hidden_size_per_layer_input
num_layers = config.num_hidden_layers
hidden_size = config.hidden_size
def _make_empty_intermediate_tensors(
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> IntermediateTensors:
tensors: dict[str, torch.Tensor] = {
"hidden_states": torch.zeros(
(batch_size, hidden_size),
dtype=dtype,
device=device,
),
"residual": torch.zeros(
(batch_size, hidden_size),
dtype=dtype,
device=device,
),
}
if ple_dim and ple_dim > 0:
tensors["per_layer_inputs"] = torch.zeros(
(batch_size, num_layers, ple_dim),
dtype=dtype,
device=device,
)
return IntermediateTensors(tensors)
self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) * self.normalizer
def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Get per-layer embeddings from embed_tokens_per_layer.
Returns:
Per-layer embeddings (num_tokens, num_layers,
hidden_size_per_layer_input)
"""
if self.embed_tokens_per_layer is None:
return None
# Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may
# be smaller than the main vocab_size).
per_layer_inputs_mask = torch.logical_and(
input_ids >= 0,
input_ids < self.vocab_size_per_layer_input,
)
per_layer_inputs_tokens = torch.where(
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
)
# Get packed per-layer embeddings: (num_tokens, total_ple_dim)
per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens)
# Apply embed_scale (sqrt of per-layer hidden dim)
per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_embeds = per_layer_embeds.reshape(
*input_ids.shape,
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
return per_layer_embeds
def project_per_layer_inputs(
self,
inputs_embeds: torch.Tensor,
per_layer_inputs: torch.Tensor | None,
) -> torch.Tensor:
"""Project inputs_embeds and combine with per_layer_inputs.
Steps:
1. Project inputs_embeds: hidden_size → total_ple_dim
2. Scale by hidden_size^{-0.5}
3. Reshape to (num_tokens, num_layers, per_layer_dim)
4. Normalize with per_layer_projection_norm
5. Combine: (projection + per_layer_inputs) * 1/sqrt(2)
"""
if self.per_layer_model_projection is None:
return None
# Project from hidden_size to total_ple_dim
# Scaled projection: output = linear(input, weight) * scale
per_layer_projection = self.per_layer_model_projection(inputs_embeds)
per_layer_projection = per_layer_projection * self.per_layer_projection_scale
# Reshape to (num_tokens, num_layers, hidden_size_per_layer_input)
per_layer_projection = per_layer_projection.reshape(
*inputs_embeds.shape[:-1],
self.config.num_hidden_layers,
self.hidden_size_per_layer_input,
)
# Normalize
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
if per_layer_inputs is None:
return per_layer_projection
# Combine: (projection + per_layer_inputs) * scale
return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
per_layer_inputs: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
# When called from the multimodal wrapper, raw PLE
# embeddings are pre-computed and passed explicitly.
# Project them through per_layer_model_projection.
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_inputs
)
else:
hidden_states = self.embed_input_ids(input_ids)
# Compute per-layer inputs for PLE
per_layer_embeds = self.get_per_layer_inputs(input_ids)
per_layer_inputs = self.project_per_layer_inputs(
hidden_states, per_layer_embeds
)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
per_layer_inputs = intermediate_tensors.get("per_layer_inputs")
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
# Extract the per-layer embedding for this specific layer
if per_layer_inputs is not None:
actual_layer_idx = self.start_layer + layer_idx
layer_per_input = per_layer_inputs[
:, actual_layer_idx, :
] # (num_tokens, per_layer_dim)
else:
layer_per_input = None
hidden_states, residual = layer(
positions,
hidden_states,
residual,
per_layer_input=layer_per_input,
**kwargs,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{
"hidden_states": hidden_states,
"residual": residual,
"per_layer_inputs": per_layer_inputs,
}
)
# Gemma4 incorporates residual into hidden_states directly
# Apply norm without residual fusion when possible.
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# MoE expert weight mapping: checkpoint 3D packed tensors are
# exploded in _weight_iterator to per-expert 2D weights like:
# moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13)
# moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13)
# moe.experts.{id}.down_proj → FusedMoE w2
# We build the mapping directly since Gemma4 uses bare param
# names (no .weight suffix) unlike standard MoE checkpoints.
num_experts = getattr(self.config, "num_experts", None) or 0
expert_params_mapping = [
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if proj_name in ["gate_proj", "up_proj"]
else "experts.w2_weight",
f"experts.{expert_id}.{proj_name}",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, proj_name in [
("w1", "gate_proj"),
("w2", "down_proj"),
("w3", "up_proj"),
]
]
params_dict = dict(self.named_parameters())
# Include buffers (e.g. layer_scalar) so they can be loaded too
params_dict.update(dict(self.named_buffers()))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
param = params_dict[remapped_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
stacked_name = name.replace(shard_name, param_name)
# k_eq_v layers use separate q_proj/k_proj instead of
# packed qkv_proj. If the stacked param doesn't exist,
# skip this mapping and fall through to direct load.
if stacked_name not in params_dict:
continue
if is_pp_missing_parameter(stacked_name, self):
continue
param = params_dict[stacked_name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(stacked_name)
break
else:
for (
param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name:
continue
moe_name = name.replace(weight_name, param_name)
if moe_name not in params_dict:
continue
if is_pp_missing_parameter(moe_name, self):
continue
param = params_dict[moe_name]
# Expert weights are already in the correct
# orientation for FusedMoE after _weight_iterator:
# gate/up: [I, H] → w1/w3 expects [I, H]
# down: [H, I] → w2 expects [H, I]
assert loaded_weight.dim() == 2, (
f"Expected 2D expert weight for {weight_name}, "
f"got shape {loaded_weight.shape}"
)
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
weight_name + ".weight",
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(moe_name)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params