Back
MoMoE: Memory-optimized Mixture of Experts
Summary
- Mixture-of-Experts (MoE) architectures scale model parameter counts efficiently by activating only a fraction of weights per token.
- However, most open-source MoE implementations suffer from performance bottlenecks, failing to achieve the theoretical throughput and memory benefits, especially during training/finetuning.
- To improve the speed of training and fine-tuning MoE models, we introduce Memory optimized Mixture of Experts (MoMoE), an MoE implementation built with fused Triton kernels, optimized memory packing, and a configurable backward pass for memory-compute trade-offs.
- MoMoE significantly outperforms popular baselines (NVIDIA's TEGrouped, ScatterMoE, etc.) in speed and memory efficiency while ensuring numerical correctness in both micro-benchmarks and at-scale model training.
- We release & open-source MoMoE as a public, memory-optimized Triton kernel for high-performance training and inference of Mixture-of-Experts models.
MoMoE beats other performant MoE kernels on fwd+bwd, even for large workloads.
"Mo' Money, Mo(E)' Problems"
Introduction
Scaling laws demonstrate that model quality improves consistently as both the volume of training data and the number of active parameters grow [1] [2]. Increasing active parameters, however, raises computational requirements in lock-step. Mixture-of-Experts (MoE) architectures, first popularized by Shazeer et al [3] and subsequently deployed in systems such as DeepSeek-v3 [4], K2 [5], Qwen-3 [6], and Llama-4 [7], seek to decouple these two trends. By activating only of expert subnetworks per token, MoE layers enable large parameter counts without incurring dense-compute costs for every forward pass [8].
MoE architectures work by replacing the standard feed-forward (MLP) layers in transformer blocks with MoE layers composed of multiple parallel expert MLPs. A lightweight gating network selects the top-K experts for each token, routing the token’s representation only through those selected experts. This sparse activation allows the model to scale up the total number of parameters dramatically, while keeping the per-token computation cost comparable to that of a dense MLP. Recent designs, beginning with DeepSeek-MoE [9], have shifted toward using a larger number of total experts and increasing the number of active experts per token, while correspondingly reducing the hidden size of each expert. As a result, efficiency becomes even more critical, both in terms of expert batching and memory access patterns, since the cost of poor implementation scales with the number of experts involved.
In theory, this architecture offers substantial wall-clock and memory efficiency gains. However, in practice the arithmetic cost may stay low while wall-clock performance suffers dramatically. Most open-source implementations fail to deliver on MoE’s promised throughput benefits for several reasons:
- Heavy reliance on Python control flow: Implementations often resort to token-wise for-loops in Python, which are intrinsically slow and difficult to scale.
- Inadequate batching strategies: Poor handling of dynamic expert assignment leads to fragmented tensor operations that under-utilize hardware.
- Overgeneralized kernel design: Even efficient Triton-based implementations like ScatterMoE [10] and Qwen3-MoE-Fused [11] prioritize abstraction and reuse over specialization, leading to wasted memory or redundant computation.
- Suboptimal memory layout and indexing: Many implementations lack the precise memory indexing and kernel fusion required to achieve high utilization of modern accelerators.
- Backward pass inefficiencies: Intermediate tensors are often saved excessively during the forward pass, imposing scalability limits during training due to memory overhead.
MoMoE (Memory-optimized Mixture of Experts) was developed to directly address these inefficiencies. Specifically, we implement our kernel to focus on obtaining larger speed-ups during training and fine-tuning, compared to the existing open-source MoE implementations.
In this post, we discuss in more detail how the above bottlenecks emerge, share some of the unsuccessful strategies we explored along the way, and ultimately present the techniques that led to meaningful speedups on modern GPU clusters.
Shortcomings with Existing Open-Source MoE
Despite the widespread availability of performant implementations for several common layers, modules, and operations (e.g. CUBLAS GeMMs, Pytorch topk, softmax etc.), Mixture of Experts has a surprising lack thereof given its popularity as an architectural choice.
Below, we discuss several commonly used forward and backward MoE implementations (and their limitations). The list may not be comprehensive, and we include discussion of some adjacent implementations in Appendix F.
PyTorch-based Implementations
The simplest and most widely common way to build an MoE layer is by utilizing vanilla PyTorch ops, as done in the open source implementation of Qwen3. This approach yields a highly readable and adaptable code but suffers from a critical performance gap: the forward pass relies on an explicit Python for-loop that iterates through the selected experts. Rather than applying the experts in parallel, this loop forces repeated kernel launches which prevents the kernel from hitting theoretical throughput during training and inference.
A single for‑loop over experts iterates down the column: for each token the router picks its top‑k experts, the loop runs their SwiGLU blocks, and the outputs are accumulated back into the token’s slot.
Triton-based Implementations
Among the few open-source Triton implementations, ScatterMoE enjoys the most popularity due to its efficiency and simplicity of integration. It’s a solid improvement over the Pytorch version on almost every axis: speed, memory, etc…
Originally created for efficient training of Qwen3, another similar Triton-based implementation is Qwen3 MoE Fused.
Unfortunately, these efficient implementations have their shortcomings. Both kernels suffer from overgeneralization of their operations, whereby they apply the same kernels for different linear layers. This design choice simplifies the implementation at the cost of either extra memory or extra computation.
Hybrid Implementations
Megatron LM is NVIDIA’s official training library and one of the most popular training/finetuning codebases [12]. As such, it features highly optimized kernels for MoE training. The newest implementation (TEGrouped), relies on permute/unpermute operations defined through NVIDIA’s Transformer Engine in Triton, and a Grouped GEMM operation in Cutlass [13].
While theoretically a sound idea, the cost of permuting and unpermuting is not as low as its time complexity might suggest. Therefore, a forward pass with scattering/gathering fused into the kernels themselves (spoiler alert), as well as with a fused SwiGLU, tends to work wonders in non-asymptotic cases (which is almost all practical use cases).
Note that Megatron LM affords several flexibilities that our kernel avoids - such as EP, FP8 support, token dropping, etc. These flexibilities make it more customizable for pretraining, but make it hard to optimize surgically.
Objective
Rather than abstracting out different parts of the MoE operation, we opt to enable our code to do one thing - and do it really well.
Past implementations tend to avoid recomputation and save many intermediate tensors for the backward pass, making them memory expensive and difficult to scale to large-scale training.
Instead of fixing the memory/speed tradeoff in recomputation, we are the first MoE kernel to give full customizability to the user. They can freely choose how much to save/recompute for the backward pass, allowing training of large MoE models even in resource-constrained settings. This flexibility empowers applications across large scale training, small scale training, and fast inference.
We present below the journey that led to the creation of MoMoE.
The interested reader will find several of our more granular learnings from the journey in Appendix A.
MoMoE Design
MoMoE was created to address the shortcomings of existing implementations, providing a performant, memory-efficient, and scalable MoE implementation for public use.
MoMoE was optimized for training and fine-tuning workloads (although can also be used during inference), and provides large throughput improvements compared to existing kernels.
Note: for simplicity, we decided to only implement a SwiGLU [16] version of the MoE, but only a few changes are required to make it compatible with other GLU variants.
MoE Algorithm
Before diving into the implementation, we will formalize the problem by presenting pseudocode for Mixture of Experts.
Note: we annotate tensor shapes and dimensions where appropriate to simplify ease-of-understanding (versions with clearly annotated dimensions in Appendix G.
Forward pass pseudocode for MoE.
Most PyTorch MoE forward implementations bear strong similarity to the above, since PyTorch has dutifully implemented and optimized most of these individual operations. However, although polished, the direct implementation requires a Python for loop - which is exceedingly slow.
The backward faces the same problem when implemented in PyTorch:
Backward pass pseudocode for MoE.
Similarly, in the backward pass, a PyTorch implementation is rife with inefficiencies which are unfortunately unavoidable in PyTorch.
We address these shortcomings by repacking the sparse token‑expert layout, and fusing the expert loop into Triton kernels for both the forward and backward pass of MoE.
The Beauty of Inherent Sparsity
In our implementation, we opt for the simplest MoE router which uses Softmax + Top-K. Mathematically:
Here, the final output for a token () is the weighted sum of the outputs from its top-k selected experts (), where each expert's weight is its gate value calculated by applying the softmax function to the router's logits ().
Note: for the actual model, we also implemented a post-TopK normalization for the gate values, dividing each gate value by their total sum.
The Top-K router makes the computation inherently sparse: a batch of tokens may interact with all experts but every token interacts with only K of them. Unfortunately, we can’t exploit this sparsity with batched PyTorch BLAS because the chosen experts differ per token.
With the magic of Triton kernels, we can avoid unnecessary operations by embracing online selectivity, while simultaneously batching our operations (i.e., using gather to only apply used experts).
Embracing Selectivity
Converting a padded ragged matrix of shape M_max × N into a compact packed array of length M_sum using a cumulative‑sum offset table to eliminate wasted padding.
In the pseudocode above, we referenced as the list of token indices in the batch where an expert must be applied. If we pack all such lists (across experts) into one long list of length , we can then construct the cumulative sum of ’s. Given an expert index, we can now easily find the exact split of relevant to our expert but using the cumsum indices to index into the packed list.
This is strongly preferable to maintaining an assignment tensor (which is enormous and wasteful). Since we can obtain any expert’s token indices easily, parallelism can finally manifest its full strength.
The Forward Pass
Two‑kernel MoE forward pass.
Subsequently, we can split the MLP computation of each expert into the following kernels:
- Fused Gather + Gate + Up-Projection + SwiGLU
- Fused Down-Projection + Scatter
Triton makes gathering from different indices straightforward; our first kernel fuses the index gather operation with matrix multiplication and applies the SwiGLU activation on-line without materializing up- and gate-projection outputs.
The second linear kernel is close to a vanilla GeMM, differing only by its fused weighted scatter of the per-token expert outputs. We evaluate two methods to perform this scatter:
- Map-Reduce: write each output to a temporary buffer and batched reduce operation in another pass.
- Atomic-Add: write each expert’s output directly to the target scatter location using atomic primitives.
Empirically, atomics edge out map-reduce for FP32 scatter, but map-reduce is faster in BF16. Following existing implementations, we choose to perform the reduction in BF16 using map-reduce.
And that’s it for the forward! With a few simple tricks, we were able to batch the experts together and fuse operations, heavily speeding up the process.
The Backward Pass
For the backward, we reuse several ideas, although the implementation becomes more complicated. The first consideration is the choice of what to save from the forward pass vs. what to recompute in the backward pass. ScatterMoE, for example, chooses to save many intermediate tensors for the backward, which makes scaling difficult.
Rather than a one‑size‑fits‑all policy, we prioritize configurability. When training, it is beneficial to save as much as possible, as long as the batch size can be scaled as desired. However, when performing inference, storing intermediate tensors is unnecessary, so their materialization can be avoided completely. We allow the flexibility to choose a desired save_percent
between 0 and 100, where 0 means saving as little as possible and 100 means saving everything to avoid recomputation in the backward pass.
Five‑kernel MoE backward pass.
Keeping this in mind, we split the expert application (MLP) part of our backward pass into 5 kernels:
- A scatter kernel for the original input and the gradient of the output
- Recomputing H + Down Projection Backward + SwiGLU Backward
- Updating Down Projection weights
- Recomputing Y + Up Projection Backward + Gate Projection Backward + Routing Weightage Backward + Scatters
- Updating Up-Projection and Gate-Projection weights
Note: We observed that one gather + scatter at the beginning worked faster than performing gathers in every future kernel, which is why we added kernel 1.
The guiding principles for the backward kernel follow the forward’s approach: we fuse aggressively yet practically, and utilize a combination of BF16 reductions and FP32 atomics to scatter back into the output tensors (chosen based on empirical performance).
Results
Benchmarking
Micro-correctness
We perform rigorous validation to ensure that our MoE kernel produces correct numerical results.
To begin, we utilize a PyTorch MoE implementation as a correctness baseline for our implementation to perform standalone correctness tests for the forward and backward outputs. We record the absolute and relative errors over a number of samples after each pass. The resulting errors reveal only speckled noise () and max relative err () for the output and all gradient tensors, meeting our acceptance thresholds. We visualize them in the figures below.
Relative errors for sample inputs with various MoE kernels.
All observed discrepancies fall within expected bounds and stem solely from mixed‑precision accumulation and the ordering of fused reductions. If you are curious about how we justify this claim please consult Appendix D.
Throughput & Memory Gains
We benchmark our MoMoE kernel to evaluate its practical performance, focusing on two key aspects: computational speed and memory usage. For this analysis, we compare our implementation against a standard implementation in PyTorch and ScatterMoE. The benchmarks are conducted across a diverse range of model configurations to ensure a thorough evaluation. The results are shown below:
MoMoE forward+backward pass throughput profiling for various choices of workload size, load balancing distribution, and MoE kernel. Higher is better.
MoMoE forward pass throughput profiling for various choices of workload size, load balancing distribution, and MoE kernel. Higher is better.
First, we assess the computational speed, measured in Tera Floating Point Operations per Second (TFLOPS). As illustrated in the figure above, MoMoE delivers a substantial performance uplift across all tested configurations. The MoMoE variant with 100% saving consistently achieves a higher TFLOP rate than both the Scatter and Torch baselines during the combined forward and backward passes, with the 50% and 0% saving variants not far behind.
MoMoE forward+backward pass peak memory profiling for various choices of workload size, load balancing distribution, and MoE kernel. Lower is better.
In addition to speed, memory efficiency is a critical consideration for training large models. The figure above compares the memory footprint of our MoMoE kernel against the baselines. The results reveal a significant advantage for MoMoE, which consumes considerably less memory across all scenarios, saving more than ten times the memory when applying full recomputation.
In the comparison above, balanced, random, and skewed refers to the load balancing between the experts. Balanced means all experts get the exact same number of tokens, random means that there is no explicit load balancing enforced, and skewed means 25% of the experts receive 80% of the tokens. In particular, regardless of the token routing distribution, MoMoE offers strong savings on memory consumption.
Efficiency vs Sparsity
A prudent question one may ask is how the performance of a sparse kernel scales as the sparsity is interpolated from 0 (dense MLP, all experts active) to 1 (sparse MLP, 1 expert active/token). To test this, we fixed a MoE layer size and profiled our kernel over increasing values of : the number of active experts.
Speed vs K (number of selected experts) for MoMoE versus other kernels. In the highlighted window where almost all current MoE models reside, MoMoE is the fastest.
As can be seen, the time for our kernel scales linearly in , and for , MoMoE is the fastest implementation. Our implementation beats Scatter MoE everywhere, but loses to Megatron (NVIDIA’s) MoE implementation only for .
We will note however that:
- Almost all current MoE models use , and our kernel was built with high sparsity in mind.
- Megatron’s implementation makes use of efficient CUTLASS grouped GEMMs (which are significantly more optimized than Triton matrix multiplication). Since at low sparsity/high density compute becomes dominated by GEMMs instead of MoE logic, it’s reasonable to expect TEGrouped to outperform MoMoE.
Pre-training
To further validate the practical benefits of MoMoE, we conduct a pre-training experiment comparing it against a standard dense Transformer model. The training configurations for this experiment are detailed in Appendix B.
Loss curves for pretrained models. FLOP-matched MoE models outperform their dense counterparts.
For a fair comparison, we use a FLOP-matched setup: the MoE model has 1.3B total parameters but only 450M active parameters per forward pass, which we compare against a dense model with 450M parameters. Since both models have the same number of active parameters, they perform the same number of floating-point operations per training step.
The results of this comparison are presented in the loss comparison figure above. The MoE model, after 1k steps, consistently achieves a lower training loss than its dense counterpart throughout the training run.
Throughput and memory comparison between saving vs recomputing. As expected, saving activations increases overhead but improves throughput.
Furthermore, we analyze the performance trade-offs within our own kernel implementation, specifically comparing a "Full Saving" strategy (where activations are cached) against a "Full Recomputation" one. As shown in the throughput comparison figure, the "Full Saving" approach yields a higher throughput. This performance gain comes at the cost of increased memory usage (so one has flexibility on which strategy fits their use case better).
It is important to note that the MoE model in this experiment is still considered undertrained. We expect that a longer training run would further amplify the observed benefits of our MoE implementation.
Finetuning
The efficiency afforded by MoMoE enables the training of larger models or the use of larger batch sizes within the same hardware constraints, providing a practical benefit for researchers and practitioners.
We consider the Qwen 3 series, which showcases strong MoE models. For correctness, we take the smaller member of the family: Qwen3 30B A3B (30B parameters, 3B active per token), and finetuned it on a private dataset with sequence length 4k.
Finetuning speed/memory results for MoMoE on Qwen3 30B A3B.
Note: As an aside, we chose to benchmark the Forward and Full since benchmarking Backward by itself makes no sense (who ever runs just the backward pass?)
MoMoE dramatically boosts memory efficiency relative to the pytorch baseline, as well as speed. In resource-constrained (e.g. academic) settings, MoMoE allows for finetuning with much larger token batch size/MoE parameter count than otherwise possible.
Inference
To validate the performance/correctness of our forward kernel at scale, as well as in a setting useful for practitioners, we apply our kernel for inference. The throughput/memory results are shown below in the single GPU setting without a specialized serving engine.
Inference speed as a function of sequence length for MoMoE and PyTorch baseline.
For smaller context lengths, the benefit from MoMoE is much stronger (due to higher effective sparsity). Asymptotically, the MoMoE and Pytorch approach the same performance (as with all MoE kernels as they become GEMM-bound). For most context regimes practically used during evaluation, MoMoE exhibits strong gains.
We then compare the loss on 2048-length samples of UltraFineweb [14] and performance on various LM evaluation harness tasks between the Huggingface transformers implementation and a version using our MoE kernel.
Distribution of delta cross-entropy between MoMoE and ground truth implementation for samples of UltraFineWeb and Qwen3 30B A3B.
For the vast majority of samples, MoMoE gets exact correctness on loss. Even where there is a difference, it is very small, and completely explainable by the noise described in Appendix D.
Below we visualize the MMLU scores, including subtask scores. Specifically, we first evaluate the pretrained Qwen3-30B-A3B model across MMLU’s 57 different subtasks under zero-shot settings, obtaining a weighted average of 77.85%. Then, we replace the original MoE layers with MoMoE and run the same evaluation again, receiving a weighted average of 77.92%.
MMLU performance by subtask between ground truth and MoMoE. The leftmost column is the weighted average (overall MMLU performance).
Intra-task score variation is caused by precision noise as described in Appendix D. The total performance on MMLU is unchanged, with reductions in wall-clock time thanks to MoMoE.
Conclusion
Mixture-of-Experts offers a compelling promise: intelligence at scale without proportional cost. But to actualize on that promise, we need efficient implementations.
By embracing simplicity, fusing where it counts, and giving users fine-grained control over memory and compute (for the first time!), MoMoE delivers the speed, efficiency, and correctness necessary for modern large-scale training and inference.
In particular, we seek to enable academics, private labs, and independent researchers alike to capitalize on being able to train/finetune/run frontier-grade MoE models. We’re excited to share MoMoE with the community. And to enable use, we are planning to submit PRs shortly to a few training/inference codebases as availability allows.
Good science should be cheap and accessible!
You can find our code here: https://github.com/tilde-research/MoMoE-impl
Please cite as:
@article{costin2025momoe,
title={MoMoE: Memory optimized Mixture of Experts)},
author={Costin, Bobby and Averbuch, Timor and Pai, Dhruv and Chen, Nathan and Keigwin, Ben},
journal={Tilde Research Blog},
year={2025},
month={7},
url={https://www.tilderesearch.com/blog/momoe},
note={Blog post}
}
Acknowledgements
Our implementation is written to fit inside the Flash‑Linear‑Attention project, and we thank the contributors of FLA for their work. We also extend our gratitude to the developers of TorchTitan for providing a platform for LLM pre‑training.
For testing, we thank the Qwen team for the open-sourced MoE models tested. We also wish to thank the Megatron LM, ScatterMoE, and MegaBlocks authors for pioneering the landscape of open-source MoE kernels.
We would like to thank Quinn McIntyre and members of the technical staff for feedback on the post.
References
- Kaplan, Jared and McCandlish, Sam and Henighan, Tom and Brown, Tom B. and Chess, Benjamin and Child, Rewon and Gray, Scott and Radford, Alec and Wu, Jeffrey and Amodei, Dario (2020).
- Hoffmann, Jordan and Borgeaud, Sebastian and Mensch, Arthur and Buchatskaya, Elena and Cai, Trevor and Rutherford, Eliza and de Las Casas, Diego and Hendricks, Lisa Anne and Welbl, Johannes and Clark, Aidan and more (2022).
- Shazeer, Noam and Mirhoseini, Azalia and Maziarz, Krzysztof and Davis, Andy and Le, Quoc and Hinton, Geoffrey and Dean, Jeff (2017).
- DeepSeek-AI and Liu, Aixin and Feng, Bei and Xue, Bing and Wang, Bingxuan and Wu, Bochao and Lu, Chengda and Zhao, Chenggang and Deng, Chengqi and Zhang, Chenyu and more (2025).
- Yang, An and Li, Anfeng and Yang, Baosong and Zhang, Beichen and Hui, Binyuan and Zheng, Bo and Yu, Bowen and Gao, Chang and Huang, Chengen and Lv, Chenxu and more (2025).
- Fedus, William and Zoph, Barret and Shazeer, Noam (2022).
- Dai, Damai and Deng, Chengqi and Zhao, Chenggang and Xu, R. X. and Gao, Huazuo and Chen, Deli and Li, Jiashi and Zeng, Wangding and Yu, Xingkai and Wu, Y. and more (2024).
- Tan, Shawn and Shen, Yikang and Panda, Rameswar and Courville, Aaron (2024).
- Huggingface Transformers.
- Mohammad Shoeybi and Mostofa Patwary and Raul Puri and Patrick LeGresley and Jared Casper and Bryan Catanzaro.
- NVIDIA.
- Wang, Yudong and Fu, Zixuan and Cai, Jie and Tang, Peijun and Lyu, Hongya and Fang, Yewei and Zheng, Zhi and Zhou, Jie and Zeng, Guoyang and Xiao, Chaojun and more (2025).
- Wang, Lean and Gao, Huazuo and Zhao, Chenggang and Sun, Xu and Dai, Damai.
- Shazeer, Noam (2020).
Appendix
A. Learnings
Building MoMoE was a useful and instructive learning experience. We present below some of our key takeaways from the project.
-
Simplicity is very often the answer.
For example, we experimented with launching the exact number of blocks needed for kernel launches. We hypothesized that launching blocks under the assumption that every would be suboptimal, since we might return early out of unnecessary blocks.
The first idea required some computation at the start of every block to figure out which expert we were handling, and the exact block of tokens for that expert - which should still be a net gain by avoiding empty launches.
This reasoning turned out to be unfounded - the cost of these computations was higher than the cost of the empty blocks. In retrospect, the cost of an empty block is close to zero since it terminates almost immediately after execution begins.
As another lesson in simplicity, we spent some time investigating fusing the Softmax + Top-K operations into a kernel. However, the cost of the routing computation is absurdly low compared to the MoE MLPs, and we ultimately opted for running the Softmax + Top-K as simple PyTorch operations for simplicity.
-
Even in Triton, the
.to
operation is costly, and should not be overused.We assumed that type conversions would be extremely cheap compared to the multiplications in a GeMM - which turned out to be another unfounded claim. Type conversion slowed our kernels down by a whopping factor of ~2x!
We removed all repeated
.to
calls from the GeMM kernels and created the scatter kernel in the backward to ensure proper typing. -
Atomics outperformed reduction in FP32, most likely due to the sparsity of the experts.
We thought that store-all + reduce would work faster than atomics, even with FP32, but choosing simplicity and letting the hardware handle the reduction for us actually ends up being quicker! Even when compared to an optimized store-all kernel which efficiently keeps track of the outputs of the k experts chosen for every token followed by a PyTorch
.sum
reduction, atomics outperform.The one way to make reductions faster is to effectively reduce their accuracy by storing intermediates in BF16, and then reducing there. We cannot compare this to an equivalent atomics implementation since BF16 is not supported for the CUDA atomic add operation.
-
Avoid CPU blocking operations unless entirely necessary
torch.nonzero
is a very convenient operation to use for mixture of experts, as a means to find the indices of the tokens to use each expert for. Problematically, it blocks the CPU from queueing more jobs until it completes.It turns out the shape of the output to
nonzero
is unknown until completion. However, with a topk router, we know the exact number of total tokens that will get used by the experts: , and we can exploit this to make our kernel much faster.In general, it is paramount to ensure CPU operations are non-blocking unless absolutely necessary, which enables strong GPU utilization. This is a further shortcoming of ScatterMoE, since their use of
aten::bincount
blocks the CPU. -
Always be mindful with tensor memory layouts
For reducing across chosen experts per token in the second linear layer, we need to reduce over the (number of selected experts) dimension. So we have (batch size * sequence length) groups of down projection outputs, each of size (embedding dimension). Originally, we stored this as a tensor: , since we believed the reduction would be very fast if it was over the last dimension. While this is true, this layout completely tanked the runtime of our
lin2_kernel
, which computed the down projection.It turns out, this is caused by the fact that the blocks we were writing to were not at all contiguous in memory anymore, since a 2D block in the and dimensions would be all over the place, split up by the final dimension. Instead, it is much smoother to make our output so that the dimension can be contiguous in memory. This sped up our kernel by almost a whole order of magnitude!
This is not some insane never-before-understood trick, but rather just a notice to always be wary of memory access patterns and a warning of the dangers of non-contiguous memory accessing in GPU kernels.
-
For loops in Triton have a very specific nature, and simplicity once more seems to work like a charm
In general, we noticed that the best order of operations to choose in Triton for loops is that which looks nicest. This sounds stupid, but let us explain. Take for example the fused SwiGLU matrix multiplications for the up and gate projections in our
lin1_kernel
:acc_a_bMibH = tl.zeros((BLOCK_SIZE_Mi, BLOCK_SIZE_H), dtype=tl.float32) acc_b_bMibH = tl.zeros((BLOCK_SIZE_Mi, BLOCK_SIZE_H), dtype=tl.float32) for d in range(0, D, BLOCK_SIZE_D): x_bMibD = tl.load(x_ptrs_bMibD, mask=offset_D[None, :] < D - d, other=0.0) Wl1_a_bDbH = tl.load(Wl1_ptrs_a_bDbH, mask=offset_D[:, None] < D - d, other=0.0, cache_modifier=".cg") Wl1_b_bDbH = tl.load(Wl1_ptrs_b_bDbH, mask=offset_D[:, None] < D - d, other=0.0, cache_modifier=".cg") acc_a_bMibH = tl.dot(x_bMibD, Wl1_a_bDbH, acc_a_bMibH) acc_b_bMibH = tl.dot(x_bMibD, Wl1_b_bDbH, acc_b_bMibH) x_ptrs_bMibD += BLOCK_SIZE_D * stride_xD Wl1_ptrs_a_bDbH += BLOCK_SIZE_D * stride_WD Wl1_ptrs_b_bDbH += BLOCK_SIZE_D * stride_WD
This ended up being by far the most performant for loop order, probably sped up by Triton’s automatic staging (i.e. for this kernel our
num_stages
is generally 4 or 5). If you are unfamiliar with Triton loop staging, it is basically a method of parallelization of for loops where multiple iterations are done ‘at once’. For example, we can load into shared memory for iteration 0, then load for iteration 1 while doing the computation for iteration 0, then load for iteration 2 while doing the computation for 1 while storing for 0, and so on.We speculate this form of layout where the loading, computation, and pointer incrementation are separated clearly will benefit the most from loop staging since Triton can easily optimize the process.
To be clear, the other methods we tried are:
- have two for loops, one for up projection and one for gate projection
- Only have one shared memory block for the weights
Wl1_bDbH
and first do the up projection loading + computation, followed by the gate projection loading + computation
Even though the first method uses extremely simple for loops (likely to be highly optimizable with Triton), and the second method uses around of the shared memory, both of these are significantly slower in practice from what we observed
B. Training Details
We follow the standard recipe from the literature. Our models are MHA transformers with RoPE, PeriLN, AdamW, weight decay 0.1, etc. They are trained on 1 node over 15B tokens of the Ultra Fineweb corpus released by OpenBMB.
C. Load Balancing
Expert collapse in MoE models is a phenomenon by which experts become underutilized during training. Expert collapse is problematic as it results in inefficient hardware utilization and poor learning dynamics. To mitigate expert collapse, various papers have proposed different mechanisms of load-balancing in MoEs. For example, Megatron’s TEGrouped uses the auxiliary load-balancing loss (auxk) loss.
We instead opt for an alternative load-balancing strategy: aux-free bias. Originally pioneered by DeepSeek AI, this strategy replaces an explicit load balancing loss with a separate set of per-expert routing biases which act as discretely-optimized controllers, tuned based on expert imbalance at each iteration [15]. This approach is much faster than auxk loss and improves training dynamics, with better loss and load balancing.
The kernel we release by default supports the fast computation of the controller updates, though it is easy to modify to support any kind of routing and load balancing formulation. To stress, the choice of load balancing and generally router is freely customizable for MoMoE.
D: Why do we expect error?
To be precise, how can we justify relative differences as large as a couple percentage points? Rather than ignoring these entirely, let’s take a deeper dive into why they make sense and are nothing to fear.
The first hint we have that these errors are explainable by precision differences is that everywhere we see errors, they appear to be somewhat normally distributed.
That is a good start, but far from concrete evidence. Empirically, we further find that precision differences are normally distributed around zero with non-negligible standard deviation. Here is a histogram of the relative errors (in %) of summing only 10 FP32 numbers generated with torch.randn
with the output in BF16 in two ways:
- Do the summation in FP32, cast afterward
- Cast at the beginning, do summation in BF16
Histogram of relative differences between cast+sum strategies for random inputs.
Code
import torch
import matplotlib.pyplot as plt
errors = []
bases = []
for i in range(10000):
N = 10
a = torch.randn((N,), device="cpu", dtype=torch.float32)
bases.append((a.sum().to(torch.bfloat16)).to(torch.float32))
errors.append((a.sum().to(torch.bfloat16) - a.to(torch.bfloat16).sum()).to(torch.float32))
print(torch.abs(torch.tensor(errors)).mean())
print(torch.abs(torch.tensor(errors)).std())
# check if the error is normally distributed (only plot the values that are between -3 and 3, deleting the outliers for ease of visualization)
rel_errs_clamped = torch.clamp(torch.tensor(errors) / (torch.abs(torch.tensor(bases)) + 1e-6) * 100, min=-3, max=3)
rel_errs_new = rel_errs_clamped[torch.abs(rel_errs_clamped) < 3]
plt.hist(rel_errs_new, bins=50)
# save plot
plt.savefig("error_distribution.png")
We can see that the error created by adding a few numbers in bfloat16 is nontrivial. For SwiGLU, here are the relative errors we get on two numbers generated with torch.randn (in %), again by doing the operations on the casted numbers vs. casting after the operations:
Histogram of relative differences between cast+sum strategies for SwiGLU.
Code
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time
bases = []
errors = []
for i in range(10000):
a = torch.randn((1,), device="cpu", dtype=torch.float32)
b = torch.randn((1,), device="cpu", dtype=torch.float32)
swiglu_fp32 = a * F.silu(b)
swiglu_bf16 = a.to(torch.bfloat16) * F.silu(b.to(torch.bfloat16).to(torch.float32))
bases.append(swiglu_fp32)
errors.append(swiglu_fp32 - swiglu_bf16)
print((torch.abs(torch.tensor(errors)) / (torch.abs(torch.tensor(bases)) + 1e-10)).mean())
print((torch.abs(torch.tensor(errors)) / (torch.abs(torch.tensor(bases)) + 1e-10)).std())
# check if the error is normally distributed
plt.hist(torch.tensor(errors) / (torch.abs(torch.tensor(bases)) + 1e-10) * 100, bins=50)
# save plot
plt.savefig("error_distribution.png")
Both of these operations (which boil down to multiplication and addition) are heavily used in MoMoE. To make their respective effects worse, we observe that the sum of two Gaussian distributions is another Gaussian distribution with and . This implies that for our precision differences, we will have a mean of 0 but a variance that grows linearly in the number of operations we perform!
Our errors are explained by the differences in the precision of calculations, and we see the errors distributed similarly (though our calculation is technically the more precise one!).
E. Correctness Heatmaps
We visualize here heatmaps of the differences between our kernel results and a PyTorch baseline for a sample input. As shown, the errors are small and, most importantly, random and isotropic without problematic structure.
Heatmaps of absolute and relative differences between MoMoE and ground truth for a sample tensor.
F. Related Implementations
vLLM - vLLM has a kernel implementation from mixture of experts. We do not directly compare to vLLM as it does not support backward pass through MoE, making it unsuitable for training. However, we note that for inference the kernel suffers from the same issues that other Triton implementations face, coming from over-generalizations meant to allow for more flexibility in terms of configuration, but ultimately leading to weaker performance.
Liger/Unsloth - Liger/Unsloth does not have any specialized MoE kernel support. It supports a fused SwiGLU implementation, but the SwiGLU operation is not a computation bottleneck in Mixture of Experts implementations. As such, they perform only marginally better than Pytorch baselines.
DeepEP - Deepseek’s communication library is compatible with MoMoE. We do not handle EP or any communication primitives - our focus is solely on optimizing the actual MoE computation itself.
G. Fully Dimension Annotated MoE Algorithms
We believe that fully showing tensor dimensions at every step of computation is crucial to catching mistakes early. We apply these practices as much as possible in our code and derivations. Below, we include the forward and backward passes of MoE, annotated with shapes.