MODULE 5 - CHAPTER 2 ⏱️ 35 min read 📖 2,800 words

KV Caching and Batching Deep Dive

Understanding the PagedAttention algorithm and how continuous batching unlocks massive throughput gains

The Key-Value (KV) Cache is both the hero and the villain of LLM serving. It's absolutely essential for achieving reasonable generation speed, yet it's also the primary memory bottleneck that limits throughput. Understanding how the KV Cache works, how it consumes memory, and how modern serving frameworks manage it efficiently is fundamental to building production LLM systems.

In this chapter, we'll dive deep into the mechanics of KV caching, explore why memory management is so challenging, and dissect the PagedAttention algorithm that makes high-throughput serving possible.

How LLMs Generate Text: Token by Token

To understand why the KV Cache exists, we first need to understand the fundamental architecture of text generation in Transformer-based LLMs.

The Autoregressive Process

Large Language Models don't generate a complete response in one step. They generate it autoregressively—one token at a time, where each new token depends on all previously generated tokens.

Prompt Processing (Prefill Phase)
When you send a prompt to an LLM, the model first processes all the prompt tokens in a single parallel forward pass through the Transformer. This is called the "prefill" phase. At the end of this phase, the model predicts the probability distribution for the very first output token.
Example: Given the prompt "The capital of France is", the model processes all 5 tokens in parallel and outputs probabilities like: "Paris" (85%), "located" (8%), "a" (3%), etc.
Decoding Phase (Autoregressive Generation)
After the first token is sampled, the model must generate the second token. To do this, it needs to consider the entire history: the original prompt plus the first generated token. Then for the third token, it considers the prompt + first two tokens, and so on.
Example:
  • Input: "The capital of France is" → Output: "Paris"
  • Input: "The capital of France is Paris" → Output: ","
  • Input: "The capital of France is Paris," → Output: "known"
  • ...and so on, token by token

The Computational Challenge

The core operation in a Transformer is self-attention. Each token must "attend to" (compute relationships with) all previous tokens in the sequence. Without optimization, this creates a catastrophic scaling problem:

# Naive approach (DO NOT DO THIS!)
def generate_next_token_naive(prompt_tokens, generated_tokens):
    # Concatenate all tokens
    all_tokens = prompt_tokens + generated_tokens

    # Run the ENTIRE sequence through the model again
    # This recomputes attention for all previous tokens
    hidden_states = model.forward(all_tokens)

    # Get the last hidden state to predict next token
    next_token_logits = hidden_states[-1]

    return sample_token(next_token_logits)

# For a 1000-token sequence:
# Token 1: Process 1000 tokens
# Token 2: Process 1001 tokens (recompute everything!)
# Token 3: Process 1002 tokens (recompute everything again!)
# ...
# Token 100: Process 1100 tokens
#
# Total computation: 1000 + 1001 + 1002 + ... + 1100
#                  ≈ 105,050 token processing operations
# With optimization: 1000 + 100 = 1,100 token processing operations
#
# Savings: ~95x faster!

This naive approach scales quadratically (O(n²)) with sequence length. For long sequences, this becomes completely impractical. The solution? The KV Cache.

The Key-Value Cache: Making Generation Practical

How Attention Works (Refresher)

In self-attention, each token is transformed into three vectors:

  • Query (Q): "What am I looking for?"
  • Key (K): "What information do I have?"
  • Value (V): "What information do I actually provide?"

Attention is computed by comparing each Query to all Keys, then using those similarity scores to weight the Values:

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V):
    """
    Q: Query vectors (batch, seq_len, d_model)
    K: Key vectors (batch, seq_len, d_model)
    V: Value vectors (batch, seq_len, d_model)
    """
    d_k = Q.size(-1)

    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))

    # Apply softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)

    # Weighted sum of values
    output = torch.matmul(attention_weights, V)

    return output


# Example dimensions for Llama 2 (70B):
# num_heads = 64
# d_head = 128
# seq_len = 2048
#
# For each layer, per token:
# Q, K, V each are: (batch=1, heads=64, seq_len=2048, d_head=128)
#
# Memory for K and V cache per layer:
# 2 (K and V) × 64 heads × 2048 tokens × 128 dim × 2 bytes (float16)
# = 67,108,864 bytes = 64 MB per layer
#
# For 80 layers: 64 MB × 80 = 5.12 GB per request!

