Si alguna vez te has preguntado por qué la primera palabra de la respuesta de un chatbot tarda un poco y luego las palabras aparecen rápidamente una a una, esta pieza es para ti. Vamos a desmenuzar cómo funciona el batching continuo partiendo de la atención, el KV cache y optimizando el throughput. Te prometo que lo técnico será claro y útil para ingenieros y curiosos por igual.
Atención, tokens y por qué importan las formas de los tensores
¿Recuerdas que los modelos de lenguaje son básicamente predictores del siguiente token? Internamente, cada token se representa como un vector de dimensión d. Si tienes una secuencia de S tokens y un lote de B secuencias, las formas típicas son B x S x d.
La atención es el lugar donde los tokens interactúan. A partir de x (la representación de entrada) se calculan tres proyecciones Q, K y V. Para una cabeza, Q, K y V tienen forma B x S x h, y la matriz de atención resultante tiene forma B x S x S. Por eso decimos que la atención es cuadrática en la longitud de la secuencia: calcular la matriz S x S cuesta O(S^2).
La máscara de atención (attention mask) es la que decide quién puede mirar a quién: una máscara causal evita que tokens futuros influyan en tokens pasados. Esta máscara es la palanca que nos permitirá después mezclar secuencias sin que se «contaminen" entre sí.
Prefill y decode: por qué el primer pase cuesta más
Cuando envías un prompt largo, el modelo hace un prefill: procesa toda la secuencia para producir el primer token de salida. Eso implica calcular Q,K,V para todos los tokens y pasarlos por todas las capas.
Después viene el decode: para generar el token siguiente no necesitas volver a calcular K y V de los tokens previos si los guardaste. Ahí entra el KV cache. Al guardar K y V por cada token, el costo de generar un token pasa de O(S) a O(1) por token (en términos de recomputación de K y V), a cambio de consumo de memoria.
Si no guardas la cache, cada token nuevo exige volver a reprocesar todo el contexto. ¿Te imaginas ese desperdicio si hay miles de usuarios concurrentes?
Tamaño del KV cache (fórmula y ejemplo aproximado)
Por capa necesitas almacenar K y V. Si la dimensión del modelo es d y hay L capas, el tamaño por token (en bytes) es: cache_per_token = 2 * L * d * bytes_per_value.
Con float16 (2 bytes por valor) y un modelo con L = 32 y d = 4096, el orden de magnitud por token sería: cache_per_token ≈ 2 * 32 * 4096 * 2 ≈ 512 KB por token. Es un ejemplo aproximado para que intuyas que la cache consume memoria rápido y por eso hay que diseñar bien la estrategia de batching.
Chunked prefill: dividir para no explotar la GPU
Cuando un prompt no cabe en memoria, no hay magia: lo dividimos en trozos (chunks). Cada chunk genera su K,V y se concatena en la cache. Chunked prefill permite procesar prompts largos incrementalmente, usando menos memoria por pase y manteniendo la integridad del contexto.
Esto es crucial para servicios que aceptan contextos enormes (repositorios, documentos, etc.).
Batching tradicional y su problema: el padding
La forma sencilla de paralelo es añadir una dimensión de batch: B x S x d. Pero las tensores deben ser rectangulares, así que rellenamos (padding) las secuencias más cortas hasta la longitud máxima del batch.
¿Qué pasa cuando las longitudes varían o cuando una secuencia termina antes que otra? El padding introduce trabajo inútil: ciclos de GPU que no contribuyen a generar respuestas reales. Y si usas optimizaciones que requieren formas estáticas (CUDA graphs, torch.compile), terminas paddeando todo al máximo—mucho desperdicio.
Además, si insertas en medio una nueva petición larga mientras otras están en decode, el padding que necesitas puede crecer de forma cuadrática con el batch size y la longitud de los prompts.
Dynamic batching: intercambiar conversaciones en caliente
Una mejora natural es la programación dinámica: cuando una secuencia termina, la reemplazamos por otra pendiente. Eso mantiene la GPU ocupada con trabajo útil, pero con el esquema tradicional de batch necesitas introducir padding para alinear la longitud del nuevo prompt con las secuencias en curso.
Es mejor, pero todavía sufre por el padding y por las restricciones de formas estáticas.
Ragged batching: quitar el eje de batch y concatenar tokens
¿Qué pasaría si eliminamos por completo la dimensión B y simplemente concatenamos tokens de distintas conversaciones en una gran secuencia? No queremos que tokens de distintas conversaciones interactúen, pero la máscara de atención nos da exactamente ese control.
Ragged batching consiste en concatenar todas las secuencias (prefills y tokens en decode) en una sola secuencia de largo T y construir una attention mask que permita únicamente las interacciones legítimas.
Ventaja: cero padding. Todas las tokens que proceses en un forward pass son reales y útiles. Limitación: tienes que gestionar cuidadosamente la máscara y la cache.
Continuous batching: juntar todo para maximizar throughput
Continuous batching combina tres ingredientes:
KV cachingpara evitar recomputarKyVde tokens previos.Chunked prefillpara trocear prompts largos y adaptarse a memoria limitada.Ragged batchingcondynamic schedulingpara eliminar padding y mantener el GPU siempre lleno.
Algoritmo simplificado para maximizar tokens por segundo:
- Mantén una cuota de tokens objetivo por batch,
T_max, según la memoria disponible. - Añade primero todos los prompts en fase de decoding (cada uno contribuye con 1 token al total).
- Completa el resto del espacio con chunks de prefill (cada chunk puede aportar múltiples tokens según chunk size).
- Cuando una conversación termina, retírala y rellena el hueco con nuevos chunks entrantes.
Así mezclas prefill y decode en el mismo forward pass sin introducir padding: cada token cuenta.
Costos, beneficios y consideraciones prácticas
-
Throughput: Continuous batching maximiza tokens por segundo porque cada forward pass produce tokens útiles y evita recomputación y padding.
-
Memoria: El KV cache y el total de tokens
T_maxlimitan cuánto puedes concatenar. Hay que dimensionarT_maxsegúnL,d,bytes_per_valuey memoria GPU disponible. -
Latencia por primer token: El prefill sigue costando por ser un pase completo. Continuous batching amortiza este costo cuando atiendes muchos usuarios concurrentes.
-
Complejidad de implementación: manejar máscaras ragged, actualizar offsets en el KV cache, y mantener scheduling eficiente es más complicado que la versión batched tradicional. Pero herramientas modernas de inferencia ya están incorporando estas ideas.
-
Restricciones de compilación: técnicas como CUDA graphs y compiladores que requieren formas estáticas complican el diseño. Una estrategia es fijar
T_maxy usar ragged packing internamente, o diseñar graph builds por rangos deT_max.
¿Cuándo conviene usar continuous batching?
-
Si sirves muchos usuarios concurrentes con prompts de longitud variable, es casi seguro que verás ganancias.
-
Si tus cargas son pequeñas y homogéneas, el batching tradicional puede ser suficiente y más simple.
-
Si tienes requisitos estrictos de latencia por primer token (por ejemplo, chat en tiempo real), mezcla continuous batching con políticas que prioricen respuestas interactivas.
Reflexión final
Continuous batching no es un truco mágico; es la consecuencia lógica de entender cómo los modelos usan atención y cache. Al quitar la dimensión de batch y controlar las interacciones con la máscara, logras usar cada ciclo de GPU para producir tokens que importan. ¿Vale el aumento de complejidad? Para servicios a escala, casi siempre sí: más tokens útiles por segundo, menos desperdicio y mejor uso de memoria.
En la próxima entrega exploraremos la gestión eficiente del KV cache con paged attention y cómo evitar que la cache explote la memoria cuando el contexto crece mucho. Si trabajas en infra de inferencia, esto te va a interesar.
