vLLM 01 - P/D disaggregation

Why P/D disaggregation?

  • Initial scheduler logic in vLLM: prioritize prefill for good throughput
  • Problem: prefill may slow down other requests’ decode

How to mix P and D together?

  • Well, even their input shapes are different
  • Decode is vector*matrix (e.g., Q projection is 1xd * dxd). not many FLOPs + needs to load model weights and KV cache -> it’s memory bound
  • Prefill is matrix*matrix (nxd * dxd). model weights are reused -> compute bound
  • Solutions: P/D disaggregation, chunked prefill

Chunked prefill

Motivation of chunked prefill

  • Unify prefill and decode procedure: based on some KV cache, P or D does some attention & linear computes, to generate some new tokens
  • The compute flow of prefill and decode are the same, and they just have difference in their input and output shapes
  • Chunked prefill becomes possible if the kernel can accept different shapes.
  • Now the scheduler can make simpler decision: only care about how many tokens to schedule in the current batch

Chunk size

  • chunk size is very important
  • if chunk size is too large, then a decode can be slow when batching with a prefill -> decode is slowed down by prefill
  • if chunk size is too small, then
    1. GPU utilization is bad, and FLOPs is low
    2. it needs many batches to finish the prefill for a long prompt -> prefill is slowed down by decode

When to use chunked prefill

  • prompts are extremely long (e.g., 10k or 100k tokens)

  • why? during attention compute, there will be temp buffers holding QKV, whose memory is proportional to context length. chunked prefill reduces context length, and thus reduces the buffer size

  • want smooth generation, e.g., SLO for p99 inter-token latency

Key problems P/D disaggregation

Connector: how to transfer KV cache?

  • pooling mode: shared memory pool. both sender and receiver need high-bandwidth connection to the memory pool
  • p2p mode: sender communicates with receiver directly; better performance; much harder to implement
  • Frameworks that support KV cache transfer: LMCache, Mooncake, NIXL

LMCache can do KV extraction and transfer

  • support both pooling and p2p mode
  • current target use cases: prefill-decode disaggregation, KV cache offloading

Mooncake

  • KV cache storage: replicas, RDMA support, etc.
  • pooling mode

NIXL: p2p mode

  • it does support p2p semantics directly. instead, it supports some lower-level data transfer features
  • backend is UXL, which is a more general data transfer library than NCCL’s own backend

How to extract & inject KV cache in vLLM?

connector API is called in model_runner

path: vllm/worker/model_runner.py

  • model runner is used to wrap model forward pass

  • preparing the input for forward

  • post-process forward outputs

  • one major part of model runner is to receive & send KV cache

steps

  • before forward, try receiving KV cache and injecting into vLLM’s paged memory
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Receive KV cache in distributed KV cache transfer setting
# In disagg prefill setting, it will also recv hidden states and bypass
# model forwarding
# In KV cache database setting, it will change the model input so that
# we can skip prefilling on tokens that successfully received KV caches
# NOTE: The receive operation is blocking
bypass_model_exec = False
if self.need_recv_kv(model_input, kv_caches):
hidden_or_intermediate_states, bypass_model_exec, model_input = \
get_kv_transfer_group().recv_kv_caches_and_hidden_states(
# model is used to know which layer the current worker
# is working on, so that we can receive KV for only those
# layers.
model_executable,
,
kv_caches=kv_caches
)

  • after forward, extract KV cache from vLLM’s paged memory, and send it outside
1
2
3
4
5
6
7
8
9
10
11
12
13
# Sending KV cache in distributed KV cache transfer setting
# NOTE: the send operation is non-blocking
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
kv_caches,
hidden_or_intermediate_states,
)

How are connector functions implemented?

check path vllm/distributed/kv_transfer/kv_connector/

There are many possible connectors to use, let’s use SimpleConnector code as an example:

receive KV cache

  • check if model_input’s tokens exist in the outside world
  • if they do exist, we compute where the KV cache should be inserted into vLLM’s page memory (parse page table, use page index to find the right place)
  • additionally, it should rebuild the model input to tell the scheduler there is KV cache already, and no prefill should be done again
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:

# some initial setup code
# ...

# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen

if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break

current_tokens = input_tokens_tensor[start_pos:end_pos]

keys, values = [], []

for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]

key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)

current_slot_mapping = slot_mapping_flat[start_pos:end_pos]

keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))

keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)

self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])

logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())

Send KV is similar

When to send requests to P and D?

  • First P then D: when P finishes, it will notify the router, and the router will tell D
  • First D then P: because D is the process to generate responses, let D be responsible for asking for KV cache from P

###KV offloading

  • Connector can also be used for KV cache offloading
  • In such cases, model sharding can be very useful. For example, in GPU-to-CPU KV cache offloading, if TP=8, the total bandwidth is 8 * single_gpu_bandwidth.

Source:

https://www.youtube.com/watch?v=ih6fcJnhoJI


vLLM 01 - P/D disaggregation
https://gdymind.github.io/2025/03/28/vllm-P-D-disaggregation/
Author
gdymind
Posted on
March 28, 2025
Licensed under