vLLM 03 - prefix caching

KV-cache-aware routing in multi-host serving

https://github.com/vllm-project/production-stack/issues/59

https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/498

Solution 1

  • Use string matching instead of token-ID-based matching
    • tokenization itself is pretty slow (it takes several microseconds) so running it for every request creates huge overhead.
  • implement the string server (e.g., using Redis) as the storage backend

Solution 2

  • Router can send a request to the KV cache management system: which server has the longest matched prefix?
  • vLLM production stack team wants to use this solution: it decouples the logic

https://github.com/vllm-project/production-stack/issues/59

KV-cache-aware routing vs load balancing: Needs tradeoff. KV-cache-aware routing may route requests to the same node, which is bad for load balancing.

KV cache store interface

Back in 2023 when vLLM was not out, let’s take a look at Huggingface’s LLM interface:

1
2
3
4
5
6
7
8
llm.inference(
input.tokens: list[int], # N tokens
previous_kv_cache: list[Tensor], # M tokens' KV cache, where M < N
) -> output_tokens, new_kv_cache

output_tokens: # N' new tokens
new_kv_cache: # (N+N') tokens' KV cache

Let’s not worry about PagedAttention or other complicated things in vLLM.

How do you design a KV cache?

KV cache design is essentially similar to traditional key-value store system design (such as in Redis, Object store, database, etc.)

Our key-value pair in this context is:

  • key: tokens
  • value: KV cache tensors

Then naturally, the interface should be like

1
2
3
4
5
6
7
class KVCacheStore:
def store(tokens, kv_cache_tensors):
pass

def retrieve(tokens) -> kv_cache_tensors
pass

Traditional key-value store usually supports exact match (given the full key, return the value)

KVCacheStore needs another query called prefix matching (actually, some traditional KV store also support this)

Prefix matching

Tokens1: ABCDE → [KV1, KV2, KV3, KV4, KV5]

Tokens2: ABCDF → [KV1 KV2, KV3, KV4, KV6]

kv_cache_store.store(”ABCDE”, [KV1, KV2, KV3, KV4, **KV5**])

When do kv_cache_store.retrieve(”ABCDF”),

we expect the matched prefix part [KV1, KV2, KV3, KV4] to be returned

Trie is a good data structure that supports prefix matching.

vLLM implements a simplified “hash of chunks” structure to simulate a Trie:

1
2
3
4
5
6
7
8
9
# given tokens
"ABCDEF"
# assume chunk size = 2, chunking:
"AB", "CD", "EF"
# chunk hashes:
h1 = hash("AB")
h2 = hash(h1 + "CD")
h3 = hash(h2 + "EF")

With prefix matching, store and retrieve become:

1
2
3
4
5
6
7
8
9
10
11
12
# store
for chunk_hash, chunk_kv in zip(...)
redis.put(chunk_hash, chunk_kv)


# retrieve
for chunk_hash in ...:
kv_chunk = redis.get(chunk_hash)
if kv_chunk is None:
break
...

Chunk size: tradeoff between matching granularity (hit rate) and management overhead

  • inside serving engine:

    • in vLLM, block size = 16, so chunk size is also 16
    • SGLang can support chunk size = 1
  • outside serving engine (on disk):

    • performance determined by # I/O
    • bad performance for small objects, so we need a large chunk size

Semantic caching

How to deal with semantically similar prefix: When receiving a request, before sending it to LLM, do a vector similarity search for similar requests. If similar request exists, we can bypass LLM and return the previous response.

Eviction policy

Which KV cache tensors to evict?

Inside each request:

  • “ABCDEF” → [“AB”, KV1], [”CD”, KV2], [”EF”, KV3]
  • Evict from tail to head: KV3, KV2, KV1

Among different requests: LRU, LFU, …

vLLM implementation

Retrieve

pass in requests to the kv_cache_manager

1
2
3
computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(request)

get_computed_block() in V1

  1. chunking tokens to blocks of 16 tokens
  2. use block_pool.get_cached_block to get computed blocks
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def get_computed_blocks(
self, request: Request) -> tuple[list[KVCacheBlock], int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.

Args:
request: The request to get the computed blocks.

Returns:
A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
if not self.enable_caching:
# Prefix caching is disabled.
return [], 0

# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
block_hashes = hash_request_tokens(self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes

self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None:
# Check for cache hits
computed_blocks = []
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash
# is not in the cached_block_hash_to_id, the following
# block hashes are not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(
block_hash):
computed_blocks.append(cached_block)
else:
break

self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)

# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
else:
# Skip cache hits for prompt logprobs
return [], 0


Store

Defined in /vllm/v1/core/block_pool.py

1
2
3
4
5
6
7
8
9
10
11
12
13
class BlockPool:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the KV cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
by their block hash.

Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
"""

The function is called cache_full_blocks

It caches a list of full blocks for prefix caching.

This function takes a list of blocks that will have their block hash metadata to be updated and cached.

Given a request, it computes the block hashes for the blocks starting from num_cached_blocks to num_full_blocks, updating the metadata for each block and caching them in the cached_block_hash_to_block.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def cache_full_blocks(
self,
request: Request,
blocks: list[KVCacheBlock],
block_hashes: list[BlockHashType],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.

Args:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
"""
if num_cached_blocks == num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks
new_block_hashes = block_hashes[num_cached_blocks:]

# Update the new blocks with the block hashes through the chain.
if num_cached_blocks == 0:
prev_block_hash_value = None
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value

for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None

if i < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")

# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)

# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)

# Update and add the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value

Eviction

still in BlockPool

If a block is cached in cached_block_hash_to_block, we reset its hash metadata and evict it from the cache.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.

Args:
block: The block to evict.

Returns:
True if the block is evicted, False otherwise.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
block.reset_hash()
del self.cached_block_hash_to_block[block_hash][block.block_id]

if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]

return True
return False


Evictor class FreeKVCacheBlockQueue: use a doubly linked list for LRU eviction

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class FreeKVCacheBlockQueue:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.

The queue is ordered by block ID in the beginning. When a block is allocated
and then freed, it will be appended back with the eviction order:
1. The least recently used block is at the front (LRU).
2. If two blocks have the same last accessed time (allocated by the
same sequence), the one with more hash tokens (the tail of a block
chain) is at the front.
Note that we maintain this order by reversing the block order when free
blocks of a request. This operation is outside of this class.
"""

Source:

https://www.youtube.com/watch?v=mWvqA_BNtsU&list=PLJj_urhaf2_qxpg8A5-6xoMvMLBKQMTX1&index=18


vLLM 03 - prefix caching
https://gdymind.github.io/2025/04/11/vLLM-03-prefix-caching/
Author
gdymind
Posted on
April 11, 2025
Licensed under