Throughput benchmark on RTX 5090

#2
by paulml - opened

849K tok/s prefill on RTX 5090 (85% MFU)

Ran some throughput benchmarks for this model on a Blackwell GPU, sharing in case it's useful.

Numbers

seq_len batch_size tok/s MFU
128 64 848,904 85.1%
256 32 817,424 84.1%
512 16 780,597 84.2%
1,024 8 715,385 83.4%
4,096 16 506,644 79.1%
10,240 4 327,488 73.2%

MFU against RTX 5090 BF16 dense peak (209.5 TFLOPS). All measured, not estimated.

Code

import torch
import torch.nn.functional as F
import time
from transformers import AutoModel

torch.set_float32_matmul_precision("high")

model = AutoModel.from_pretrained(
    "microsoft/harrier-oss-v1-270m", dtype=torch.bfloat16
).cuda().eval()

embed = model.embed_tokens
rotary = model.rotary_emb
layers = model.layers
norm = model.norm

def forward(input_ids, position_ids):
    x = embed(input_ids)
    pos = rotary(x, position_ids, "full_attention")
    for layer in layers:
        x = layer(x, position_embeddings=pos, position_ids=position_ids)
    return norm(x)

compiled = torch.compile(forward, mode="max-autotune")

BS, SL, N_ITER = 64, 128, 100

ids = torch.randint(1, 262144, (BS, SL), device="cuda")
pos = torch.arange(SL, device="cuda").unsqueeze(0).expand(BS, -1)

with torch.no_grad():
    for _ in range(5):
        compiled(ids, pos)
    torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(N_ITER):
        out = compiled(ids, pos)
        F.normalize(out[:, -1, :], dim=-1)
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - t0

tps = BS * SL * N_ITER / elapsed
print(f"{tps:,.0f} tok/s  ({elapsed/N_ITER*1000:.1f} ms/batch)  bs={BS} sl={SL}")

What matters

Use PyTorch nightly, not stable. This is the single biggest thing. Stable PyTorch 2.7 ships Triton 3.3 which has a bug on Blackwell (sm_120): it crashes when trying to autotune matmul kernels, so you're forced to fall back to Ampere-era CUTLASS kernels. That costs ~17% throughput. The nightly (2.12+) bundles Triton 3.7 which fixes this (triton-lang/triton#9734).

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
pip install transformers

The forward function is pulled out of the model internals so torch.compile can trace it cleanly. A couple of gotchas:

  • embed_tokens already scales by sqrt(hidden_size) internally, so don't do it again
  • skip attention_mask for uniform-length batches since SDPA handles causal masking, and you avoid a graph break

The model ends up 81% GEMM-bound after compilation. torch.compile already fuses all the pointwise stuff (RMSNorm, GeLU*gate, residuals) into Triton kernels, so there's not much left to squeeze from custom kernels.

Things I tried that didn't help

  • Flash Attention 2: built flash-attn 2.8.3 from source for sm_120. Less than 1% difference vs SDPA. The model only has 4 query heads with MQA so attention isn't the bottleneck.
  • FP8 quantization: works on nightly now (was crashing on stable). Roughly neutral at short sequences, +10% at 10K tokens. Quality drops to 0.97 cosine sim though.
  • CUDA Graphs: reduce-overhead mode gives the same throughput as max-autotune. torch.compile already eliminates kernel launch overhead.
  • Multi-stream replicas: only helps at 8K+ token sequences (+3-5%). Hurts at short sequences due to dispatch overhead.

Sign up or log in to comment