Date: October 22, 2025
Environment: Nebius H200 Kubernetes Cluster
Infrastructure: Dynamo LLM Serving Platform v0.5.0
Model: Qwen/Qwen3-32B with FP8 quantization
We conducted a controlled A/B experiment comparing aggregated vs disaggregated inference architectures for Qwen3-32B (32B parameters, FP8 quantized) on identical hardware (10 H200 GPUs each). Disaggregated inference architectures separate prefill and decode operations into specialized worker pools, theoretically enabling better resource utilization and independent scaling.
Contrary to expectations, the disaggregated architecture performed worse across all key metrics:
| Metric | Concurrency 100 | Winner | Improvement |
|---|---|---|---|
| Time to First Token (Avg) | 1.1s vs 4.7s | Aggregated | 4.2x faster |
| TTFT P99 | 3.3s vs 13.5s | Aggregated | 4.1x faster |
| Output Throughput | 3,990 vs 3,158 tok/s | Aggregated | 26% higher |
| Throughput per GPU | 399 vs 316 tok/s/GPU | Aggregated | 26% more efficient |
| Throughput per User | 39.9 vs 31.6 tok/s/user | Aggregated | 26% better UX |
Pareto Efficiency Comparison:
The Pareto plot shows aggregated architecture dominates across the entire efficiency frontier - achieving both better resource utilization (throughput per GPU) and better user experience (throughput per user) at all concurrency levels. There is no tradeoff; aggregated wins on both dimensions.
The KV cache transfer overhead between prefill and decode workers appears to dominate performance. At low concurrency (C=1, no queueing), disaggregated P99 TTFT is 1.4 seconds vs 213ms for aggregated - a 6.4x degradation that cannot be explained by load or variance. This 1,153ms overhead is consistent with KV cache serialization, network transfer (likely cross-node, as our 10-GPU deployment spans 2 nodes with 8 GPUs each), and deserialization.
Confidence caveat: While KV cache transfer is the most plausible explanation, we did not profile the vLLM implementation or confirm node placement during the benchmark. The specific breakdown (serialization, network, deserialization timing) is estimated based on similar systems, not measured directly.
For workloads similar to ours (ISL=4000, OSL=500), use aggregated architecture. Disaggregation provides no benefit and significantly degrades both user experience and resource efficiency. Disaggregation may only be beneficial for:
- Workloads with extreme prefill/decode resource imbalances (not present here)
- Scenarios requiring independent scaling of prefill vs decode capacity
- All-on-one-node deployments with very short input sequences (ISL < 1000)
Both deployments used identical infrastructure and model configurations, differing only in worker architecture.
| Component | Specification |
|---|---|
| Cluster | Nebius Kubernetes (16 nodes × 8 H200 GPUs = 128 total GPUs) |
| GPU | NVIDIA H200 SXM 141GB |
| Node GPU Count | 8 GPUs per node (important: 10 GPUs spans 2 nodes) |
| Interconnect | NVLink/NVSwitch (intra-node), InfiniBand (inter-node) |
| Platform | Dynamo v0.5.0 with vLLM backend |
| Runtime | nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.5.0 |
| Parameter | Value |
|---|---|
| Model | Qwen/Qwen3-32B (32B parameters) |
| Quantization | FP8 (weights and KV cache) |
| Max Model Length | 7,800-7,964 tokens (depending on config) |
| GPU Memory Utilization | 70% |
| Chunked Prefill | Disabled |
Total Resources: 10 GPUs (5 workers × 2 GPUs each)
Topology:
- 1 Frontend (CPU-only, routing/load balancing)
- 5 VllmDecodeWorker replicas
- Each worker: 2 GPUs with tensor parallelism (TP=2)
- Each worker handles both prefill AND decode
GPU Distribution:
- Node 1: 4 workers (8 GPUs)
- Node 2: 1 worker (2 GPUs)
- No inter-worker communication during inference
Namespace: bhamm-qwen32b-agg
backend: pytorch
tensor_parallel_size: 2
pipeline_parallel_size: 1
enable_attention_dp: false
enable_chunked_prefill: false
max_batch_size: 96
max_num_tokens: 7964
max_seq_len: 7964
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.7
dtype: fp8
cache_transceiver_config:
backend: default
cuda_graph_config:
enable_padding: true
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 96]
disable_overlap_scheduler: false
print_iter_log: falseClick to expand: qwen32b-agg-deployment.yaml
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: qwen32b-agg
namespace: bhamm-qwen32b-agg
spec:
services:
Frontend:
dynamoNamespace: qwen32b-agg
componentType: frontend
replicas: 1
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.5.0
nodeSelector:
nvidia.com/gpu.product: NVIDIA-H200-SXM-141GB
VllmDecodeWorker:
envFromSecret: hf-token-secret
dynamoNamespace: qwen32b-agg
componentType: worker
replicas: 5
resources:
limits:
gpu: "2"
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.5.0
workingDir: /workspace/components/backends/vllm
command:
- python3
- -m
- dynamo.vllm
args:
- --model
- Qwen/Qwen3-32B
- --tensor-parallel-size
- "2"
- --quantization
- fp8
- --max-model-len
- "7964"
- --gpu-memory-utilization
- "0.7"
startupProbe:
httpGet:
path: /health
port: 9090
initialDelaySeconds: 120
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 60
livenessProbe:
httpGet:
path: /live
port: 9090
initialDelaySeconds: 300
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 10
readinessProbe:
httpGet:
path: /live
port: 9090
initialDelaySeconds: 300
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 10
volumes:
- name: backend-config
configMap:
name: qwen32b-agg-config
nodeSelector:
nvidia.com/gpu.product: NVIDIA-H200-SXM-141GBTotal Resources: 10 GPUs (2 prefill workers × 1 GPU + 4 decode workers × 2 GPUs)
Topology:
- 1 Frontend (CPU-only, routing between prefill and decode)
- 2 VllmPrefillWorker replicas
- Each worker: 1 GPU (TP=1)
- Handles prefill only, sends KV cache to decode workers
- 4 VllmDecodeWorker replicas
- Each worker: 2 GPUs (TP=2)
- Receives KV cache, handles decode only
GPU Distribution (likely):
- Node 1: 4 decode workers (8 GPUs)
- Node 2: 2 prefill workers (2 GPUs)
- Critical: KV cache transfers occur cross-node via InfiniBand
Namespace: bhamm-qwen32b-disagg
Prefill (disagg_prefill.yaml):
backend: pytorch
tensor_parallel_size: 1
pipeline_parallel_size: 1
enable_attention_dp: false
enable_chunked_prefill: false
max_batch_size: 1
max_num_tokens: 7800
max_seq_len: 7800
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.7
dtype: fp8
cache_transceiver_config:
backend: default
cuda_graph_config:
enable_padding: true
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
disable_overlap_scheduler: true
print_iter_log: falseDecode (disagg_decode.yaml):
backend: pytorch
tensor_parallel_size: 2
pipeline_parallel_size: 1
enable_attention_dp: false
enable_chunked_prefill: false
max_batch_size: 128
max_num_tokens: 7800
max_seq_len: 7800
kv_cache_config:
enable_block_reuse: false
free_gpu_memory_fraction: 0.7
dtype: fp8
cache_transceiver_config:
backend: default
cuda_graph_config:
enable_padding: true
batch_sizes: [1, 2, 4, 8, 16, 32, 64, 96, 128]
disable_overlap_scheduler: false
print_iter_log: falseClick to expand: qwen32b-disagg-deployment.yaml
apiVersion: nvidia.com/v1alpha1
kind: DynamoGraphDeployment
metadata:
name: qwen32b-disagg
namespace: bhamm-qwen32b-disagg
spec:
services:
Frontend:
dynamoNamespace: qwen32b-disagg
componentType: frontend
replicas: 1
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.5.0
nodeSelector:
nvidia.com/gpu.product: NVIDIA-H200-SXM-141GB
VllmPrefillWorker:
envFromSecret: hf-token-secret
dynamoNamespace: qwen32b-disagg
componentType: worker
subComponentType: prefill
replicas: 2
resources:
limits:
gpu: "1"
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.5.0
workingDir: /workspace/components/backends/vllm
command:
- python3
- -m
- dynamo.vllm
args:
- --model
- Qwen/Qwen3-32B
- --tensor-parallel-size
- "1"
- --quantization
- fp8
- --max-model-len
- "7800"
- --gpu-memory-utilization
- "0.7"
- --is-prefill-worker
startupProbe:
httpGet:
path: /health
port: 9090
initialDelaySeconds: 120
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 60
livenessProbe:
httpGet:
path: /live
port: 9090
initialDelaySeconds: 300
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 10
readinessProbe:
httpGet:
path: /live
port: 9090
initialDelaySeconds: 300
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 10
volumes:
- name: backend-config
configMap:
name: qwen32b-disagg-config
nodeSelector:
nvidia.com/gpu.product: NVIDIA-H200-SXM-141GB
VllmDecodeWorker:
envFromSecret: hf-token-secret
dynamoNamespace: qwen32b-disagg
componentType: worker
subComponentType: decode
replicas: 4
resources:
limits:
gpu: "2"
extraPodSpec:
mainContainer:
image: nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.5.0
workingDir: /workspace/components/backends/vllm
command:
- python3
- -m
- dynamo.vllm
args:
- --model
- Qwen/Qwen3-32B
- --tensor-parallel-size
- "2"
- --quantization
- fp8
- --max-model-len
- "7800"
- --gpu-memory-utilization
- "0.7"
startupProbe:
httpGet:
path: /health
port: 9090
initialDelaySeconds: 120
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 60
livenessProbe:
httpGet:
path: /live
port: 9090
initialDelaySeconds: 300
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 10
readinessProbe:
httpGet:
path: /live
port: 9090
initialDelaySeconds: 300
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 10
volumes:
- name: backend-config
configMap:
name: qwen32b-disagg-config
nodeSelector:
nvidia.com/gpu.product: NVIDIA-H200-SXM-141GB# Create namespaces
kubectl create namespace bhamm-qwen32b-agg
kubectl create namespace bhamm-qwen32b-disagg
# Create HuggingFace token secret (for model access)
kubectl create secret generic hf-token-secret \
--from-literal=HF_TOKEN=<your-token> \
-n bhamm-qwen32b-agg
kubectl create secret generic hf-token-secret \
--from-literal=HF_TOKEN=<your-token> \
-n bhamm-qwen32b-disagg
# Install Dynamo platform in each namespace
helm install dynamo-platform oci://nvcr.io/nvidia/dynamo/dynamo-platform-helm \
--version 0.5.0 \
--namespace bhamm-qwen32b-agg
helm install dynamo-platform oci://nvcr.io/nvidia/dynamo/dynamo-platform-helm \
--version 0.5.0 \
--namespace bhamm-qwen32b-disagg
# Create backend ConfigMaps
kubectl create configmap qwen32b-agg-config \
--from-file=backend_config.yaml=agg.yaml \
-n bhamm-qwen32b-agg
kubectl create configmap qwen32b-disagg-config \
--from-file=disagg_prefill_config.yaml=disagg_prefill.yaml \
--from-file=disagg_decode_config.yaml=disagg_decode.yaml \
-n bhamm-qwen32b-disagg# Deploy aggregated cluster
kubectl apply -f qwen32b-agg-deployment.yaml
# Deploy disaggregated cluster
kubectl apply -f qwen32b-disagg-deployment.yaml
# Wait for workers to be ready (~5-10 minutes)
kubectl wait --for=condition=ready pod \
-l component=worker \
-n bhamm-qwen32b-agg \
--timeout=600s
kubectl wait --for=condition=ready pod \
-l component=worker \
-n bhamm-qwen32b-disagg \
--timeout=600sWe used AIPerf (aiperf) for synthetic workload generation with the following parameters:
| Parameter | Value | Rationale |
|---|---|---|
| Input Sequence Length (ISL) | 4,000 tokens (mean), σ=100 | Represents long-context queries (documents, code) |
| Output Sequence Length (OSL) | 500 tokens (mean), σ=50 | Typical completion length for assistants |
| Concurrency Levels | 1, 2, 5, 10, 25, 50, 100 | Low to high load |
| Requests per Level | 200 | Statistical significance |
| Mode | Streaming | Realistic production usage |
| Execution | In-cluster | Eliminates client-side network latency |
All benchmarks ran as Kubernetes Jobs within the cluster to ensure:
- No external network latency
- Persistent result storage (PVC)
- Asynchronous execution
- Identical network conditions for both clusters
#!/bin/bash
set -e
DEPLOYMENT_NAME=$1
SERVICE_URL=$2
OUTPUT_DIR=$3
CONCURRENCY_LEVELS=(1 2 5 10 25 50 100)
REQUEST_COUNT=200
# Install aiperf
source /opt/dynamo/venv/bin/activate
pip install -q aiperf
for CONCURRENCY in "${CONCURRENCY_LEVELS[@]}"; do
CONC_DIR="$OUTPUT_DIR/$DEPLOYMENT_NAME/c$CONCURRENCY"
mkdir -p "$CONC_DIR"
aiperf profile \
--model "Qwen/Qwen3-32B" \
--url "$SERVICE_URL" \
--endpoint-type chat \
--synthetic-input-tokens-mean 4000 \
--synthetic-input-tokens-stddev 100 \
--output-tokens-mean 500 \
--output-tokens-stddev 50 \
--extra-inputs max_tokens:500 \
--extra-inputs min_tokens:500 \
--extra-inputs ignore_eos:true \
--tokenizer "Qwen/Qwen3-32B" \
--streaming \
--request-count "$REQUEST_COUNT" \
--concurrency "$CONCURRENCY" \
--output-artifact-dir "$CONC_DIR" \
2>&1 | tee "$CONC_DIR/benchmark.log"
done# Create benchmark namespace
kubectl create namespace bhamm-benchmark
# Create persistent storage for results
kubectl apply -f - <<EOF
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: benchmark-results-pvc
namespace: bhamm-benchmark
spec:
accessModes:
- ReadWriteMany
storageClassName: nebius-shared-fs
resources:
requests:
storage: 10Gi
EOF
# Create benchmark script ConfigMap
kubectl create configmap benchmark-script \
--from-file=run_benchmark.sh \
-n bhamm-benchmark
# Deploy benchmark jobs (both run simultaneously)
kubectl apply -f benchmark-agg-job.yaml
kubectl apply -f benchmark-disagg-job.yaml
# Monitor progress
kubectl logs -f -n bhamm-benchmark job/benchmark-qwen32b-agg
kubectl logs -f -n bhamm-benchmark job/benchmark-qwen32b-disagg- Aggregated:
http://qwen32b-agg-frontend.bhamm-qwen32b-agg.svc.cluster.local:8000 - Disaggregated:
http://qwen32b-disagg-frontend.bhamm-qwen32b-disagg.svc.cluster.local:8000
For each concurrency level, AIPerf collected:
- Time to First Token (TTFT): Latency from request start to first token (avg, p50, p99)
- Inter-Token Latency (ITL): Time between consecutive tokens (avg)
- Request Latency: End-to-end request completion time (avg, p99)
- Output Token Throughput: Total tokens/sec generated
- Request Throughput: Requests/sec completed
Derived metrics:
- Throughput per GPU: Total throughput / 10 GPUs
- Throughput per User: Total throughput / Concurrency
The Pareto plot compares resource efficiency (throughput per GPU) vs user experience (throughput per user). Ideally, an architecture dominates on both axes.
Key Observations:
-
Aggregated dominates across the entire curve - it achieves higher per-GPU AND per-user throughput at all concurrency levels.
-
At Concurrency 100 (production-scale load):
- Aggregated: 399.0 tok/s/GPU, 39.9 tok/s/user
- Disaggregated: 315.8 tok/s/GPU, 31.6 tok/s/user
- Winner: Aggregated by 26% on both metrics
-
No tradeoff exists - in theory, disaggregation might sacrifice efficiency for better UX (or vice versa). Here, it loses on both.
| Concurrency | Architecture | Total (tok/s) | Per GPU (tok/s) | Per User (tok/s) |
|---|---|---|---|---|
| 1 | Aggregated | 77.2 | 7.7 | 77.2 |
| Disaggregated | 76.5 | 7.7 | 76.5 | |
| 10 | Aggregated | 721.7 | 72.2 | 72.2 |
| Disaggregated | 717.6 | 71.8 | 71.8 | |
| 25 | Aggregated | 1,591.1 | 159.1 | 63.6 |
| Disaggregated | 1,591.0 | 159.1 | 63.6 | |
| 50 | Aggregated | 2,680.8 | 268.1 | 53.6 |
| Disaggregated | 2,582.6 | 258.3 | 51.7 | |
| 100 | Aggregated | 3,990.1 | 399.0 | 39.9 |
| Disaggregated | 3,158.4 | 315.8 | 31.6 |
Analysis:
- At low concurrency (1-25), both architectures perform similarly
- At high concurrency (50-100), disaggregation's overhead compounds:
- More KV cache transfers queued up
- Cross-node network becomes bottleneck
- Coordination overhead increases
- The 26% throughput gap at C=100 translates to needing 26% more hardware with disaggregation to match aggregated performance
TTFT measures responsiveness - how quickly users see the first token. This is the most important UX metric for interactive applications.
| Concurrency | Aggregated | Disaggregated | Speedup |
|---|---|---|---|
| 1 | 177 | 367 | 2.1x |
| 2 | 180 | 310 | 1.7x |
| 5 | 170 | 338 | 2.0x |
| 10 | 185 | 427 | 2.3x |
| 25 | 259 | 728 | 2.8x |
| 50 | 449 | 1,398 | 3.1x |
| 100 | 1,102 | 4,667 | 4.2x |
At C=100, disaggregated users wait 4.7 seconds before seeing any response - an unacceptable user experience for interactive applications.
P99 represents the "worst" user experience that still affects a significant minority of users.
| Concurrency | Aggregated | Disaggregated | Speedup |
|---|---|---|---|
| 1 | 213 | 1,366 | 6.4x |
| 10 | 391 | 1,688 | 4.3x |
| 25 | 875 | 3,546 | 4.1x |
| 50 | 1,754 | 6,777 | 3.9x |
| 100 | 3,291 | 13,512 | 4.1x |
Critical Insight: Even at C=1 (no queueing), disaggregated P99 TTFT is 1.37 seconds vs 213ms for aggregated. This 6.4x overhead is pure disaggregation tax - KV cache serialization, network transfer, and deserialization.
At C=100, 1% of disaggregated users wait over 13.5 seconds for first token - completely unacceptable.
The TTFT degradation is dramatic and worsens with scale:
TTFT Average @ Concurrency 100:
Aggregated: [▓▓▓░░░░░░░] 1.1s
Disaggregated: [▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓] 4.7s (4.2x worse)
TTFT P99 @ Concurrency 100:
Aggregated: [▓▓▓▓░░░░░░░░░░░░░░░░] 3.3s
Disaggregated: [▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓▓] 13.5s (4.1x worse)
ITL measures the "smoothness" of token generation after the first token. This is the only metric where disaggregated shows any advantage:
| Concurrency | Aggregated | Disaggregated | Winner |
|---|---|---|---|
| 1 | 12.58 | 12.30 | Disagg (2.2%) |
| 10 | 13.13 | 12.73 | Disagg (3.0%) |
| 25 | 14.41 | 13.61 | Disagg (5.6%) |
| 100 | 20.69 | 16.88 | Disagg (18.4%) |
Analysis: Disaggregated decode workers are more specialized and slightly more efficient at token generation. However, this 3-18% ITL advantage is dwarfed by the 2-4x TTFT disadvantage. Users would much rather see the first token sooner, even if subsequent tokens arrive slightly slower.
Request latency combines TTFT + (ITL × output tokens). Aggregated maintains lower latency:
| Concurrency | Aggregated (Avg) | Disaggregated (Avg) | Winner |
|---|---|---|---|
| 1 | 6,391 ms | 6,302 ms | Disagg (marginal) |
| 25 | 7,422 ms | 7,455 ms | Agg |
| 100 | 11,333 ms | 13,093 ms | Agg (13% faster) |
P99 Request Latency (critical for SLAs):
| Concurrency | Aggregated (P99) | Disaggregated (P99) | Improvement |
|---|---|---|---|
| 50 | 11,457 ms | 15,258 ms | 24.9% faster |
| 100 | 15,668 ms | 23,348 ms | 32.9% faster |
At high load, disaggregated tail latency is 33% worse, making it very difficult to maintain production SLAs.
| Metric | Winner | Magnitude |
|---|---|---|
| Time to First Token | ✅ Aggregated | 2-4x faster |
| TTFT P99 | ✅ Aggregated | 4-6x faster |
| Inter-Token Latency | 3-18% faster (minor) | |
| Request Latency | ✅ Aggregated | 13-33% faster (high conc.) |
| Output Throughput | ✅ Aggregated | 26% higher |
| Throughput per GPU | ✅ Aggregated | 26% more efficient |
| Throughput per User | ✅ Aggregated | 26% better UX |
Verdict: Aggregated architecture wins decisively on 6 of 7 metrics, with the 7th (ITL) being a minor disadvantage that doesn't meaningfully impact user experience.
What we know for certain:
-
Disaggregation introduces massive overhead - Even at C=1 (no queueing), disaggregated P99 TTFT is 1,366ms vs 213ms - a 1,153ms penalty that cannot be explained by load or variance.
-
The overhead scales with concurrency - At C=100, the TTFT gap widens to 4.2x (avg) and 4.1x (P99), suggesting the overhead compounds under load.
-
Architectural difference is KV cache transfer - The only significant difference between the deployments is that disaggregated must transfer KV caches from prefill to decode workers, while aggregated keeps everything in-process.
-
Cross-node deployment likely exacerbates the issue - Our 10-GPU deployment spans 2 nodes (8 GPUs per node), meaning prefill and decode workers are likely on different nodes, requiring network transfer via InfiniBand.
We believe the KV cache transfer overhead includes:
- Serialization (prefill worker): Convert KV cache tensors to transferable format
- Network Transfer: Send data between nodes (or within node if co-located)
- Deserialization (decode worker): Parse received data back to tensors
- Coordination: Frontend scheduling and handoff logic
Estimated timing (based on similar systems, not profiled):
- Same-node transfer: ~400-600ms per request
- Cross-node addition: ~200-500ms per request
- Total: ~600-1,100ms (roughly matches our 1,153ms observed overhead)
Critical unknowns that affect confidence in our diagnosis:
-
Actual node placement - We didn't capture which nodes hosted which workers during the benchmark (workers are now scaled down). Cross-node transfer is likely given 10 GPUs on 8-GPU nodes, but not confirmed.
-
vLLM implementation efficiency - We don't know if:
- The KV cache transfer is optimized (e.g., using RDMA, compression)
- There are configuration parameters we missed
- The disaggregation code path has known performance issues
-
Network utilization - We didn't measure actual network bandwidth usage or identify bottlenecks.
-
Same-node performance - Our estimate that single-node deployment would reduce overhead to ~0.8s is speculative. It could be better (if network is the main issue) or worse (if serialization dominates).
-
Alternative explanations - Could there be other architectural differences we're not accounting for (scheduling policies, batching behavior, memory allocation patterns)?
What increases our confidence:
- ✅ The overhead appears even at C=1, ruling out queueing/load as primary cause
- ✅ KV cache transfer is a known bottleneck in disaggregated serving systems
- ✅ The magnitude (1-1.5s) is consistent with transferring ~4k tokens of FP8 KV cache over InfiniBand
- ✅ Aggregated has no such transfer requirement and performs as expected
- ✅ The 26% throughput gap at scale aligns with per-request overhead compounding
What would increase confidence further:
- 🔬 Profile vLLM disaggregated execution to measure time spent in each phase
- 🔬 Confirm actual node placement during benchmark
- 🔬 Run 8-GPU (single-node) comparison
- 🔬 Instrument network transfer sizes and latencies
- 🔬 Test with KV cache compression/optimization enabled
Disaggregation is beneficial when:
- Prefill and decode have vastly different resource needs - not true here (ISL:OSL = 8:1, but both fit comfortably in memory)
- Independent scaling is required - not exercised in this experiment
- Prefill batching opportunities exist - limited with streaming workloads
Our workload characteristics:
- ISL = 4k, OSL = 500: Moderate input, moderate output
- Streaming mode: Users want immediate response (TTFT critical)
- Fixed capacity: Not testing dynamic scaling
Result: All cost, no benefit from disaggregation.
Use aggregated architecture for Qwen3-32B in production. The disaggregated deployment should be decommissioned or reserved for research purposes only.
Disaggregation may be worth re-evaluating if:
- Extreme decode-heavy workloads: ISL < 500, OSL > 2000 (e.g., simple Q&A with long answers)
- Dynamic scaling requirements: Need to independently scale prefill/decode capacity based on real-time load
- Single-node deployment: All GPUs on one node (eliminates cross-node overhead)
- Non-streaming use cases: Batch processing where TTFT doesn't matter
To further explore disaggregation:
- Single-node deployment: Test 8 GPUs (4 agg workers vs 2 prefill + 3 decode)
- Different workload profiles: Test ISL=500/OSL=2000, ISL=8000/OSL=200
- KV cache compression: Explore FP4 or quantization for transfer
- RDMA optimizations: Tune InfiniBand parameters for lower latency
- Larger models: Test 70B+ where prefill/decode imbalance may be more pronounced
At C=100, disaggregated requires 26% more GPUs to match aggregated throughput:
- Aggregated: 10 GPUs → 3,990 tok/s
- Disaggregated equivalent: 12.6 GPUs → 3,990 tok/s
For a 100-GPU production deployment:
- Aggregated: 100 GPUs
- Disaggregated equivalent: 126 GPUs (+26 GPUs)
At $3/GPU/hour, disaggregated costs $1,872/day more for equivalent performance, plus 4x worse user experience.
All configuration files and scripts are available in the repository:
qwen32b-agg-deployment.yaml- Aggregated DynamoGraphDeploymentqwen32b-disagg-deployment.yaml- Disaggregated DynamoGraphDeploymentagg.yaml- Aggregated backend configdisagg_prefill.yaml- Disaggregated prefill configdisagg_decode.yaml- Disaggregated decode configrun_benchmark.sh- Benchmark execution scriptbenchmark-agg-job.yaml- Aggregated benchmark Kubernetes Jobbenchmark-disagg-job.yaml- Disaggregated benchmark Kubernetes Jobcreate_pareto_plot.py- Pareto analysis plotting scriptrestore_qwen32b_workers.sh- Quick restore script for scaled-down deployments
To free GPU resources while preserving deployments:
# Scale workers to 0
kubectl patch dynamographdeployment qwen32b-agg -n bhamm-qwen32b-agg \
--type='json' -p='[{"op": "replace", "path": "/spec/services/VllmDecodeWorker/replicas", "value": 0}]'
kubectl patch dynamographdeployment qwen32b-disagg -n bhamm-qwen32b-disagg \
--type='json' -p='[{"op": "replace", "path": "/spec/services/VllmPrefillWorker/replicas", "value": 0},
{"op": "replace", "path": "/spec/services/VllmDecodeWorker/replicas", "value": 0}]'
# Restore when needed
./restore_qwen32b_workers.shTo verify your reproduction:
# Check deployment status
kubectl get dynamographdeployment -n bhamm-qwen32b-agg
kubectl get dynamographdeployment -n bhamm-qwen32b-disagg
# View benchmark results
kubectl exec -n bhamm-benchmark results-viewer -- \
cat /results/qwen32b-agg/c100/profile_export_aiperf.json
# Regenerate plots
kubectl exec -n bhamm-benchmark results-viewer -- \
python3 /tmp/create_pareto_plot.pyThis controlled A/B experiment demonstrates that disaggregated inference provides no benefit for Qwen3-32B with moderate-length inputs (ISL=4k) and outputs (OSL=500). The performance degradation is severe and consistent:
- 4x worse Time to First Token at production scale (C=100)
- 26% lower throughput per GPU
- 33% worse P99 latency
We have high confidence that disaggregation introduces massive overhead (1.1-1.5s per request even at low concurrency), and medium confidence that KV cache transfer is the primary cause. The magnitude and behavior are consistent with serialization + network transfer + deserialization, but we did not profile the implementation to confirm.
To strengthen this diagnosis, we recommend:
- Profiling vLLM disaggregation internals
- Single-node (8 GPU) comparison to isolate cross-node effects
- Testing with different workload profiles (extreme decode-heavy, etc.)
Production deployments should use aggregated architecture for workloads similar to ours. While disaggregation is an appealing architectural pattern in theory, it introduces overheads that dominate performance in our configuration.
The slight ITL advantage of disaggregation (~18% at high concurrency) is vastly outweighed by the TTFT degradation, making aggregated the clear choice for user-facing applications where responsiveness matters.
Disaggregation may be worth reconsidering for extreme workload profiles (ISL<500, OSL>2000), single-node deployments, or if vLLM disaggregation receives significant optimization.
Client Version: v1.28.2
Server Version: v1.28.8
dynamo-platform-dynamo-operator-controller-managerv0.5.0dynamo-platform-etcd-0dynamo-platform-nats-0
- PVC Storage Class:
nebius-shared-fs(ReadWriteMany) - Results Volume: 10Gi
- Intra-node: NVLink/NVSwitch (900 GB/s bidirectional)
- Inter-node: InfiniBand (200 Gb/s)
Report Generated: October 22, 2025
Author: AI-Dynamo Performance Engineering
Contact: [Internal]