🔍 Key Insight

The critical observation is that the Key (K) and Value (V) vectors for past tokens never change. Only the Query (Q) for the new token is different. This means we can cache all previous K and V vectors and reuse them!

How KV Caching Works

class TransformerLayerWithKVCache:
    def __init__(self, d_model, num_heads):
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        # Initialize empty cache
        self.k_cache = []  # List of cached Key tensors
        self.v_cache = []  # List of cached Value tensors

    def forward(self, x, use_cache=True):
        """
        x: Input tokens (batch, seq_len, d_model)
        use_cache: Whether to use/update the cache
        """
        batch_size, seq_len, _ = x.shape

        # Compute Q, K, V for new tokens
        Q = self.compute_query(x)
        K = self.compute_key(x)
        V = self.compute_value(x)

        if use_cache and len(self.k_cache) > 0:
            # Concatenate new K, V with cached K, V
            K_full = torch.cat([self.k_cache[-1], K], dim=1)
            V_full = torch.cat([self.v_cache[-1], V], dim=1)
        else:
            # First pass: no cache yet
            K_full = K
            V_full = V

        # Compute attention using full K and V
        output = self.scaled_dot_product_attention(Q, K_full, V_full)

        if use_cache:
            # Update cache with concatenated K, V
            self.k_cache.append(K_full)
            self.v_cache.append(V_full)

        return output

    def clear_cache(self):
        """Clear the cache for a new sequence"""
        self.k_cache = []
        self.v_cache = []


# Usage example:
layer = TransformerLayerWithKVCache(d_model=4096, num_heads=32)

# Prefill: Process entire prompt
prompt_embeddings = tokenize_and_embed("The capital of France is")
output1 = layer.forward(prompt_embeddings, use_cache=True)
# Cache now contains K, V for all 5 prompt tokens

# Generation: Process one token at a time
token1_embedding = sample_and_embed(output1)  # "Paris"
output2 = layer.forward(token1_embedding, use_cache=True)
# Cache now contains K, V for 6 tokens (prompt + "Paris")

token2_embedding = sample_and_embed(output2)  # ","
output3 = layer.forward(token2_embedding, use_cache=True)
# Cache now contains K, V for 7 tokens

# etc...

⚡ Performance Impact

Without KV Cache:

  • Generate 100 tokens for a 1000-token prompt
  • Total operations: 1000 + 1001 + 1002 + ... + 1100 ≈ 105,050 token processes
  • Time: ~350 seconds on a single GPU

With KV Cache:

  • Prefill: Process 1000 tokens once
  • Decode: Process 1 new token 100 times
  • Total operations: 1000 + 100 = 1,100 token processes
  • Time: ~3.5 seconds on the same GPU

Speedup: 100x faster! ⚡

The Throughput Problem: Memory as the Bottleneck

While the KV Cache solves the computation problem, it creates a massive memory problem. Let's quantify exactly how much memory these caches consume.

Memory Consumption Formula

def calculate_kv_cache_memory(
    batch_size,
    sequence_length,
    num_layers,
    num_heads,
    head_dimension,
    precision_bytes=2  # float16
):
    """
    Calculate total memory required for KV Cache
    """
    # Memory per token per layer
    memory_per_token_per_layer = (
        2 * num_heads * head_dimension * precision_bytes  # 2 for K and V
    )

    # Total memory for entire cache
    total_memory = (
        batch_size *
        sequence_length *
        num_layers *
        memory_per_token_per_layer
    )

    return total_memory


# Example: Llama 2 70B
cache_size = calculate_kv_cache_memory(
    batch_size=1,
    sequence_length=2048,
    num_layers=80,
    num_heads=64,
    head_dimension=128,
    precision_bytes=2
)

print(f"KV Cache size for 1 user: {cache_size / 1e9:.2f} GB")
# Output: KV Cache size for 1 user: 5.37 GB

# For 16 concurrent users:
cache_size_16 = cache_size * 16
print(f"KV Cache size for 16 users: {cache_size_16 / 1e9:.2f} GB")
# Output: KV Cache size for 16 users: 85.90 GB

