I am currently digging into the MoE EP implementation within vLLM.
I have traced the code down to FusedMoE.select_experts (and related kernel launchers), where I can see that the global physical_expert_id is resolved.
However, I am struggling to locate the specific code implementation that handles the cross-rank communication. Specifically:
After identifying the global expert ID, where does the system dispatch/send the hidden states to the target Rank (machine/GPU) that holds the corresponding expert?
Is there an explicit AlltoAll or send/recv mechanism implemented for token routing, or does vLLM handle MoE differently (e.g., via Tensor Parallelism where weights are sharded and tokens remain local until the final reduction)?
Any pointers to the relevant file or function would be greatly appreciated!
After the global expert IDs are selected in FusedMoE.select_experts, the actual cross-rank communication (i.e., dispatching hidden states to the correct expert rank) is handled by the All2All communication backends, which are invoked in the FusedMoEPrepareAndFinalize subclasses. The dispatch/send of hidden states is performed via explicit All2All collectives, not via tensor parallel sharding. The key logic for this is in the prepare method of the relevant FusedMoEPrepareAndFinalize subclass, which calls the All2All backend to route tokens to the correct ranks, and the finalize method to combine results after expert computation. See the modular kernel design for details on how dispatch and combine are separated from expert computation and are backend-dependent (Fused MoE Modular Kernel, Fused MoE Modular All2All backends).
For example, in the DeepEP or PPLX backends, the All2All dispatch is performed in the prepare step of DeepEPLLPrepareAndFinalize or PplxPrepareAndFinalize, which use All2All collectives to send the relevant hidden states to the correct expert ranks. The code for these mechanisms can be found in files like deepep_ll_prepare_finalize.py and pplx_prepare_finalize.py. The modular kernel framework allows different All2All implementations to be plugged in, but all perform explicit token routing via collective communication, not just local computation with sharded weights (deepep_ll_prepare_finalize.py, modular_kernel.py).
Would you like a step-by-step breakdown of how the dispatch and combine calls are made in the code?