vLLM on RTX 5090 Blackwell: A Technical Deep Dive
Hi folks, I spent most of today trying to get vLLM running with PyTorch 2.9.0 and it looks like the most recent build takes care of a lot of errors.
There are so many ways to get this wrong and I’m amazed it worked at all. I think I hit every issue on this forum to get to this point. I hope it helps anyone else working on the same issue to get things running.
Working Installation Process
System Configuration
- GPU: NVIDIA GeForce RTX 5090 GB202 (Blackwell sm_120 architecture)
- Driver: NVIDIA 575.64.03
- CUDA: 12.8 (automatically supported by driver)
- OS: Ubuntu 25.04 Plucky Puffin
- RAM: 192 GB
- CPU: AMD Ryzen 9 9900X3D (12-core)
- Python: 3.12.10 (in virtual environment)
The Working Solution
The successful installation used PyTorch 2.9 nightly + vLLM source build:
vllm --version
# Output: 0.10.1rc2.dev413+g5438967fb.d20250901
This is a September 1, 2025 development build compiled from source with PyTorch 2.9.0.dev20250831+cu128.
Installation Steps (Confirmed Working Method)
- Environment Setup
uv venv --python 3.12.10 --seed
source .venv/bin/activate
- PyTorch 2.9 Nightly Installation
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
# Result: torch==2.9.0.dev20250831+cu128
- vLLM Source Build
gh repo clone vllm-project/vllm # This can just as easily be cloning from ssh
cd vllm
python use_existing_torch.py # This line seems to be working now and it's important
pip install -r requirements/build.txt
export VLLM_FLASH_ATTN_VERSION=2
export TORCH_CUDA_ARCH_LIST="12.0"
MAX_JOBS=6 pip install --no-build-isolation -e .
- Verification
vllm --version
# Should show: 0.10.1rc2.dev413+g5438967fb.d20250901
What Doesn’t Work (Confirmed Failures)
Pre-built Wheels from PyPI
- All official releases: Lack Blackwell sm_120 support
- Dependency conflicts: Version mismatches with PyTorch requirements
- Missing wheel combinations: No PyTorch 2.8.x+cu128 wheels exist
Docker Container Approaches
- All official images: Not ready for Blackwell architecture
- NVIDIA NGC containers: Contain outdated vLLM versions incompatible with RTX 5090
Failed PyTorch Version Attempts
- PyTorch 2.7.x: Only supports CUDA 12.6, incompatible with Blackwell sm_120
- PyTorch 2.8.x: No cu128 wheels available anywhere (404 errors on GitHub releases)
- Dependency resolution conflicts: vLLM nightly wants torch==2.8.0, but this doesn’t exist with cu128
Environment Variable Requirements
The successful build required specific Blackwell-targeting variables:
VLLM_FLASH_ATTN_VERSION=2 # FA3 unsupported on Blackwell
TORCH_CUDA_ARCH_LIST="12.0" # Blackwell sm_120 architecture
MAX_JOBS=6 # Memory management during compilation
Performance Analysis
Startup and Warmup Behavior
Initial load: ~72 seconds (model download and compilation)
First request: 19.9 tokens/s (cold start)
Second request: 81.6 tokens/s (warming up)
Subsequent requests: 50-300+ tokens/s (stabilized)
Optimization Process
- CUDA Graphs: Progressive optimization with request patterns
- torch.compile: JIT compilation of hot code paths (3.36s initial compilation)
- KV Cache: Memory utilization grows from 0.0% to 0.2% as patterns stabilize
- Flash Attention: Using FA backend on V1 engine (FA3 not supported on Blackwell)
Memory Usage
- Qwen2.5-7B: 31GB VRAM usage (aggressive KV cache pre-allocation)
- Available KV cache: 27.09 GiB
- Maximum concurrency: 288.98x for 1,024 token requests
- Graph capturing: Additional 0.44 GiB
Sustained Performance
- Decode speed: 290+ tokens/second (Qwen2.5-7B)
- Response quality: Proper model behavior with appropriate stopping
- Stability: No crashes or memory issues during extended testing
Technical Architecture Details
Blackwell-Specific Challenges
- sm_120 compute capability: Newer than most software expects
- CUDA 12.8 minimum requirement: Software ecosystem lagging behind hardware
- Flash Attention limitations: FA3 unavailable, must use FA2
- Kernel availability: Many operations lack optimized kernels for sm_120
Key Build Dependencies
From the successful installation, these packages were critical:
- PyTorch: 2.9.0.dev20250831+cu128 (nightly build)
- CUDA libraries: All 12.8.x versions (cublas, cudnn, etc.)
- Build tools: cmake 4.1.0, ninja 1.13.0, setuptools-scm 9.2.0
- Ray: 2.49.0 with cgraph support (cupy-cuda12x dependency)
vLLM Configuration
Model loading: 0.6611 GiB, 72.403091 seconds
Chunked prefill: enabled with max_num_batched_tokens=2048
Compilation level: 3 (highest optimization)
CUDA graphs: enabled with 67 capture sizes
Backend: Flash Attention V1 engine
Root Cause Analysis
Why Most Installations Fail
- PyTorch Version Gap: No PyTorch 2.8.x+cu128 wheels exist, creating dependency deadlock
- vLLM Version Constraints: Nightly builds expect torch==2.8.0 but must use torch>=2.9.0 for Blackwell (No idea why)
- Architecture Support Lag: sm_120 support very recent, not in stable releases
- Build Environment Requirements: Specific environment variables needed for Blackwell compilation
Why This Installation Worked
- Source compilation: Bypassed pre-built wheel dependency conflicts
- PyTorch 2.9 nightly: Only version with functional CUDA 12.8+sm_120 support
- use_existing_torch.py: Critical script that cleaned dependency files to use existing PyTorch
- Proper environment variables: VLLM_FLASH_ATTN_VERSION=2 and TORCH_CUDA_ARCH_LIST=“12.0”
- Sufficient resources: 192GB RAM prevented memory-related build failures (change MAX_JOBS env variable for tighter RAM budgets)
Comparison with Alternative Solutions
I’ve only tested Ollama on this rig, but the documentation out there hints that this is a vLLM problem to a degree.
llama.cpp
- Status: Confirmed working with pre-built wheels
- Performance: 700+ tokens/sec prefill, good decode performance
- Setup: Significantly easier, no dependency conflicts
- Use case: Better choice for most users until vLLM stabilizes
TensorRT-LLM
- Status: Working but requires complex setup
- Performance: Highest potential (FP4 optimization)
- Setup: Requires building from source, NVIDIA-specific optimizations
- Use case: Best for maximum performance, enterprise deployments
Ollama
- Status: Works reliably
- Performance: Moderate (baseline comparison)
- Setup: Trivial installation
- Use case: Good fallback option, proven stability
Recommendations
For RTX 5090 Users (September 2025)
- Try llama.cpp first: Most reliable path to working inference
- For vLLM: Use the exact source build method documented above
- Avoid pre-built wheels: All fail due to PyTorch version conflicts
- Monitor development: New versions may resolve dependency issues
For vLLM Development Team
- Update dependency constraints: Support PyTorch 2.9+ in nightly builds
- Improve Blackwell documentation: Current guides don’t address sm_120 specifics
- Pre-built wheel support: Provide wheels compiled with PyTorch 2.9+cu128
- Environment detection: Auto-set VLLM_FLASH_ATTN_VERSION=2 for Blackwell GPUs
For Enterprise Adoption
- Source build required: No stable pre-built solution exists
- Test thoroughly: Performance characteristics still stabilizing
- Monitor memory usage: Current builds are memory-aggressive
- Have fallback plans: Keep alternative inference engines available
Timeline and Implications
Current State (September 2025)
- Limited working installations: Mostly individual researchers/developers using source builds
- No enterprise adoption: Production deployments require custom compilation
- Blackwell support is improving as we test it: Weekly improvements in compatibility
Expected Evolution
- Official support: Likely within 1-2 stable releases
- Performance optimization: Memory usage and speed improvements
- Documentation: Better guides for Blackwell-specific setup
Broader Context
This installation represents early adoption of new hardware with new software. While functional, it’s a sign of how legitimately painful running at the front of AI infrastructure capabilities can be.
Conclusion
The successful vLLM installation on RTX 5090 Blackwell required:
- PyTorch 2.9 nightly (2.9.0.dev20250831+cu128)
- Source compilation from vLLM git main branch
- Specific environment variables for Blackwell compatibility
- Proper build sequence including use_existing_torch.py
The 290+ tokens/second performance with even just mediocre calibrating that RTX 5090 + vLLM can deliver “enterprise-grade” inference (still no support for causing Sam Altman to cry ‘AGI’ every few weeks), but the installation process remains challenging. The dependency version conflicts between vLLM requirements and Blackwell support means source builds are required until official compatibility is resolved. Fortunately they’re not that tricky once someone figures out the broken bits.
Key Insight: The missing PyTorch 2.8.x+cu128 wheels create a dependency deadlock that can only be resolved by jumping to PyTorch 2.9+ and building vLLM from source.
I’ll be re-posting this on my GitHub and Substack, maybe.
Feedback and questions welcome.