The Key-Value (KV) Cache is both the hero and the villain of LLM serving. It's absolutely essential for achieving reasonable generation speed, yet it's also the primary memory bottleneck that limits throughput. Understanding how the KV Cache works, how it consumes memory, and how modern serving frameworks manage it efficiently is fundamental to building production LLM systems.
In this chapter, we'll dive deep into the mechanics of KV caching, explore why memory management is so challenging, and dissect the PagedAttention algorithm that makes high-throughput serving possible.
How LLMs Generate Text: Token by Token
To understand why the KV Cache exists, we first need to understand the fundamental architecture of text generation in Transformer-based LLMs.
The Autoregressive Process
Large Language Models don't generate a complete response in one step. They generate it autoregressively—one token at a time, where each new token depends on all previously generated tokens.
- Input: "The capital of France is" → Output: "Paris"
- Input: "The capital of France is Paris" → Output: ","
- Input: "The capital of France is Paris," → Output: "known"
- ...and so on, token by token
The Computational Challenge
The core operation in a Transformer is self-attention. Each token must "attend to" (compute relationships with) all previous tokens in the sequence. Without optimization, this creates a catastrophic scaling problem:
# Naive approach (DO NOT DO THIS!)
def generate_next_token_naive(prompt_tokens, generated_tokens):
# Concatenate all tokens
all_tokens = prompt_tokens + generated_tokens
# Run the ENTIRE sequence through the model again
# This recomputes attention for all previous tokens
hidden_states = model.forward(all_tokens)
# Get the last hidden state to predict next token
next_token_logits = hidden_states[-1]
return sample_token(next_token_logits)
# For a 1000-token sequence:
# Token 1: Process 1000 tokens
# Token 2: Process 1001 tokens (recompute everything!)
# Token 3: Process 1002 tokens (recompute everything again!)
# ...
# Token 100: Process 1100 tokens
#
# Total computation: 1000 + 1001 + 1002 + ... + 1100
# ≈ 105,050 token processing operations
# With optimization: 1000 + 100 = 1,100 token processing operations
#
# Savings: ~95x faster!
This naive approach scales quadratically (O(n²)) with sequence length. For long sequences, this becomes completely impractical. The solution? The KV Cache.
The Key-Value Cache: Making Generation Practical
How Attention Works (Refresher)
In self-attention, each token is transformed into three vectors:
- Query (Q): "What am I looking for?"
- Key (K): "What information do I have?"
- Value (V): "What information do I actually provide?"
Attention is computed by comparing each Query to all Keys, then using those similarity scores to weight the Values:
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V):
"""
Q: Query vectors (batch, seq_len, d_model)
K: Key vectors (batch, seq_len, d_model)
V: Value vectors (batch, seq_len, d_model)
"""
d_k = Q.size(-1)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
# Apply softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Weighted sum of values
output = torch.matmul(attention_weights, V)
return output
# Example dimensions for Llama 2 (70B):
# num_heads = 64
# d_head = 128
# seq_len = 2048
#
# For each layer, per token:
# Q, K, V each are: (batch=1, heads=64, seq_len=2048, d_head=128)
#
# Memory for K and V cache per layer:
# 2 (K and V) × 64 heads × 2048 tokens × 128 dim × 2 bytes (float16)
# = 67,108,864 bytes = 64 MB per layer
#
# For 80 layers: 64 MB × 80 = 5.12 GB per request!
🔍 Key Insight
The critical observation is that the Key (K) and Value (V) vectors for past tokens never change. Only the Query (Q) for the new token is different. This means we can cache all previous K and V vectors and reuse them!
How KV Caching Works
class TransformerLayerWithKVCache:
def __init__(self, d_model, num_heads):
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
# Initialize empty cache
self.k_cache = [] # List of cached Key tensors
self.v_cache = [] # List of cached Value tensors
def forward(self, x, use_cache=True):
"""
x: Input tokens (batch, seq_len, d_model)
use_cache: Whether to use/update the cache
"""
batch_size, seq_len, _ = x.shape
# Compute Q, K, V for new tokens
Q = self.compute_query(x)
K = self.compute_key(x)
V = self.compute_value(x)
if use_cache and len(self.k_cache) > 0:
# Concatenate new K, V with cached K, V
K_full = torch.cat([self.k_cache[-1], K], dim=1)
V_full = torch.cat([self.v_cache[-1], V], dim=1)
else:
# First pass: no cache yet
K_full = K
V_full = V
# Compute attention using full K and V
output = self.scaled_dot_product_attention(Q, K_full, V_full)
if use_cache:
# Update cache with concatenated K, V
self.k_cache.append(K_full)
self.v_cache.append(V_full)
return output
def clear_cache(self):
"""Clear the cache for a new sequence"""
self.k_cache = []
self.v_cache = []
# Usage example:
layer = TransformerLayerWithKVCache(d_model=4096, num_heads=32)
# Prefill: Process entire prompt
prompt_embeddings = tokenize_and_embed("The capital of France is")
output1 = layer.forward(prompt_embeddings, use_cache=True)
# Cache now contains K, V for all 5 prompt tokens
# Generation: Process one token at a time
token1_embedding = sample_and_embed(output1) # "Paris"
output2 = layer.forward(token1_embedding, use_cache=True)
# Cache now contains K, V for 6 tokens (prompt + "Paris")
token2_embedding = sample_and_embed(output2) # ","
output3 = layer.forward(token2_embedding, use_cache=True)
# Cache now contains K, V for 7 tokens
# etc...
⚡ Performance Impact
Without KV Cache:
- Generate 100 tokens for a 1000-token prompt
- Total operations: 1000 + 1001 + 1002 + ... + 1100 ≈ 105,050 token processes
- Time: ~350 seconds on a single GPU
With KV Cache:
- Prefill: Process 1000 tokens once
- Decode: Process 1 new token 100 times
- Total operations: 1000 + 100 = 1,100 token processes
- Time: ~3.5 seconds on the same GPU
Speedup: 100x faster! ⚡
The Throughput Problem: Memory as the Bottleneck
While the KV Cache solves the computation problem, it creates a massive memory problem. Let's quantify exactly how much memory these caches consume.
Memory Consumption Formula
def calculate_kv_cache_memory(
batch_size,
sequence_length,
num_layers,
num_heads,
head_dimension,
precision_bytes=2 # float16
):
"""
Calculate total memory required for KV Cache
"""
# Memory per token per layer
memory_per_token_per_layer = (
2 * num_heads * head_dimension * precision_bytes # 2 for K and V
)
# Total memory for entire cache
total_memory = (
batch_size *
sequence_length *
num_layers *
memory_per_token_per_layer
)
return total_memory
# Example: Llama 2 70B
cache_size = calculate_kv_cache_memory(
batch_size=1,
sequence_length=2048,
num_layers=80,
num_heads=64,
head_dimension=128,
precision_bytes=2
)
print(f"KV Cache size for 1 user: {cache_size / 1e9:.2f} GB")
# Output: KV Cache size for 1 user: 5.37 GB
# For 16 concurrent users:
cache_size_16 = cache_size * 16
print(f"KV Cache size for 16 users: {cache_size_16 / 1e9:.2f} GB")
# Output: KV Cache size for 16 users: 85.90 GB
# But an A100 GPU only has 80 GB!
# And the model weights themselves take ~140 GB (for 70B model)
# This doesn't fit!
Real-World Example: Llama 2 70B
| Component | Memory Usage | Notes |
|---|---|---|
| Model Weights (int8) | 70 GB | Quantized from float16 |
| KV Cache (1 user, 2K tokens) | 5.4 GB | Per concurrent user |
| Available for KV Caches (A100 80GB) | 10 GB | 80 GB total - 70 GB model |
| Max Concurrent Users (Naive) | ~1-2 users | 10 GB ÷ 5.4 GB/user |
⚠️ The Throughput Wall
This calculation reveals the fundamental problem: even with an expensive, top-tier A100 GPU, naive KV Cache management limits you to serving only 1-2 concurrent users with a 70B model. This is catastrophically low throughput for a production system.
The GPU itself can process much higher workloads, but memory management is the bottleneck. This is exactly the problem PagedAttention solves.
Memory Fragmentation: The Hidden Waste
Beyond the raw memory consumption, traditional serving systems suffer from two types of fragmentation:
# Visualizing memory fragmentation (conceptual)
# Memory layout with pre-allocated blocks (traditional approach):
#
# GPU Memory (80 GB):
# [Model: 70GB]
# [Request 1: ████████░░░░] (5.4 GB allocated, 2 GB used - 63% waste)
# [Request 2: ██████░░░░░░] (5.4 GB allocated, 1.5 GB used - 72% waste)
# [Request 3: COMPLETED] (5.4 GB freed - but one contiguous block)
# [Cannot fit Request 4 - need 5.4 GB contiguous]
#
# Available memory: 5.4 GB + scattered free space
# Utilization: ~35% (massive waste!)
#
# This is why naive serving achieves only 20-40% memory efficiency
Studies show that traditional serving systems waste 60-80% of allocated KV Cache memory due to these fragmentation issues. This is the problem PagedAttention is designed to solve.
PagedAttention: Virtual Memory for LLMs
PagedAttention is inspired by a brilliant observation: operating systems solved this exact problem decades ago with virtual memory and paging. Can we apply the same concept to GPU memory for KV Caches?
The Core Concept: Fixed-Size Pages
Instead of allocating one large, contiguous block for each request's KV Cache, PagedAttention divides the cache into small, fixed-size blocks called pages (or "blocks").
Block Tables: Mapping Logical to Physical
Each request has a block table—a small data structure that maps logical sequence positions to physical memory locations.
class BlockTable:
"""
Maps logical blocks (sequence positions) to physical blocks (memory locations)
"""
def __init__(self, block_size=16):
self.block_size = block_size
# Maps logical block index -> physical block index
self.table = []
def allocate_block(self, physical_block_id):
"""Allocate a new physical block for the next logical position"""
self.table.append(physical_block_id)
def get_physical_block(self, logical_block_idx):
"""Get the physical location of a logical block"""
return self.table[logical_block_idx]
def num_blocks(self):
"""Number of blocks allocated"""
return len(self.table)
class PagedKVCacheManager:
"""
Manages a pool of physical memory blocks
"""
def __init__(self, total_blocks, block_size=16):
self.block_size = block_size
self.total_blocks = total_blocks
# Free list of available physical blocks
self.free_blocks = list(range(total_blocks))
# Physical memory storage (simplified)
# In reality, this is GPU memory
self.physical_memory = [None] * total_blocks
# Track which requests are using which blocks
self.active_requests = {} # request_id -> BlockTable
def allocate_blocks_for_request(self, request_id, num_tokens):
"""Allocate blocks for a new request"""
num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
if len(self.free_blocks) < num_blocks_needed:
raise MemoryError("Not enough free blocks!")
# Create block table for this request
block_table = BlockTable(block_size=self.block_size)
# Allocate non-contiguous blocks
for _ in range(num_blocks_needed):
physical_block = self.free_blocks.pop(0)
block_table.allocate_block(physical_block)
self.active_requests[request_id] = block_table
return block_table
def free_request(self, request_id):
"""Free all blocks used by a request"""
block_table = self.active_requests.pop(request_id)
# Return physical blocks to free list
for logical_idx in range(block_table.num_blocks()):
physical_block = block_table.get_physical_block(logical_idx)
self.free_blocks.append(physical_block)
self.physical_memory[physical_block] = None
def extend_request(self, request_id, additional_tokens):
"""Allocate additional blocks as sequence grows"""
block_table = self.active_requests[request_id]
additional_blocks = (additional_tokens + self.block_size - 1) // self.block_size
for _ in range(additional_blocks):
if not self.free_blocks:
raise MemoryError("No free blocks for extension!")
physical_block = self.free_blocks.pop(0)
block_table.allocate_block(physical_block)
# Usage example:
cache_manager = PagedKVCacheManager(total_blocks=1000, block_size=16)
# Request 1: 64 tokens (needs 4 blocks)
table1 = cache_manager.allocate_blocks_for_request("req1", num_tokens=64)
print(f"Request 1 block table: {table1.table}")
# Output: Request 1 block table: [0, 1, 2, 3]
# Request 2: 48 tokens (needs 3 blocks)
table2 = cache_manager.allocate_blocks_for_request("req2", num_tokens=48)
print(f"Request 2 block table: {table2.table}")
# Output: Request 2 block table: [4, 5, 6]
# Request 1 grows by 32 tokens (needs 2 more blocks)
cache_manager.extend_request("req1", additional_tokens=32)
print(f"Request 1 extended: {table1.table}")
# Output: Request 1 extended: [0, 1, 2, 3, 7, 8]
# Note: Blocks [0,1,2,3,7,8] are NOT contiguous in physical memory!
# But logically, they represent a continuous 96-token sequence
How PagedAttention Eliminates Fragmentation
🔑 Key Benefits
- No internal fragmentation: Allocate only as many pages as needed. If a sequence is 50 tokens, allocate 4 pages (64 tokens), wasting only 14 tokens (22%) instead of potentially thousands.
- No external fragmentation: Pages can be scattered anywhere in memory. No need for large contiguous blocks.
- On-demand allocation: Pages are allocated as the sequence grows, not pre-allocated for maximum length.
- Easy reuse: When a request completes, its pages immediately return to the free pool and can be used by any new request.
Memory Efficiency Comparison
| Approach | Fragmentation | Memory Waste | Concurrent Users (A100) |
|---|---|---|---|
| Naive Pre-allocation | Very High | 60-80% | 1-2 users |
| Dynamic Allocation | Medium | 30-50% | 3-5 users |
| PagedAttention (vLLM) | Minimal | <4% | 8-12 users |
By reducing memory waste from 60-80% to less than 4%, PagedAttention enables 4-6x more concurrent users on the same hardware—a transformational improvement.
Continuous Batching: Keeping the GPU Busy
PagedAttention solves the memory problem, but there's another critical technique that maximizes GPU utilization: continuous batching (also called "in-flight batching").
The Problem with Static Batching
Traditional batching collects a fixed number of requests, processes them together, and waits for the entire batch to complete before starting a new batch. This creates inefficiency:
- Some requests finish quickly (short outputs)
- Some requests take much longer (long outputs)
- The GPU sits idle waiting for the slowest request in the batch
# Static batching timeline (INEFFICIENT)
# Batch 1: [Req1, Req2, Req3, Req4]
# Time:
# 0-2s: GPU at 100% (all 4 requests generating)
# 2-3s: GPU at 75% (Req1 finished, 3 remaining)
# 3-5s: GPU at 50% (Req2 finished, 2 remaining)
# 5-7s: GPU at 25% (Req3 finished, 1 remaining)
# 7-10s: GPU at 25% (only Req4 still generating)
#
# Average GPU utilization: 50-60%
# Wasted time: 5-6 seconds of low utilization
# Batch 2: [Req5, Req6, Req7, Req8]
# Cannot start until Req4 completes at t=10s!
Continuous Batching: Dynamic Scheduling
Continuous batching solves this by allowing requests to be added to the active batch as soon as any request finishes:
# Continuous batching timeline (EFFICIENT)
# Active batch starts with: [Req1, Req2, Req3, Req4]
# Time:
# 0-2s: GPU at 100% [Req1, Req2, Req3, Req4]
# 2s: Req1 finishes -> immediately add Req5
# 2-3s: GPU at 100% [Req2, Req3, Req4, Req5]
# 3s: Req2 finishes -> immediately add Req6
# 3-5s: GPU at 100% [Req3, Req4, Req5, Req6]
# 5s: Req3 finishes -> immediately add Req7
# 5-7s: GPU at 100% [Req4, Req5, Req6, Req7]
# 7s: Req4 finishes -> immediately add Req8
# 7-10s: GPU at 100% [Req5, Req6, Req7, Req8]
#
# Average GPU utilization: 95-100%
# No wasted time!
class ContinuousBatchScheduler:
def __init__(self, max_batch_size=32):
self.max_batch_size = max_batch_size
self.active_batch = []
self.waiting_queue = Queue()
def add_request(self, request):
"""Add a new request to the queue"""
if len(self.active_batch) < self.max_batch_size:
# Space available, add immediately
self.active_batch.append(request)
else:
# Batch full, add to waiting queue
self.waiting_queue.put(request)
def generation_step(self):
"""Execute one generation step for all active requests"""
finished_indices = []
# Generate next token for all active requests
for i, req in enumerate(self.active_batch):
next_token = self.model_generate_next_token(req)
req.append_token(next_token)
# Check if request is complete
if req.is_finished():
finished_indices.append(i)
req.callback(req.get_result())
# Remove finished requests
for idx in reversed(finished_indices):
self.active_batch.pop(idx)
# Add new requests from queue to fill the batch
while len(self.active_batch) < self.max_batch_size:
try:
new_req = self.waiting_queue.get_nowait()
self.active_batch.append(new_req)
except:
break # No more waiting requests
def run(self):
"""Main serving loop"""
while True:
if len(self.active_batch) > 0:
self.generation_step()
else:
time.sleep(0.001) # Wait for new requests
Performance Impact
| Metric | Static Batching | Continuous Batching | Improvement |
|---|---|---|---|
| GPU Utilization | 50-60% | 85-95% | +50% |
| Throughput (req/sec) | 5-8 | 15-25 | +200% |
| Average Latency | 8-12 sec | 4-6 sec | -50% |
| Cost per Token | $0.015 | $0.005 | -67% |
Summary: Unlocking High-Throughput Serving
Let's recap the key techniques that enable production-grade LLM serving:
🔑 Key Takeaways
- KV Cache is essential: Reduces token generation from O(n²) to O(n), enabling practical generation speeds
- Memory is the bottleneck: KV Cache memory consumption limits throughput more than compute capacity
- PagedAttention eliminates waste: Reduces memory fragmentation from 60-80% to <4%, enabling 4-6x more concurrent users
- Continuous batching maximizes utilization: Keeps GPU at 85-95% utilization by dynamically scheduling requests
- Combined impact: Together, these techniques provide 20-50x throughput improvement over naive serving
While we've explained the concepts, implementing them correctly requires thousands of lines of highly optimized CUDA kernels and sophisticated scheduling algorithms. In the next chapter, we'll explore vLLM—the production-ready framework that implements PagedAttention and continuous batching, making these techniques accessible to practitioners.