# But an A100 GPU only has 80 GB!
# And the model weights themselves take ~140 GB (for 70B model)
# This doesn't fit!

Real-World Example: Llama 2 70B

Component Memory Usage Notes
Model Weights (int8) 70 GB Quantized from float16
KV Cache (1 user, 2K tokens) 5.4 GB Per concurrent user
Available for KV Caches (A100 80GB) 10 GB 80 GB total - 70 GB model
Max Concurrent Users (Naive) ~1-2 users 10 GB ÷ 5.4 GB/user

⚠️ The Throughput Wall

This calculation reveals the fundamental problem: even with an expensive, top-tier A100 GPU, naive KV Cache management limits you to serving only 1-2 concurrent users with a 70B model. This is catastrophically low throughput for a production system.

The GPU itself can process much higher workloads, but memory management is the bottleneck. This is exactly the problem PagedAttention solves.

Memory Fragmentation: The Hidden Waste

Beyond the raw memory consumption, traditional serving systems suffer from two types of fragmentation:

Internal Fragmentation
Systems pre-allocate a fixed-size block for each request's maximum possible sequence length. If a request completes with a shorter sequence, the unused memory in that block is wasted.
Example: You allocate 5.4 GB for a 2048-token sequence. A user's request completes at 512 tokens, using only 1.35 GB. The remaining 4.05 GB (75%!) is wasted until the entire request slot is freed.
External Fragmentation
As requests complete at different times, memory becomes fragmented into non-contiguous free chunks. Even if total free memory is sufficient for a new request, you can't use it because you need one large contiguous block.
Example: After several requests complete, you have 10 GB free memory split across 4 scattered chunks (3GB, 2GB, 3GB, 2GB). A new request needs 5.4 GB contiguous—it can't fit, even though total free memory is 10 GB!
# Visualizing memory fragmentation (conceptual)

# Memory layout with pre-allocated blocks (traditional approach):
#
# GPU Memory (80 GB):
# [Model: 70GB]
# [Request 1: ████████░░░░] (5.4 GB allocated, 2 GB used - 63% waste)
# [Request 2: ██████░░░░░░] (5.4 GB allocated, 1.5 GB used - 72% waste)
# [Request 3: COMPLETED]   (5.4 GB freed - but one contiguous block)
# [Cannot fit Request 4 - need 5.4 GB contiguous]
#
# Available memory: 5.4 GB + scattered free space
# Utilization: ~35% (massive waste!)
#
# This is why naive serving achieves only 20-40% memory efficiency

Studies show that traditional serving systems waste 60-80% of allocated KV Cache memory due to these fragmentation issues. This is the problem PagedAttention is designed to solve.

PagedAttention: Virtual Memory for LLMs

PagedAttention is inspired by a brilliant observation: operating systems solved this exact problem decades ago with virtual memory and paging. Can we apply the same concept to GPU memory for KV Caches?

The Core Concept: Fixed-Size Pages

Instead of allocating one large, contiguous block for each request's KV Cache, PagedAttention divides the cache into small, fixed-size blocks called pages (or "blocks").

Page/Block Structure
A typical page size is 16 tokens. This means each page stores the K and V vectors for 16 consecutive tokens in the sequence. Pages are allocated on-demand as the sequence grows.
Example: For a 64-token sequence, you need 4 pages (64 ÷ 16 = 4). These 4 pages can be stored anywhere in GPU memory—they don't need to be adjacent.

Block Tables: Mapping Logical to Physical

Each request has a block table—a small data structure that maps logical sequence positions to physical memory locations.

class BlockTable:
    """
    Maps logical blocks (sequence positions) to physical blocks (memory locations)
    """
    def __init__(self, block_size=16):
        self.block_size = block_size
        # Maps logical block index -> physical block index
        self.table = []

    def allocate_block(self, physical_block_id):
        """Allocate a new physical block for the next logical position"""
        self.table.append(physical_block_id)

    def get_physical_block(self, logical_block_idx):
        """Get the physical location of a logical block"""
        return self.table[logical_block_idx]

    def num_blocks(self):
        """Number of blocks allocated"""
        return len(self.table)


