NeuronMM: High-Performance Matrix Multiplication for LLM Inference on AWS Trainium
Trainium, as a typical systolic-array architecture, features a programmable memory hierarchy (including two types of on-chip SRAMs and off-chip HBM). It also provides a rich set of specialized compute engines tailored for various AI operators. Such hardware heterogeneity gives programmers a lot of flexibility to explore for better performance. However, leveraging the Trainium architecture for high performance can be challenging, because of frequent data movement and the necessity of aligning a tensor's logical shape with Trainium's physical memory layout.
At Yotta Labs, enabling high-performance AI workloads across heterogeneous hardware is a core mission. Our research team, led by Chief Scientist Dong Li, has developed NeuronMM, a high-performance matrix multiplication (matmul) kernel that dramatically accelerates LLM inference on AWS Trainium. NeuronMM is open sourced and adds a key milestone to the Trainium ecosystem. We introduces a series of techniques customized to Trainium to reduce data movement across the software-managed memory hierarchy, maximize the utilization of SRAM and compute engines, and avoid expensive matrix transpose.
Evaluation
Implementations. NeuronMM is developed on top of NeuronX Distributed Inference library, and the MLP kernels are implemented based on AWS NKI. We evaluate NeuronMM on an trn1.2xlarge instance of Amazon Elastic Compute Cloud (AmazonEC2) equipped with AWS Trainium accelerators, running the DeepLearning Amazon Machine Images (AMI) Neuron (Ubuntu22.04). We use a single Neuron Core with 16GB HBM on a Trainium chip for evaluation.
Dataset. We test LLMs that fit entirely into the HBM and are currently supported by NeuronX Distributed Inference library, including Llama-3.2-1B, Llama-3.2-3B, Qwen3 1.7B, andQwen3-4B. We evaluate NeuronMM with nine datasets, covering three language modeling datasets(WikiText 2, PTB, and C4) and six common sense reasoning datasets (OpenBookQA [38], WinoGrande [44], PIQA[9], HellaSwag[59], ARC-e, and ARC-c [11]). For fine-tuning with LoRA, we use the yahma/alpaca-cleaned.
Results. Figure 1 shows execution time and memory traffic of different matmul implementations across input sequence lengths. In the figure, "NKI𝑋𝑊" computes the standard matmul without SVD (a compression technique used in NeuronMM), while "NKI𝑋𝑈𝑉" executes matmuls, using the low-rank factors 𝑈 and 𝑉 derived from the SVD of 𝑊, without TrainiumFusion optimization (a technique created by NeuronMM). We evaluate our kernel with matrices 𝑋∈R(𝑀×8192), 𝑊∈R(8192×16384), 𝑈∈R(8192×4096), and 𝑉∈R(4096×16384), where 𝑈,𝑉 denotes the low-rank approximation derived from the SVD of 𝑊. We vary the first dimension 𝑀 of 𝑋 from 1024 to 32768 to simulate different sequence lengths.

Figure 1. Execution time and HBM-SBUF memory traffic of different matmul implementations across input sequence lengths.
Compared to the NKI𝑋𝑊 baseline, NeuronMM delivers an average 2.09× speedup, reaching 2.22× (84.15 ms vs. 186.60 ms) at sequence length 32K, driven by 4.78× reduction in HBM-SBUF memory traffic. NeuronMM also outperforms NKI𝑋𝑈𝑉 baseline, achieving a 1.35× speedup with over 2.6× less memory traffic on average.
Table 1 demonstrates the generability of NeuronMM across various LLMs across 9 datasets. We report mean accuracy (mAcc) and average end-to-end speedup. NeuronMM achieves significant end-to-end inference speedup (1.21×–2.49×), while 𝛾 (a metric to assess the trade-off between inference speed and accuracy degradation) remains low — ranging from 3.24% to 25.27%, with the most values below10% — indicating a favorable trade-off between speedup and accuracy. For example, on Qwen-3-1.7B, NeuronMM enables 1.74×faster inference with only 0.03 accuracy drop compared to standard LLM inference.
Table 1: Evaluation of NeuronMM across four LLMs and nine datasets under compression ratios of 0.1 and 0.2.
