Skip to content

vllm.model_executor.layers.fused_moe.runner.default_moe_runner

DefaultMoERunner

Bases: MoERunner

Default implementation of the MoE runner for executing Mixture of Experts layers.

This class provides a comprehensive implementation for running MoE computations with support for: - Expert routing and token dispatching - Shared experts computation with optional parallel execution using CUDA streams - Data parallel (DP) chunking for large batch processing - Tensor model parallel and expert parallel operations - Various quantization methods and custom operators - Both monolithic and decomposed expert execution paths

The runner handles the complete MoE forward pass including routing tokens to experts, executing expert computations, and combining results. It supports advanced features like overlapped execution of shared experts and optimized kernels for different parallel execution modes.

Eventually, this class will be split up and specialized for different configurations, e.g. the presence or absence of shared experts, a gate, etc.

Source code in vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
class DefaultMoERunner(MoERunner):
    """
    Default implementation of the MoE runner for executing Mixture of Experts layers.

    This class provides a comprehensive implementation for running MoE computations
    with support for:
    - Expert routing and token dispatching
    - Shared experts computation with optional parallel execution using CUDA streams
    - Data parallel (DP) chunking for large batch processing
    - Tensor model parallel and expert parallel operations
    - Various quantization methods and custom operators
    - Both monolithic and decomposed expert execution paths

    The runner handles the complete MoE forward pass including routing tokens to
    experts, executing expert computations, and combining results. It supports
    advanced features like overlapped execution of shared experts and optimized
    kernels for different parallel execution modes.

    Eventually, this class will be split up and specialized for different
    configurations, e.g. the presence or absence of shared experts, a gate, etc.
    """

    def __init__(
        self,
        layer: torch.nn.Module,
        moe_config: FusedMoEConfig,
        router: FusedMoERouter,
        routed_input_transform: torch.nn.Module | None,
        gate: torch.nn.Module | None,
        shared_experts: torch.nn.Module | None,
        quant_method: FusedMoEMethodBase,
        reduce_results: bool,
        enable_dbo: bool,
    ):
        super().__init__()
        self.moe_config = moe_config
        self.router = router
        self.routed_input_transform = routed_input_transform
        self.gate = gate
        self.quant_method = quant_method
        self.reduce_results = reduce_results
        self.enable_dbo = enable_dbo

        self.shared_experts: SharedExperts | None = None
        if shared_experts is not None:
            self.shared_experts = SharedExperts(
                shared_experts,
                moe_config=moe_config,
                # Note: For now we must pass quant_method along to SharedExperts so it
                # can property determine where the shared experts are supposed to be
                # called, i.e. by a MK or by the MoERunner.
                # Once the MK can be created upfront, we can just pass in the proper
                # flags derived from the quant_method's MK.
                reduce_results=reduce_results,
                quant_method=quant_method,
                enable_dbo=enable_dbo,
            )

        # Chunked all2all staging tensor
        # These need to exist ahead of time due to CUDAgraph construction
        # needing a fixed buffer address.
        self.use_dp_chunking = self.moe_config.moe_parallel_config.use_dp_chunking
        self.batched_hidden_states: torch.Tensor | None = None
        self.batched_router_logits: torch.Tensor | None = None
        self._maybe_init_dp_chunking()

        # Needed for string -> FusedMoE layer lookup in custom ops.
        self.layer_name = layer.layer_name

        self.forward_entry, self.forward_impl = self._select_forward(layer)

    def _select_forward(self, layer: torch.nn.Module) -> tuple[Callable, Callable]:
        # Select implementation based on presence of DP chunking.
        forward_impl_fn = (
            self._forward_impl_chunked if self.use_dp_chunking else self._forward_impl
        )

        if current_platform.is_tpu() or current_platform.is_cpu():
            # TODO: Once the OOM issue for the TPU backend is resolved, we
            # will switch to using the moe_forward custom op.
            # Note: CPU doesn't require wrapped forward_impl.
            return (
                _moe_forward if self.shared_experts is None else _moe_forward_shared,
                forward_impl_fn,
            )

        return (
            torch.ops.vllm.moe_forward
            if self.shared_experts is None
            else torch.ops.vllm.moe_forward_shared,
            forward_impl_fn,
        )

    # TODO(bnell): temporary hack, do not call this method.
    def _replace_quant_method(self, quant_method: FusedMoEMethodBase):
        if self.shared_experts is not None:
            self.shared_experts._quant_method = quant_method
        self.quant_method = quant_method

    def is_internal_router(self) -> bool:
        return self.gate is not None

    def _maybe_init_dp_chunking(self):
        if not self.use_dp_chunking:
            return

        assert self.batched_hidden_states is None
        states_shape: tuple[int, ...]
        logits_shape: tuple[int, ...]

        moe = self.moe_config

        if self.enable_dbo:
            states_shape = (2, moe.max_num_tokens, self.moe_config.hidden_dim)
            logits_shape = (2, moe.max_num_tokens, self.moe_config.num_logical_experts)
        else:
            states_shape = (moe.max_num_tokens, self.moe_config.hidden_dim)
            logits_shape = (moe.max_num_tokens, self.moe_config.num_logical_experts)

        device = torch.accelerator.current_device_index()
        self.batched_hidden_states = torch.zeros(
            states_shape,
            dtype=moe.in_dtype,
            device=device,
        )

        self.batched_router_logits = torch.zeros(
            logits_shape,
            dtype=moe.router_logits_dtype,
            device=device,
        )

    def must_reduce_shared_expert_outputs(self) -> bool:
        """
        The shared_experts are typically computed using the RowParallelLinear
        layer. The result of this function is typically used as
        the reduce_results argument to the module.
        When just tensor-parallel is used, it is not required to reduce
        the shared_experts results immediately. Instead we reduce at the
        once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
        With EP and all2all kernels - this is no longer viable as all
        GPU ranks in DP, produce the complete set of hidden_states.
        Therefore it is required that we reduce the shared_experts output
        early.
        """
        return (
            self.quant_method.moe_kernel is not None
            and self.quant_method.moe_kernel.output_is_reduced()
        )

    def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
        """
        Some combine kernels reduce across GPU ranks by default.
        """
        if self.must_reduce_shared_expert_outputs():
            return final_hidden_states
        else:
            return tensor_model_parallel_all_reduce(final_hidden_states)

    def apply_routed_input_transform(
        self, hidden_states: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        """Apply transform for routed experts (e.g., latent projection).

        This is called by FusedMoE.forward_native. The original hidden_states
        is saved separately so shared experts get [S, hidden_size] while
        routed experts get the transformed [S, moe_latent_size].

        TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
        moved inside SharedFusedMoE to all-reduce on the smaller latent
        dimension.

        Returns (possibly transformed) hidden states and the input for shared
        experts (or None if there are no shared experts).
        """
        if self.routed_input_transform is not None:
            result = self.routed_input_transform(hidden_states)
            # ReplicatedLinear returns (output, extra_bias) tuple.
            # We only need the output tensor; extra_bias is not used here.
            if isinstance(result, tuple):
                return result[0], hidden_states
            return result, hidden_states

        return (
            hidden_states,
            hidden_states if self.shared_experts is not None else None,
        )

    def _maybe_reduce_output(
        self,
        states: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
        trunc_sizes: list[int],
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        def trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
            return x[..., :trunc_size]

        def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
            return trunc(self.maybe_all_reduce_tensor_model_parallel(x), trunc_size)

        if (
            not self.moe_config.is_sequence_parallel
            and not self.use_dp_chunking
            and self.reduce_results
            and (self.moe_config.tp_size > 1 or self.moe_config.ep_size > 1)
        ):
            func = reduce_and_trunc
        else:
            func = trunc

        if isinstance(states, tuple):
            return tuple(
                [func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
            )
        else:
            assert len(trunc_sizes) == 1
            return func(states, trunc_sizes[0])

    def _encode_layer_name(self) -> str | ModuleName:
        if HAS_OPAQUE_TYPE:
            return ModuleName(self.layer_name)
        # Can be unavailable or None in unittests
        if (
            is_forward_context_available()
            and get_forward_context().all_moe_layers is not None
        ):
            return "from_forward_context"
        return self.layer_name

    def _maybe_pad_hidden_states(
        self,
        shared_experts_input: torch.Tensor | None,
        hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, list[int]]:
        shared_experts_hidden_dim = (
            shared_experts_input.shape[-1] if shared_experts_input is not None else 0
        )
        transformed_hidden_dim = hidden_states.shape[-1]
        if (
            not self.quant_method.skip_forward_padding
            and self.moe_config.hidden_dim != transformed_hidden_dim
        ):
            hidden_states = F.pad(
                hidden_states,
                (0, self.moe_config.hidden_dim - transformed_hidden_dim),
                mode="constant",
                value=0.0,
            )

        if self.shared_experts is not None:
            orig_hidden_dims = [shared_experts_hidden_dim, transformed_hidden_dim]
        else:
            orig_hidden_dims = [transformed_hidden_dim]

        return hidden_states, orig_hidden_dims

    def _maybe_apply_shared_experts(
        self,
        shared_experts_input: torch.Tensor | None,
        order: SharedExpertsOrder,
    ):
        if self.shared_experts is not None:
            assert shared_experts_input is not None
            self.shared_experts.apply(shared_experts_input, order)

    def _apply_quant_method(
        self,
        layer: torch.nn.Module,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        shared_experts_input: torch.Tensor | None,
    ) -> tuple[torch.Tensor | None, torch.Tensor]:
        # Run this before quant_method to avoid inplace issues.
        # TODO(bnell): probably not needed anymore since inplace is
        # disabled when shared experts are present.
        self._maybe_apply_shared_experts(
            shared_experts_input, SharedExpertsOrder.NO_OVERLAP
        )

        if self.quant_method.is_monolithic:
            fused_out = self.quant_method.apply_monolithic(
                layer=layer,
                x=hidden_states,
                router_logits=router_logits,
            )
        else:
            topk_weights, topk_ids = self.router.select_experts(
                hidden_states=hidden_states,
                router_logits=router_logits,
            )

            # Passing shared_experts_input in case SharedExpertsOrder is
            # NO_OVERLAP or MK_INTERNAL_OVERLAPPED.
            fused_out = self.quant_method.apply(
                layer=layer,
                x=hidden_states,
                topk_weights=topk_weights,
                topk_ids=topk_ids,
                shared_experts_input=shared_experts_input,
            )

        self._maybe_apply_shared_experts(
            shared_experts_input,
            SharedExpertsOrder.MULTI_STREAM_OVERLAPPED,
        )

        return (
            self.shared_experts.output if self.shared_experts is not None else None,
            fused_out,
        )

    def _sequence_parallel_context(self):
        ctx = get_forward_context()
        return (
            ctx.dp_metadata.sp_local_sizes(self.moe_config.sp_size)
            if ctx.dp_metadata
            else nullcontext()
        )

    def _allocate_dp_chunking_outputs(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor | None, torch.Tensor]:
        assert self.use_dp_chunking

        # Assert the inputs are of the proper type and shape.
        assert self.batched_hidden_states is not None
        assert self.batched_router_logits is not None

        assert self.batched_hidden_states.dtype == hidden_states.dtype, (
            f"{self.batched_hidden_states.dtype} == {hidden_states.dtype}"
        )
        assert self.batched_router_logits.dtype == router_logits.dtype, (
            f"{self.batched_router_logits.dtype} == {router_logits.dtype}"
        )

        # Check size compatibility.
        assert self.batched_hidden_states.size(-1) == hidden_states.size(-1)
        assert self.batched_router_logits.size(-1) == router_logits.size(-1)

        final_fused_hidden_states = torch.empty_like(hidden_states)
        if self.shared_experts is not None:
            final_shared_hidden_states = torch.empty_like(hidden_states)
        else:
            final_shared_hidden_states = None

        return final_shared_hidden_states, final_fused_hidden_states

    def _maybe_sync_shared_experts_stream(
        self,
        shared_experts_input: torch.Tensor | None,
    ):
        # If router/gate provided, then apply it here.
        # (Note: This code runs only when "overlapped mode" is on to allow
        #        parallel execution of shared experts with the FusedMoE via
        #        separate cuda stream)
        if self.shared_experts is not None:
            self.shared_experts.maybe_sync_shared_experts_stream(shared_experts_input)

    @property
    def do_naive_dispatch_combine(self) -> bool:
        return (
            self.moe_config.dp_size > 1 and not self.quant_method.supports_internal_mk
        )

    def _maybe_dispatch(
        self,
        layer: torch.nn.Module,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # For naive dispatch/combine Dp/Ep, dispatch the hidden states and
        # router logits to all experts.
        # NOTE: this will be removed once all kernels are migrated into the
        # MoEKernel framework.
        if self.do_naive_dispatch_combine:
            hidden_states, router_logits = get_ep_group().dispatch_router_logits(
                hidden_states,
                router_logits,
                self.moe_config.is_sequence_parallel,
            )

        # NOTE: Similar with DP, PCP also needs dispatch and combine. For
        # simplicity, AgRsAll2All was added separately for PCP here. Maybe
        # we should modify All2AllManager abstraction to better support PCP.
        if self.moe_config.pcp_size > 1:
            hidden_states = get_pcp_group().all_gather(
                hidden_states,
                dim=0,
            )
            router_logits = get_pcp_group().all_gather(
                router_logits,
                dim=0,
            )

        return hidden_states, router_logits

    def _maybe_combine(
        self,
        shared_output: torch.Tensor | None,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        if self.do_naive_dispatch_combine:
            hidden_states = get_ep_group().combine(
                hidden_states, self.moe_config.is_sequence_parallel
            )

        if self.moe_config.pcp_size > 1:
            hidden_states = get_pcp_group().reduce_scatter(
                hidden_states,
                dim=0,
            )

        if self.shared_experts is not None:
            assert shared_output is not None
            return shared_output, hidden_states
        else:
            return hidden_states

    def forward(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """Invoke the fused moe layer.

        Input:
        - hidden_states
        - router_logits

        Output:
        - The new hidden_states.
        or
        - A tuple of (shared experts output, new hidden_states).

        Calling sequence
        - forward
          - self.forward_entry (_moe_forward or _moe_forward_shared custom op)
            - forward_dispatch
              - forward_impl (_forward_impl or _forward_impl_chunked)

        Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
        to the following reasons:
        1. the chunking loop in _forward_impl_chunked cannot be compiled by
           torch.compile
        2. pytorch cannot handle union types in custom op signatures so _moe_forward
           and _moe_forward_shared must be split.

        If _forward_impl_chunked can be implemented via torch.scan we can potentially
        get rid of _moe_forward and _moe_forward_shared and collapse the whole sequence
        into the 'forward' method.
        """

        # Apply transform for routed experts (e.g., latent projection for latent MoE)
        hidden_states, shared_experts_input = self.apply_routed_input_transform(
            hidden_states
        )

        hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
            shared_experts_input,
            hidden_states,
        )

        fused_output = self.forward_entry(
            hidden_states,
            router_logits,
            shared_experts_input,
            self._encode_layer_name(),
        )

        return self._maybe_reduce_output(fused_output, og_hidden_dims)

    def forward_dispatch(
        self,
        layer: torch.nn.Module,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        shared_experts_input: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        # TODO(bnell): this can be removed after MK migration is complete.
        layer.ensure_moe_quant_config_init()

        # Sync aux and main stream for shared expert multi-stream overlap.
        self._maybe_sync_shared_experts_stream(shared_experts_input)

        # If the Runner holds the gate, apply it after the stream sync,
        # so it can run overlapped with the
        # NOTE: in future PR, MoE runner will always hold the gate.
        if self.gate is not None:
            router_logits, _ = self.gate(hidden_states)

        self._maybe_apply_shared_experts(
            shared_experts_input,
            SharedExpertsOrder.EXTERNAL,
        )

        with self._sequence_parallel_context():
            return self.forward_impl(
                layer,
                hidden_states,
                router_logits,
                shared_experts_input,
            )

    def _slice_and_copy_input(
        self,
        out_slice: torch.Tensor,
        orig: torch.Tensor | None,
        start: int,
        end: int,
    ) -> torch.Tensor:
        assert orig is not None
        slice_size = end - start
        orig_slice = orig[start:end, :]
        if self.enable_dbo:
            assert out_slice.dim() == 3
            batch_buffer_idx = dbo_current_ubatch_id()
            out_slice = out_slice[batch_buffer_idx, :]

        assert out_slice.size(0) >= slice_size
        out_slice = out_slice[:slice_size, :]
        out_slice.copy_(orig_slice, non_blocking=True)
        return out_slice

    def _forward_impl_chunked(
        self,
        layer: torch.nn.Module,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        shared_experts_input: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        final_shared_hidden_states, final_fused_hidden_states = (
            self._allocate_dp_chunking_outputs(hidden_states, router_logits)
        )

        ctx = get_forward_context()
        # flashinfer_cutlass_kernels can handle: optional DP + TP/EP
        max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
        moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens

        # If the input to the MoE is sequence parallel then divide by sp_size
        # to find the maximum number of tokens for any individual dispatcher.
        if self.moe_config.is_sequence_parallel:
            max_tokens_across_dispatchers = cdiv(
                max_tokens_across_dispatchers, self.moe_config.sp_size
            )

        num_tokens = hidden_states.size(0)
        for chunk_idx, chunk_start_ in enumerate(
            range(0, max_tokens_across_dispatchers, moe_dp_chunk_size_per_rank)
        ):
            chunk_start = chunk_start_
            chunk_end = min(
                chunk_start + moe_dp_chunk_size_per_rank, max_tokens_across_dispatchers
            )
            # clamp start and end
            chunk_start = min(chunk_start, num_tokens - 1)
            chunk_end = min(chunk_end, num_tokens)
            chunk_sizes = ctx.dp_metadata.chunked_sizes(
                self.moe_config.sp_size, moe_dp_chunk_size_per_rank, chunk_idx
            )
            with chunk_sizes:
                hidden_states_chunk = self._slice_and_copy_input(
                    self.batched_hidden_states,
                    hidden_states,
                    chunk_start,
                    chunk_end,
                )

                router_logits_chunk = self._slice_and_copy_input(
                    self.batched_router_logits,
                    router_logits,
                    chunk_start,
                    chunk_end,
                )

                shared_experts_input_chunk = (
                    shared_experts_input[chunk_start:chunk_end, :]
                    if shared_experts_input is not None
                    else None
                )

                shared_output_chunk, hidden_states_chunk = self._apply_quant_method(
                    layer=layer,
                    hidden_states=hidden_states_chunk,
                    router_logits=router_logits_chunk,
                    shared_experts_input=shared_experts_input_chunk,
                )

                # Store outputs
                # TODO(bnell): document when chunk_start >= num_tokens
                if chunk_start < num_tokens:
                    final_fused_hidden_states[chunk_start:chunk_end, :].copy_(
                        hidden_states_chunk, non_blocking=True
                    )
                    if self.shared_experts is not None:
                        assert shared_output_chunk is not None
                        assert final_shared_hidden_states is not None
                        final_shared_hidden_states[chunk_start:chunk_end, :].copy_(
                            shared_output_chunk, non_blocking=True
                        )

        if self.shared_experts is None:
            return final_fused_hidden_states
        else:
            assert final_shared_hidden_states is not None
            return (final_shared_hidden_states, final_fused_hidden_states)

    def _forward_impl(
        self,
        layer: torch.nn.Module,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        shared_experts_input: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        # TODO(bnell): parts of the dispatch/combine steps will go away once
        # #32567 lands and the remaining kernels are made MKs.  The PCP
        # code will probably remain
        hidden_states, router_logits = self._maybe_dispatch(
            layer,
            hidden_states,
            router_logits,
        )

        shared_output, hidden_states = self._apply_quant_method(
            layer=layer,
            hidden_states=hidden_states,
            router_logits=router_logits,
            shared_experts_input=shared_experts_input,
        )

        return self._maybe_combine(
            shared_output,
            hidden_states,
        )

apply_routed_input_transform

apply_routed_input_transform(
    hidden_states: Tensor,
) -> tuple[Tensor, Tensor | None]

Apply transform for routed experts (e.g., latent projection).

This is called by FusedMoE.forward_native. The original hidden_states is saved separately so shared experts get [S, hidden_size] while routed experts get the transformed [S, moe_latent_size].

TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be moved inside SharedFusedMoE to all-reduce on the smaller latent dimension.

Returns (possibly transformed) hidden states and the input for shared experts (or None if there are no shared experts).

Source code in vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
def apply_routed_input_transform(
    self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None]:
    """Apply transform for routed experts (e.g., latent projection).

    This is called by FusedMoE.forward_native. The original hidden_states
    is saved separately so shared experts get [S, hidden_size] while
    routed experts get the transformed [S, moe_latent_size].

    TODO: For latent MoE bandwidth optimization, fc2_latent_proj could be
    moved inside SharedFusedMoE to all-reduce on the smaller latent
    dimension.

    Returns (possibly transformed) hidden states and the input for shared
    experts (or None if there are no shared experts).
    """
    if self.routed_input_transform is not None:
        result = self.routed_input_transform(hidden_states)
        # ReplicatedLinear returns (output, extra_bias) tuple.
        # We only need the output tensor; extra_bias is not used here.
        if isinstance(result, tuple):
            return result[0], hidden_states
        return result, hidden_states

    return (
        hidden_states,
        hidden_states if self.shared_experts is not None else None,
    )

forward

forward(
    hidden_states: Tensor, router_logits: Tensor
) -> Tensor | tuple[Tensor, Tensor]

Invoke the fused moe layer.

Input: - hidden_states - router_logits

Output: - The new hidden_states. or - A tuple of (shared experts output, new hidden_states).

Calling sequence - forward - self.forward_entry (_moe_forward or _moe_forward_shared custom op) - forward_dispatch - forward_impl (_forward_impl or _forward_impl_chunked)

Note: The existence of _moe_forward and _moe_forward_shared custom ops are due to the following reasons: 1. the chunking loop in _forward_impl_chunked cannot be compiled by torch.compile 2. pytorch cannot handle union types in custom op signatures so _moe_forward and _moe_forward_shared must be split.

If _forward_impl_chunked can be implemented via torch.scan we can potentially get rid of _moe_forward and _moe_forward_shared and collapse the whole sequence into the 'forward' method.

Source code in vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
def forward(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Invoke the fused moe layer.

    Input:
    - hidden_states
    - router_logits

    Output:
    - The new hidden_states.
    or
    - A tuple of (shared experts output, new hidden_states).

    Calling sequence
    - forward
      - self.forward_entry (_moe_forward or _moe_forward_shared custom op)
        - forward_dispatch
          - forward_impl (_forward_impl or _forward_impl_chunked)

    Note: The existence of _moe_forward and _moe_forward_shared custom ops are due
    to the following reasons:
    1. the chunking loop in _forward_impl_chunked cannot be compiled by
       torch.compile
    2. pytorch cannot handle union types in custom op signatures so _moe_forward
       and _moe_forward_shared must be split.

    If _forward_impl_chunked can be implemented via torch.scan we can potentially
    get rid of _moe_forward and _moe_forward_shared and collapse the whole sequence
    into the 'forward' method.
    """

    # Apply transform for routed experts (e.g., latent projection for latent MoE)
    hidden_states, shared_experts_input = self.apply_routed_input_transform(
        hidden_states
    )

    hidden_states, og_hidden_dims = self._maybe_pad_hidden_states(
        shared_experts_input,
        hidden_states,
    )

    fused_output = self.forward_entry(
        hidden_states,
        router_logits,
        shared_experts_input,
        self._encode_layer_name(),
    )

    return self._maybe_reduce_output(fused_output, og_hidden_dims)

maybe_all_reduce_tensor_model_parallel

maybe_all_reduce_tensor_model_parallel(
    final_hidden_states: Tensor,
)

Some combine kernels reduce across GPU ranks by default.

Source code in vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
    """
    Some combine kernels reduce across GPU ranks by default.
    """
    if self.must_reduce_shared_expert_outputs():
        return final_hidden_states
    else:
        return tensor_model_parallel_all_reduce(final_hidden_states)

must_reduce_shared_expert_outputs

must_reduce_shared_expert_outputs() -> bool

The shared_experts are typically computed using the RowParallelLinear layer. The result of this function is typically used as the reduce_results argument to the module. When just tensor-parallel is used, it is not required to reduce the shared_experts results immediately. Instead we reduce at the once at the end of the MoE op. (Refer to DeepSeekV2MoE module) With EP and all2all kernels - this is no longer viable as all GPU ranks in DP, produce the complete set of hidden_states. Therefore it is required that we reduce the shared_experts output early.

Source code in vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
def must_reduce_shared_expert_outputs(self) -> bool:
    """
    The shared_experts are typically computed using the RowParallelLinear
    layer. The result of this function is typically used as
    the reduce_results argument to the module.
    When just tensor-parallel is used, it is not required to reduce
    the shared_experts results immediately. Instead we reduce at the
    once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
    With EP and all2all kernels - this is no longer viable as all
    GPU ranks in DP, produce the complete set of hidden_states.
    Therefore it is required that we reduce the shared_experts output
    early.
    """
    return (
        self.quant_method.moe_kernel is not None
        and self.quant_method.moe_kernel.output_is_reduced()
    )