class PagedKVCacheManager:
    """
    Manages a pool of physical memory blocks
    """
    def __init__(self, total_blocks, block_size=16):
        self.block_size = block_size
        self.total_blocks = total_blocks

        # Free list of available physical blocks
        self.free_blocks = list(range(total_blocks))

        # Physical memory storage (simplified)
        # In reality, this is GPU memory
        self.physical_memory = [None] * total_blocks

        # Track which requests are using which blocks
        self.active_requests = {}  # request_id -> BlockTable

    def allocate_blocks_for_request(self, request_id, num_tokens):
        """Allocate blocks for a new request"""
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size

        if len(self.free_blocks) < num_blocks_needed:
            raise MemoryError("Not enough free blocks!")

        # Create block table for this request
        block_table = BlockTable(block_size=self.block_size)

        # Allocate non-contiguous blocks
        for _ in range(num_blocks_needed):
            physical_block = self.free_blocks.pop(0)
            block_table.allocate_block(physical_block)

        self.active_requests[request_id] = block_table

        return block_table

    def free_request(self, request_id):
        """Free all blocks used by a request"""
        block_table = self.active_requests.pop(request_id)

        # Return physical blocks to free list
        for logical_idx in range(block_table.num_blocks()):
            physical_block = block_table.get_physical_block(logical_idx)
            self.free_blocks.append(physical_block)
            self.physical_memory[physical_block] = None

    def extend_request(self, request_id, additional_tokens):
        """Allocate additional blocks as sequence grows"""
        block_table = self.active_requests[request_id]
        additional_blocks = (additional_tokens + self.block_size - 1) // self.block_size

        for _ in range(additional_blocks):
            if not self.free_blocks:
                raise MemoryError("No free blocks for extension!")

            physical_block = self.free_blocks.pop(0)
            block_table.allocate_block(physical_block)


# Usage example:
cache_manager = PagedKVCacheManager(total_blocks=1000, block_size=16)

# Request 1: 64 tokens (needs 4 blocks)
table1 = cache_manager.allocate_blocks_for_request("req1", num_tokens=64)
print(f"Request 1 block table: {table1.table}")
# Output: Request 1 block table: [0, 1, 2, 3]

# Request 2: 48 tokens (needs 3 blocks)
table2 = cache_manager.allocate_blocks_for_request("req2", num_tokens=48)
print(f"Request 2 block table: {table2.table}")
# Output: Request 2 block table: [4, 5, 6]

# Request 1 grows by 32 tokens (needs 2 more blocks)
cache_manager.extend_request("req1", additional_tokens=32)
print(f"Request 1 extended: {table1.table}")
# Output: Request 1 extended: [0, 1, 2, 3, 7, 8]

# Note: Blocks [0,1,2,3,7,8] are NOT contiguous in physical memory!
# But logically, they represent a continuous 96-token sequence

How PagedAttention Eliminates Fragmentation

🔑 Key Benefits

  • No internal fragmentation: Allocate only as many pages as needed. If a sequence is 50 tokens, allocate 4 pages (64 tokens), wasting only 14 tokens (22%) instead of potentially thousands.
  • No external fragmentation: Pages can be scattered anywhere in memory. No need for large contiguous blocks.
  • On-demand allocation: Pages are allocated as the sequence grows, not pre-allocated for maximum length.
  • Easy reuse: When a request completes, its pages immediately return to the free pool and can be used by any new request.

Memory Efficiency Comparison

Approach Fragmentation Memory Waste Concurrent Users (A100)
Naive Pre-allocation Very High 60-80% 1-2 users
Dynamic Allocation Medium 30-50% 3-5 users
PagedAttention (vLLM) Minimal <4% 8-12 users

By reducing memory waste from 60-80% to less than 4%, PagedAttention enables 4-6x more concurrent users on the same hardware—a transformational improvement.

Continuous Batching: Keeping the GPU Busy

PagedAttention solves the memory problem, but there's another critical technique that maximizes GPU utilization: continuous batching (also called "in-flight batching").

The Problem with Static Batching

Traditional batching collects a fixed number of requests, processes them together, and waits for the entire batch to complete before starting a new batch. This creates inefficiency:

  • Some requests finish quickly (short outputs)
  • Some requests take much longer (long outputs)
  • The GPU sits idle waiting for the slowest request in the batch
# Static batching timeline (INEFFICIENT)

