DeepSeek-V3.2 on GB300: Performance Breakthrough
Summary
DeepSeek-V3.2 has been successfully and smoothly run on GB300 (SM103 - Blackwell Ultra). Leveraging FP4 quantization, it achieves a single-GPU throughput of 7360 TGS in a prefill only scenario. In a (ISL=2k,OSL=1K) context mixed scenario (P+D), the output throughput per GPU is 2816 TGS.
However, compared to DeepSeek-R1, DeepSeek-V3.2 in vLLM still has significant room for improvement in inference performance.
Meanwhile, DeepSeek-R1 (FP4) does not rely on wideEP. Using only 2 GPUs (EP=2), it can achieve a prefill throughput of 22476 TGS per GPU (ISL=2K, batch=256), and reach 3072 TGS per GPU in a (ISL=2k,OSL=1K) context scenario.
The B300 series demonstrates over an 8x improvement in Prefill performance compared to the Hopper series. In P/D mixed scenarios, throughput improves by 10-20x.
Basic Recipe with FP4 Weight Quantization
One of Blackwell’s most notable features is the fifth-generation Tensor Core’s native support for FP4.
1. Download NVFP4 Model Weight from Huggingface
2. Use FP4 MoE Kernel provided by FlashInfer
FP4-quantized MoE models on Blackwell require you to explicitly set VLLM_USE_FLASHINFER_MOE_FP4=1 to enable the FlashInfer FP4MoE Kernel.
export VLLM_USE_FLASHINFER_MOE_FP4=1
3. Serve the Model
GB300/B300 single-GPU memory is 288GB. Two GPUs are sufficient to hold the FP4 format weights of the DeepSeek series models.
vllm serve nvidia/DeepSeek-V3.2-NVFP4 -tp 2
# or
vllm serve nvidia/DeepSeek-R1-0528-NVFP4 -tp 2
4. Optimize Batch Configuration
Below are reference values for the max boundary batch to achieve better prefill throughput for TP=2, using the additional parameter --max-num-batched-tokens:
# DeepSeek-R1-0528-NVFP4
--max-num-batched-tokens 32768
# DeepSeek-V3.2-NVFP4
--max-num-batched-tokens 20480
Note
All performance benchmarks presented in this blog are conducted using vLLM v0.14.1.
To ensure reproducibility, the same software stack and dependency versions are used
throughout all experiments:
- vLLM: v0.14.1
- CUDA: 13.0
- DeepGEMM: v2.1.1.post3
Performance Boost by Blackwell Architecture
FP8 vs. FP4 (for DeepSeek V3.2)
Compared to FP8, FP4 weights yield a 14% improvement in prefill-only scenarios, and 2x output throughput in mixed scenarios.
In Prefill, the major time occupations are attention/Indexer/KV cache writes, etc. FP4 primarily optimizes the weight read bandwidth and GEMM for MoE/MLP, thus the overall improvement is limited (14%).
However, in PD mixed scenarios (ISL=2k, OSL=1K): The proportion of time spent on MoE GEMM and weight loading increases, making the system more susceptible to weight bandwidth and memory bottlenecks. Here, FP4 can significantly reduce weight traffic and trigger the dedicated FlashInfer NVFP4MoE operator path, leading to greater throughput gains.
Furthermore, FP4 frees up more KV-Cache space, allowing for larger batch concurrency (Under FP8, KV Cache occupancy often reaches 100%, causing request queuing; under FP4, the same requests occupy only ~ 25%, greatly relieving memory pressure).
FP4 allows only 2 GPUs to host DeepSeek-V3.2 with -tp=2, and achieve up to 14720 (7360 per GPU) total prefill throughput (ISL=2k,OSL=1,batch=64), and total throughput 5632 (2816 per GPU) for P+D mixing (ISL=2k,OSL=1K,batch=512).
Tip
To use FP8, switch to FP8 model weights and then use VLLM_USE_FLASHINFER_MOE_FP8=1.
Under FP8, DeepSeek-V3.2 requires 4 GPUs, then use -tp 4.
Blackwell Ultra vs. Hopper (for DeepSeek R1)
The chart below shows the per-GPU total throughput comparison, under the same requests and vLLM setup, for GB300 (NVL72), B300 (HGX), and last gen H200:
- In prefill-only (ISL=2K) scenarios, GB300’s per-GPU throughput is 10% higher than B300, and 7.7x higher compared to H200.
- In PD mixed scenarios (ISL=2k,OSL=1K), GB300’s per-GPU throughput is 12% higher than B300, and 20.8x higher than H200.
The reasons are multifaceted: Besides FP4, B300’s FLOPs are 7.5x higher than the Hopper series (peak reaches ~15 PFLOPS). The optimization of attention layer computations by the SM’s SFU modules brings efficiency gains in Prefill.
Its 288GB memory is also 2x that of H200, with memory bandwidth nearly doubled. Additionally, Blackwell Ultra’s high-density NVFP4 FLOPS speed up MoE forward compared to Hopper’s FP8. Those contribute to a significant performance leap in the Decode phase. Reference: https://developer.nvidia.com/blog/inside-nvidia-blackwell-ultra-the-chip-powering-the-ai-factory-era/
GB300 also shows minor improvements over B300 even in small-scale intra-node configurations with TP=2.
Deployment Tuning
EP2 vs. TP2 Selection
Given that DeepSeek-R1’s weights can fit within the HBM of only two B300 GPUs, we explored whether it’s better to scale via DP based on TP2 or based on EP2.
Note
The CLI parameter to switch to EP2 is -dp=2 --enable-expert-parallel.
Prefill-Only Scenario (ISL=2k, OSL=1)
EP=2 (blue curve) reaches a throughput ceiling of 22K TGS, outperforming TP=2 (green curve) in both throughput and the growth slope of TTFT. This benefits from EP’s typical “large packet, low frequency” communication pattern, which better utilizes the high bandwidth of RDMA/NVLink under high concurrency.
However, the blue EP curve exhibits some fluctuations due to unbalanced expert routing, causing different batches to hit different expert distributions and resulting in variations in expert load and all-to-all communication volume.
P/D Mixed Scenario (large ISL, small OSL)
Under TP = 2, each decode step introduces inter-GPU communication overhead, which leads to a 50% to 2× degradation in TPOT compared to EP = 2.
However, TP also improves TTFT by ~ 50%, accelerating the execution of each step. This improvement offsets the TPOT degradation, ultimately resulting in an overall throughput gain of 5%–20% in terms of output tokens.
Conclusions
- For DeepSeek-R1 on GB300 in disagg-prefill, EP is more suitable for prefiller (then simply increase the DP count for scaling). EP has a higher throughput ceiling in Prefill (peak ~10% - 15% higher than TP2), while TTFT growth with concurrency is more gradual, which is more beneficial for controlling queuing and tail latency.
- In a P/D integrated deployment, the strategy depends on workload:
- When ISL is large and OSL is small, then the prefill phase becomes the dominant bottleneck, TP=2 is recommended, to prevent excessive attention-layer latency from crowding out GPU time in the decode phase.
- In contrast, for output-heavy case, the TPOT advantage of EP=2 becomes dominant, and it is therefore the preferred configuration.
Benefits of MTP
MTP provides decent improvements for Decode, but not always a silver bullet.
As argued below, the built-in draft model speculates 1 token at a time, balancing acceptance rate and computational load.
--speculative-config.method mtp \
--speculative-config.num_speculative_tokens 1
When the context length is not long, enabling MTP (blue) for DeepSeek R1-0528 on GB300 achieves higher throughput than disabling MTP (green) within a certain concurrency range (<=256) (acceptance rate can reach > 80%). However, throughput drops sharply when MTP is enabled under high concurrency.
In scenarios with ISL=2048 and OSL=64, the decode proportion is extremely low. The overhead of MTP’s multi-token prediction cannot be amortized, resulting in increased per-token compute, memory pressure, and scheduling complexity. At low concurrency, the overhead cannot be amortized; at high concurrency, it further squeezes prefill batching and system concurrency.
Therefore, the overall throughput is lower than that achieved with MTP disabled at both low and high concurrency levels.
DeepSeek V3.2 - Still Way To Go
As shown in the chart below, with the same GB300 setup, DeepSeek R1’s Prefill throughput capability is ~ 3x that of DeepSeek V3.2.
- DeepSeek R1 in EP2 can reach a peak Prefill throughput of ~ 22k TGS.
- DeepSeek V3.2 in EP2 is relatively weaker, with a Prefill peak throughput of ~7.3k TGS per GPU.
- Regarding TTFT, with both models using TP2, R1 reduces latency by about 55% compared to V3.2.
However, in the P+D mixed case, the gap between the two models in terms of Output Throughput and TPOT is not significant.
Why does R1’s throughput beat V3.2 overall?
The main reason is that V3.2 introduces the Indexer/Sparse MLA (Indexer + SparseAttnIndexer) and uses DeepseekV32IndexerBackend with a dedicated cache structure. In prefill phase, this adds extra quantization/indexing computation, which reduces throughput. Profiling analysis also shows that the kernel execution time for a single DSA layer step is 2.7x times that of MLA.
From vLLM code perspective, apart from the Indexer path, NVFP4MoE kernel selection is identical between V3.2 and R1. So the prefill performance difference primarily comes from the overhead of V3.2’s Indexer/Sparse Attention.
The advantage of DSA better serves ultra-long contexts. If your context doesn’t require sufficient attention computation, the extra overhead becomes pronounced. However, as the context length increases further, DSA’s TPOT advantage in the Decode phase shows up, surpassing MLA between 10k-20k tokens and leading with about a 6x steeper slope.
Last, the DeepseekV32IndexerBackend is still relatively new and immature, with considerable optimization potential.
Therefore, we believe DeepSeek-V3.2 still has significant room for improvement.
Disaggregated Prefill (for DeepSeek-V3.2)
Below is a quick-start tutorial of disaggregated prefill of 1P+1D via RDMA scaleout network (next blog will show tips through NVLink72 across GB serial trays)
# Prefill Node
export VLLM_USE_FLASHINFER_MOE_FP4=1
export UCX_NET_DEVICES=mlx5_bond_0:1 # optional, tell NIXL to use specific RDMA interface
export VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_NODE_IP}
vllm serve nvidia/DeepSeek-V3.2-NVFP4 -tp 2 --max-num-batched-tokens 20480 \
--kv-transfer-config \
'{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_load_failure_policy":"fail","kv_buffer_device":"cuda"}' \
--port 8000
# Decode Node
export VLLM_NIXL_SIDE_CHANNEL_HOST=${DECODE_NODE_IP}
...
# Exactly the same ENV variables and vllm CLI as Prefill Node, except `VLLM_NIXL_SIDE_CHANNEL_HOST`
# Proxy Node
cd vllm # move to vLLM source code and may need to install necessary dependencies
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py \
--port 8000 \
--prefiller-hosts ${PREFILL_NODE_IP} --prefiller-ports 8000 \
--decoder-hosts ${DECODE_NODE_IP} --decoder-ports 8000
# If you have multiple Prefiller or Decoder:
# just append to hosts list, like: `--prefiller-hosts ${IP1} ${IP2} --prefiller-ports 8000 8000 `
# vllm bench against the proxy (using random dataset and ISL=4k,OSL=1k)
vllm bench serve --model nvidia/DeepSeek-V3.2-NVFP4 \
--seed $RANDOM --dataset-name random \
--base-url http://${PROXY_NODE_IP}:8000 \
--tokenizer /mnt/models/DeepSeek-V3.2 \
--num-prompts 500 --max-concurrency 100 \
--random-input-len 4096 --random-output-len 1024 \
--ignore-eos
Note
PD Disaggregation on vLLM v0.14.1: To run PD disaggregation with vLLM v0.14.1, you need to manually apply the patch from PR #32698. However, this feature has been merged into the latest vLLM main branch, so if you’re using a newer version, you may not need this patch.
We use the Nixl KV Connector to facilitate KV transfer across processes/nodes. Both P and D roles use TP=2 strategy.
As concurrent load increases, the disagg setup shows throughput advantages over the integrated setup, with the gap widening, while maintaining lower latency (both TTFT and TPOT). The slope of latency increase is also more stable.
Regarding TPOT, both 1P1D and 3P1D outperform the non-disagg setup. At 256 batch, the disagg setup suppresses TPOT within 60ms, while the integrated setup exceeds 80ms.
When ISL continues to grow (from 2K to 8K), the throughput of the 1P1D setup begins to struggle, with Prefill becoming the bottleneck. Requests wait in queue at the P node, failing to release the Decoder’s compute power. When adding 2 P replicas (3P1D), they parallelize the Prefill phase of more requests, achieving better total throughput.
Although the per-GPU throughput may not be the highest for disaggregation, better Goodput and SLO guarantees are achieved with more hardware investment.
Preview: next blog will showcase the practice of P/D disaggregation, leveraging NVL72 on GB200.
Acknowledgements
We would like to give thanks to the many talented people in the vLLM community who worked together as a part of this effort:
- Verda Team: for providing GB300 cluster and offering support.
- DaoCloud Team: Xingyan Jiang, Nicole Li, Peter Pan, Kebe Liu
- InferAct Team: Jie Li, Kaichao You