DIFF Transformer V2: atención diferencial para LLMs | Keryc
DIFF Transformer V2 llega como una versión más práctica y estable de la idea diferencial en la atención. ¿Qué cambia respecto a DIFF V1 y por qué importa si estás entrenando o desplegando un gran modelo de lenguaje? Aquí te lo explico con técnica pero sin perder la claridad.
Qué es DIFF V2 y por qué lo hicieron
DIFF V2 implementa la operación diferencial directamente en la atención: duplica las cabezas de consulta (query) a 2h, mantiene las cabezas de key-value (KV) en h_kv y luego resta pares de cabezas (head 0 menos head 1, head 2 menos head 3, etc.). La resta se escala por un factor proyectado lambda por token y por cabeza, y luego se reduce de vuelta a la dimensión original antes de W_O, así W_O permanece igual que en el Transformer base.
¿Por qué esta estructura? Porque permite lograr la capacidad expresiva de una atención diferencial sin pagar el coste en caché de valores ni necesitar kernels de atención personalizados. En otras palabras: mantiene velocidad de decodificación comparable al Transformer estándar y facilita uso práctico en LLMs.
Diseño técnico y piezas clave
Duplicado de query heads: DIFF V2 usa 2h cabezas de query pero sigue usando h_kv cabezas para key y value. Después de la operación diferencial, la dimensión vuelve a h * d para que W_O sea compatible.
Lambda proyectada por token y por cabeza: lambda viene de proyectar X (las representaciones de token) y se aplica vía sigmoid para mantener la escala acotada. Esto da control fino del Context RMS por cabeza y por token.
No hay per-head RMSNorm en el contexto: a diferencia de DIFF V1, DIFF V2 elimina la normalización RMS por cabeza en el contexto, porque la proyección de lambda y la modificación de las dimensiones resuelven los problemas numéricos que motivaron esa normalización.
Compatibilidad con kernels existentes: al alinear las dimensiones de cabeza entre Q, K y V, DIFF V2 evita la necesidad de kernels especiales y se beneficia de FlashAttention moderno en GPUs H-series y B-series.
Código conceptual
Un esqueleto simplificado de las dos versiones ayuda a ver la diferencia:
Fíjate en dos puntos: en V2 las parejas de cabezas están entrelazadas (0::2 y 1::2) para compartir el mismo KV, y lambda es por token-cabeza con sigmoid.
Context RMS y estabilidad numérica
En la atención softmax estándar, si los v_j se asumen con RMS 1 y no correlacionados, el Context RMS queda en el rango [1/sqrt(n), 1). Esto impone una limitación: si la atención se rebaja hacia distribuciones uniformes en secuencias largas, la magnitud cae a 1/sqrt(n). En DIFF V1 intentaron corregir eso con una RMSNorm por cabeza, pero eso obliga a multiplicar por factores enormes cuando n es grande (por ejemplo sqrt(8192) alrededor de 90), lo que genera gradientes enormes y explosión numérica.
DIFF V2 soluciona esto de otra forma: al proyectar lambda por token y cabeza y eliminar la RMSNorm por cabeza, la escala de los gradientes vuelve a ser comparable a la de un Transformer estándar. En la práctica esto reduce picos de gradiente y outliers de activación en entrenamientos con learning rates grandes.
Resultados empíricos y comportamiento en entrenamiento
Los autores corren preentrenamientos a escala productiva (trillones de tokens, modelos densos y MoE de 30A3) con learning rates grandes (6e-4 a 1e-3). Observaciones preliminares:
Reducción notable de la pérdida de modelado de lenguaje frente al Transformer: gap de 0.02 a 0.03.
Menos picos de gradiente y activaciones outlier, especialmente con learning rates grandes donde el Transformer puede volverse inestable.
El overhead de throughput en preentrenamiento es despreciable si usas FlashAttention en GPUs H-series y B-series.
También recomiendan, para prefilling de secuencias largas, combinar DIFF V2 con técnicas como YOCO que bajan la complejidad de prefilling a tiempo lineal.
Costos, parámetros y comparación teórica
Si comparas DIFF V2 con un Transformer que simplemente tenga 2h cabezas reales, ambos tienen el mismo coste de kernel de atención pero DIFF V2 requiere menos parámetros en la proyección de salida W_O. Con la configuración GQA actual (grouped query attention), aproximadamente 25% de los parámetros del módulo de atención se pueden ahorrar, porque la memoria y parámetros de W_Q y W_O dominan.
Además, si el objetivo es simplemente igualar la pérdida del Transformer pero ganar estabilidad de entrenamiento o mejor control de outliers, DIFF V2 ya tiene sentido práctico. No todo se resume a la mínima pérdida: estabilidad y eficiencia operativa también cuentan.
Ablaciones importantes y errores comunes
Los autores reportan varias pruebas de ablasión que muestran qué no hacer:
Error de emparejamiento de cabezas: dividir las mitades de atención como attn[:, :nh//2] y attn[:, nh//2:] es incorrecto. Las cabezas diferenciales deben ser pares intercalados attn[:, 0::2] y attn[:, 1::2] para compartir el mismo KV. La implementación equivocada resulta en inestabilidad y pérdida más alta.
Omitir lambda en la resta: usar attn1 - attn2 sin escalado lleva a un Context RMS demasiado pequeño al inicio.
No aplicar sigmoid a lambda: usar la lambda proyectada sin regularización puede dejar el Context RMS sin límite superior y causar inestabilidad.
Estas tres ablasiones degradan el entrenamiento; la primera es la más dañina para estabilidad.
Compatibilidad con sparse attention y consideraciones prácticas
DIFF V2 es compatible con esquemas de atención dispersa. El desafío práctico está en la selección de bloques de KV cuando las cabezas diferenciadas forman una gran GQA group. Estrategias posibles:
Seleccionar bloques basados en promedio de logits entre las cabezas diferenciadas, o
manejar por separado las dos clases de cabezas durante selección.
Conceptualmente no hay un obstáculo fundamental, solo un ajuste en la heurística de selección de bloques para mantener la aceleración.
Recomendaciones para equipo de ML y ML infra
Si usas FlashAttention moderno en GPUs H-series o B-series, DIFF V2 tiene overhead pequeño en preentrenamiento.
En decodificación, DIFF V2 evita duplicar la carga de la KV-cache, así mantiene latencias y uso de memoria comparables al Transformer base.
Si planeas entrenar con learning rates altos o a escala productiva, DIFF V2 puede ofrecer ganancias en estabilidad y menor magnitud de outliers.
Ten cuidado al implementar la indexación de cabezas y con la proyección de lambda (usar sigmoid).
DIFF V2 no es solo una curiosidad teórica: es una reformulación práctica de la operación diferencial en atención que prioriza estabilidad, compatibilidad con kernels existentes y eficiencia en decodificación.