Writing
Porting Mamba2 to ROCm
Mamba2 introduced the SSD (state space duality) framework, which unified structured state space models with linear attention under a single mathematical lens. The original implementation targets CUDA via Triton. Getting it onto AMD GPUs — ROCm’s HIP ecosystem — is mostly mechanical, but there are a handful of places where the abstraction leaks and you need to understand what’s actually happening at the hardware level.
This post covers what I learned porting the core SSD kernels to ROCm for an internal benchmark. The goal was functional correctness first, performance parity second.
What Mamba2 actually does
The SSD layer is a sequence-to-sequence transformation parameterized by four tensors: A (decay), B (input projection), C (output projection), and X (input). The recurrence is:
h_t = A_t * h_{t-1} + B_t * x_t
y_t = C_t @ h_t
In matrix form, this is a structured masked matrix multiply — the “dual” the paper refers to. Efficient computation requires a tiled chunked algorithm that avoids materializing the full sequence-length attention matrix.
The Triton kernel for this is roughly 400 lines implementing the chunked SSD computation with careful management of SRAM vs. HBM access patterns.
The HIP translation
ROCm ships hipify-clang, which mechanically translates CUDA C++ to HIP. For straight CUDA kernels, the hit rate is high — intrinsics like __shfl_xor_sync map to __shfl_xor under HIP, grid/block launch syntax is identical, atomics translate cleanly.
Triton is a different story. Triton has a ROCm backend (triton-rocm) that is mostly functional but trails the CUDA backend by a few months on new features. For Mamba2, three things needed attention:
1. tl.dot precision. On CUDA, tl.dot with allow_tf32=True uses TF32 accumulation on Ampere+. On ROCm, the equivalent path uses MFMA instructions on CDNA architectures. The accumulation behavior differs in the low bits. For SSD, this is a non-issue — the recurrence is numerically forgiving — but it surfaced in my validation suite because I was comparing against CUDA outputs at float32 tolerance. Loosening the tolerance to atol=1e-4 cleared it.
2. Warp size. AMD CDNA GPUs have a warp size of 64 (a “wavefront”) vs. NVIDIA’s 32. Triton handles this in the compiler, but if you have any hardcoded assumptions about warp size in your validation code or test harness, they’ll be wrong on ROCm. I had one.
3. Shared memory bank conflicts. The chunked SSD kernel does a transpose in shared memory. The bank conflict pattern differs between CDNA and SM architectures because of the 64-wide wavefront. On CDNA, a naive transpose that’s conflict-free on NVIDIA will hit partial bank conflicts. Adding a one-column padding to the shared allocation — the standard CUDA trick — still helps, but the optimal pad width differs. I used pad = wavefront_size // 32 as a heuristic and it recovered most of the throughput regression.
Validation strategy
I kept the CUDA implementation as the reference and validated ROCm outputs against it at atol=1e-4, rtol=1e-4 on:
- Random inputs across shapes
batch×seq×dim:1×256×512,4×1024×1024,8×2048×2048 - Edge cases: single token, maximum sequence length, zero-initialized state
- Gradient correctness via
torch.autograd.gradcheckon a small instance
All passed after the tolerance adjustment. The gradient check required using float64 for the numerical Jacobian approximation — this is standard for SSMs because the state accumulation amplifies floating-point errors.
Performance
On MI250X vs. A100-SXM4 at batch=4, seq=2048, dim=1024:
- Forward pass: ROCm ~0.92× CUDA throughput
- Backward pass: ROCm ~0.87× CUDA throughput
The backward gap is larger because the chunked backward kernel does more scattered reads, and CDNA’s HBM bandwidth profile penalizes scattered access slightly more than SM90. Worth investigating further; I didn’t dig into the roofline analysis for this port.
The one place it actually matters
The selective_scan_cuda extension (used in Mamba1, still used in some Mamba2 variants) ships as a compiled CUDA extension. It does not hipify cleanly — there are device-specific intrinsics that hipify-clang drops silently. If your code path hits selective_scan_cuda, you will get a silent wrong answer, not an error.
The fix is to use the pure-Triton SSD path that Mamba2 provides. It’s what you should be using anyway on modern hardware, but it’s worth confirming your import path doesn’t fall back to the CUDA extension.
References
@article{dao2024transformers,
title={Transformers are {SSMs}: Generalized Models and Efficient Algorithms
Through Structured State Space Duality},
author={Dao, Tri and Gu, Albert},
journal={arXiv:2405.21060},
year={2024}
}
@software{triton_rocm,
title={Triton ROCm Backend},
author={OpenAI and AMD},
url={https://github.com/ROCm/triton},
year={2024}
}