# Batch 1: [Req1, Req2, Req3, Req4]
# Time:
# 0-2s:   GPU at 100% (all 4 requests generating)
# 2-3s:   GPU at 75% (Req1 finished, 3 remaining)
# 3-5s:   GPU at 50% (Req2 finished, 2 remaining)
# 5-7s:   GPU at 25% (Req3 finished, 1 remaining)
# 7-10s:  GPU at 25% (only Req4 still generating)
#
# Average GPU utilization: 50-60%
# Wasted time: 5-6 seconds of low utilization

# Batch 2: [Req5, Req6, Req7, Req8]
# Cannot start until Req4 completes at t=10s!

Continuous Batching: Dynamic Scheduling

Continuous batching solves this by allowing requests to be added to the active batch as soon as any request finishes:

# Continuous batching timeline (EFFICIENT)

# Active batch starts with: [Req1, Req2, Req3, Req4]
# Time:
# 0-2s:   GPU at 100% [Req1, Req2, Req3, Req4]
# 2s:     Req1 finishes -> immediately add Req5
# 2-3s:   GPU at 100% [Req2, Req3, Req4, Req5]
# 3s:     Req2 finishes -> immediately add Req6
# 3-5s:   GPU at 100% [Req3, Req4, Req5, Req6]
# 5s:     Req3 finishes -> immediately add Req7
# 5-7s:   GPU at 100% [Req4, Req5, Req6, Req7]
# 7s:     Req4 finishes -> immediately add Req8
# 7-10s:  GPU at 100% [Req5, Req6, Req7, Req8]
#
# Average GPU utilization: 95-100%
# No wasted time!

class ContinuousBatchScheduler:
    def __init__(self, max_batch_size=32):
        self.max_batch_size = max_batch_size
        self.active_batch = []
        self.waiting_queue = Queue()

    def add_request(self, request):
        """Add a new request to the queue"""
        if len(self.active_batch) < self.max_batch_size:
            # Space available, add immediately
            self.active_batch.append(request)
        else:
            # Batch full, add to waiting queue
            self.waiting_queue.put(request)

    def generation_step(self):
        """Execute one generation step for all active requests"""
        finished_indices = []

        # Generate next token for all active requests
        for i, req in enumerate(self.active_batch):
            next_token = self.model_generate_next_token(req)
            req.append_token(next_token)

            # Check if request is complete
            if req.is_finished():
                finished_indices.append(i)
                req.callback(req.get_result())

        # Remove finished requests
        for idx in reversed(finished_indices):
            self.active_batch.pop(idx)

        # Add new requests from queue to fill the batch
        while len(self.active_batch) < self.max_batch_size:
            try:
                new_req = self.waiting_queue.get_nowait()
                self.active_batch.append(new_req)
            except:
                break  # No more waiting requests

    def run(self):
        """Main serving loop"""
        while True:
            if len(self.active_batch) > 0:
                self.generation_step()
            else:
                time.sleep(0.001)  # Wait for new requests

Performance Impact

Metric Static Batching Continuous Batching Improvement
GPU Utilization 50-60% 85-95% +50%
Throughput (req/sec) 5-8 15-25 +200%
Average Latency 8-12 sec 4-6 sec -50%
Cost per Token $0.015 $0.005 -67%

Summary: Unlocking High-Throughput Serving

Let's recap the key techniques that enable production-grade LLM serving:

🔑 Key Takeaways

  • KV Cache is essential: Reduces token generation from O(n²) to O(n), enabling practical generation speeds
  • Memory is the bottleneck: KV Cache memory consumption limits throughput more than compute capacity
  • PagedAttention eliminates waste: Reduces memory fragmentation from 60-80% to <4%, enabling 4-6x more concurrent users
  • Continuous batching maximizes utilization: Keeps GPU at 85-95% utilization by dynamically scheduling requests
  • Combined impact: Together, these techniques provide 20-50x throughput improvement over naive serving

While we've explained the concepts, implementing them correctly requires thousands of lines of highly optimized CUDA kernels and sophisticated scheduling algorithms. In the next chapter, we'll explore vLLM—the production-ready framework that implements PagedAttention and continuous batching, making these techniques accessible to practitioners.

← Previous: Inference Optimization Next: vLLM Framework →