Writing high-performance matrix multiplication kernels for Blackwell

https://news.ycombinator.com/rss Hits: 6
Summary

Writing high-performance matrix multiplication kernels for Blackwell In this guide, we’ll progressively iterate on a matrix multiplication kernel. The first implementation will be very simple, but also quite slow. However, in just a few simple steps it can be modified into a state-of-the-art kernel, matching or exceeding highly optimized implementations such as cuBLAS and CUTLASS. Warning The utilization shown in the table below might be different than what you see online, but the differences can likely be explained by a different input data distribution. All our benchmarks here use arrays with iid normal float16 entries, which turn out to be one of the slower distributions you can choose. You can reproduce the numbers for yourself by running our test file after changing the BENCHMARK variable to True. tl;dr don’t believe matmul benchmarks if they don’t specify input data distribution. Implementation TensorCore utilization % of cuBLAS utilization 0. Basic kernel 37.62% 59.4% 1. Warp specialization 45.47% 71.7% 2. Tiled epilogue 55.82% 88.1% 3. Collective (2CTA) MMA 59.41% 93.7% 4. Persistent kernel 61.46% 97.0% 5. Dedicated epilogue warpgroup 63.38% 100.0% 6. Grid tiling 69.44% 109.6% cuBLAS 63.38% 100.0% CUTLASS 69.30% 109.3% The cuBLAS baseline is obtained by measuring the performace of jax.dot. The CUTLASS performance is measured by taking the best result from the following cutlass_profiler invocation (excluding sparse matmuls): cutlass_profiler --dist=gaussian,mean:0,stddev:1,scale:-1 --output=results.csv --accumulator-type=f32 --m=4096 --k=4096 --n=8192 --kernels='*sm100*' --A=f16 --B=f16 --C=void --D=f16 At each step, we will showcase either the full implementation of the kernel, or the difference between the code listings shown in the previous and current steps. Full implementations can be found in our test file. You can also find the a full standalone optimized kernel implementation in the Pallas ops package. 0. Basic kernel We begin with a simple single-CTA...

First seen: 2025-10-06 21:07

Last seen: 2025-10-07 02:08