If you've ever wondered why a chatbot's first word takes a bit longer and then the words flow quickly one by one, this piece is for you. We'll break down how continuous batching works starting from attention, the KV cache, and optimizing throughput. I promise the technical parts will be clear and useful for engineers and curious readers alike.
Attention, tokens, and why tensor shapes matter
Remember that language models are basically next-token predictors. Internally, each token is represented as a vector of dimension d. If you have a sequence of S tokens and a batch of B sequences, the typical shapes are B x S x d.
Attention is where tokens interact. From x (the input representation) you compute three projections Q, K and V. For a single head, Q, K and V have shape B x S x h, and the resulting attention matrix has shape . That's why we say attention is quadratic in sequence length: computing the matrix costs O(S^2).
B x S x S
S x S
The attention mask decides who can look at whom: a causal mask prevents future tokens from influencing past tokens. That mask is the lever that later lets us mix sequences without them "contaminating" each other.
Prefill and decode: why the first pass costs more
When you send a long prompt, the model does a prefill: it processes the whole sequence to produce the first output token. That means computing Q,K,V for every token and passing them through all layers.
Then comes decode: to generate the next token you don't need to recompute K and V for previous tokens if you saved them. That's where the KV cache enters. By storing K and V per token, the cost of generating a token goes from O(S) to O(1) per token (in terms of recomputing K and V), at the expense of memory.
If you don't store the cache, every new token forces you to reprocess the entire context. Can you imagine that waste with thousands of concurrent users?
KV cache size (formula and rough example)
You need to store K and V per layer. If the model dimension is d and there are L layers, the size per token (in bytes) is: cache_per_token = 2 * L * d * bytes_per_value.
With float16 (2 bytes per value) and a model with L = 32 and d = 4096, the order of magnitude per token would be: cache_per_token ≈ 2 * 32 * 4096 * 2 ≈ 512 KB per token. It's a rough example so you get a sense that the cache consumes memory fast and that's why batching strategy matters.
Chunked prefill: split so the GPU doesn't explode
When a prompt doesn't fit in memory, there's no magic: we split it into chunks. Each chunk generates its K,V and you concatenate them into the cache. Chunked prefill lets you process long prompts incrementally, using less memory per pass while preserving context integrity.
This is crucial for services that accept huge contexts (repositories, documents, etc.).
Traditional batching and its problem: padding
The simple way to parallelize is adding a batch dimension: B x S x d. But tensors must be rectangular, so we pad shorter sequences up to the batch maximum length.
What happens when lengths vary or when one sequence finishes before another? Padding creates useless work: GPU cycles that don't contribute to real outputs. And if you use optimizations that require static shapes (CUDA graphs, torch.compile), you end up padding everything to the max—lots of waste.
Also, if you insert a new long request while others are decoding, the padding you need can grow quadratically with batch size and prompt length.
Dynamic batching: swapping conversations on the fly
A natural improvement is dynamic scheduling: when a sequence ends, replace it with a pending one. That keeps the GPU busy with useful work, but with the traditional batch scheme you still need to pad to align the new prompt's length with ongoing sequences.
It's better, but still suffers from padding and static-shape constraints.
Ragged batching: remove the batch axis and concatenate tokens
What if we remove the B dimension entirely and just concatenate tokens from different conversations into one big sequence? We don't want tokens from different conversations to interact, but the attention mask gives us exactly that control.
Ragged batching concatenates all sequences (prefills and decode tokens) into a single sequence of length T and builds an attention mask that only allows legitimate interactions.
Advantage: zero padding. Every token you process in a forward pass is real and useful. Limitation: you must carefully manage the mask and the cache.
Continuous batching: combine everything to maximize throughput
Continuous batching combines three ingredients:
KV caching to avoid recomputing K and V for prior tokens.
Chunked prefill to slice long prompts and fit limited memory.
Ragged batching with dynamic scheduling to remove padding and keep the GPU always full.
Simplified algorithm to maximize tokens per second:
Keep a target token quota per batch, T_max, based on available memory.
Add all prompts currently in decode first (each contributes 1 token to the total).
Fill the remaining space with prefill chunks (each chunk can add multiple tokens depending on chunk size).
When a conversation ends, remove it and fill the gap with incoming new chunks.
This way you mix prefill and decode in the same forward pass without adding padding: every token counts.
Costs, benefits, and practical considerations
Throughput: Continuous batching maximizes tokens per second because each forward pass produces useful tokens and avoids recomputation and padding.
Memory: The KV cache and the total tokens T_max limit how much you can concatenate. Dimension T_max according to L, d, bytes_per_value and available GPU memory.
Latency for the first token: Prefill still costs since it's a full pass. Continuous batching amortizes that cost when you serve many concurrent users.
Implementation complexity: handling ragged masks, updating offsets in the KV cache, and keeping scheduling efficient is more complex than traditional batching. But modern inference tools are already incorporating these ideas.
Compilation constraints: techniques like CUDA graphs and compilers that require static shapes complicate the design. One strategy is to fix T_max and use ragged packing internally, or build graphs for ranges of T_max.
When should you use continuous batching?
If you serve many concurrent users with variable-length prompts, you'll almost certainly see gains.
If your workloads are small and homogeneous, traditional batching may be sufficient and simpler.
If you have strict first-token latency requirements (e.g., real-time chat), mix continuous batching with policies that prioritize interactive responses.
Final reflection
Continuous batching isn't a magic trick; it's the logical consequence of understanding how models use attention and cache. By removing the batch dimension and controlling interactions with the mask, you make every GPU cycle produce tokens that matter. Is the added complexity worth it? For large-scale services, almost always yes: more useful tokens per second, less waste, and better memory use.
In the next installment we'll explore efficient KV cache management with paged attention and how to prevent cache blow-up when context grows huge. If you work on inference infra, this will interest you.