Volvamos al laboratorio. En la segunda entrega de la serie PRX se deja claro algo que quizá ya sospechas: las decisiones de entrenamiento mueven la aguja tanto como la arquitectura. Aquí te cuento, de forma práctica y técnica, qué realmente aceleró la convergencia, qué mejoró la calidad y qué fue puro ruido en nuestros experimentos con PRX.
Punto de partida: referencia limpia y reproducible
Antes de tocar nada, los autores fijan un baseline intencionalmente simple para poder atribuir mejoras a cambios concretos. Resumen rápido del baseline que usaron:
- Modelo: PRX-1.2B (stream único, atención global sobre tokens de imagen y texto)
- Objetivo: Flow Matching puro (
Flow Matching) en espacio latente FLUX VAE - Datos: 1M imágenes sintéticas (MidJourney V6), 256x256, batch global 256
- Optimización: AdamW, lr 1e-4, betas (0.9, 0.95), eps 1e-15, weight_decay 0
- Text encoder: GemmaT5, RoPE pos encoding, padding mask, EMA desactivada
Métricas de control (no perfectas, pero útiles): FID, CMMD (CLIP-MMD), DINO-MMD y throughput (batches/s). Pregunta guía: ¿esta modificación mejora convergencia o eficiencia respecto al baseline?
1) Alineamiento de representaciones: atajo poderoso al inicio
La idea central: separar lo que el modelo debe aprender como representación interna de la tarea de denoising. En vez de que una sola pérdida haga todo, añades una pérdida auxiliar que empuja las características intermedias del denoiser hacia el espacio de un encoder de visión fuerte y congelado.
-
REPA: añades un teacher congelado (p. ej. DINOv2 o DINOv3), obtienes embeddings por parche del teacher y fuerzas a los tokens intermedios del estudiante a coincidir en ese espacio (cosine similarity por parche). Resultado: arranque mucho más rápido y estructuras globales más limpias. En sus números: FID baja de 18.2 a 14.64 con DINOv3, pero el throughput cae de 3.95 a 3.46 batches/s por el costo del forward del teacher.
-
iREPA: ajustes puntuales para preservar estructura espacial (3x3 conv en el proyector y normalizado espacial). En su experiencia, mejora con DINOv2 pero no siempre con DINOv3; es un tweak barato que puede ayudar, pero ojo a las interacciones.
-
REPA Works Until It Does Not: buen acelerador temprano, pero puede limitar al modelo en etapas tardías. Solución práctica: usar REPA como burn-in y apagarla después (stage-wise schedule).
-
REPA-E y diseño de latentes: en lugar de solo alinear features, puedes diseñar la latencia para que sea intrínsecamente más aprendible. REPA-E actualiza el VAE con una pérdida de alineamiento mientras evita que la pérdida de flujo tome atajos dañinos. Flux2-AE apunta lo mismo desde otro ángulo: latentes con estructura semántica son más fáciles de modelar.
Resultados de latentes: pasar a Flux2-AE o REPA-E reduce FID dramáticamente (18.20 -> ~12.08). Flux2-AE logra mejores CMMD/DINO-MMD pero con costoso throughput (3.95 -> 1.79). REPA-E es más balanceado (throughput 3.39) y mantiene ganancias fuertes.
Lección práctica
Si tienes límite de compute, usa alineamiento al inicio y considera invertir en un tokenizer/AE diseñado para learnability: da el salto más grande por paso de entrenamiento.
2) Objetivos de entrenamiento: no subestimes la formulación de la pérdida
Pequeños cambios en la pérdida generan efectos grandes.
-
Contrastive Flow Matching: añade un término que empuja las trayectorias condicionales a ser distintas entre sí (usar negativos del batch). En su setup texto-a-imagen el beneficio fue modesto: mejora CMMD/DINO-MMD levemente, FID no mejora en ese experimento, y el throughput cae muy poco. Sigue siendo un regularizador barato.
-
JiT / X-prediction (Back to Basics): en vez de predecir ruido o velocidad, el modelo predice una estimación de la imagen limpia
xy luego la convierte a velocidad. Esto respeta la suposición del manifold y hace el problema de aprendizaje más fácil.
En latentes 256x256 la mejora es ambigua: FID mejora ligeramente pero CMMD/DINO empeoran. Lo interesante: JiT permite entrenar directamente en pixeles a 1024x1024 con parches grandes (32x32) y estabilidad. Entrenar en 1024x1024 sin VAE fue solo ~3x más lento que 256x256 latente y alcanzó FID 17.42. Con esto, JiT abre la puerta a entrenamiento alto-res sin tokenizers costosos.
Lección práctica
X-pred es la mejor apuesta para entrenar sin tokenizer a alta resolución; en latente puede no ser dominante.
3) Token routing y sparsificación: gana cuando los tokens son muchos
Atención: estas técnicas apuntan al costo de computación de la atención sobre secuencias largas.
-
TREAD: en vez de borrar tokens, rutas un subconjunto para que salte un bloque contiguo de capas (no pierdes información, solo profundidad). Usuarios reportan que hasta 50% de routing funciona bien.
-
SPRINT: denso en capas tempranas, sparse en las capas medias costosas, y re-expande antes de la salida. Puede dropear hasta 75% en la parte cara.
En 256x256 (pocos tokens) ambas aportan throughput marginal (3.95 -> 4.11/4.20) pero degradan calidad (FID sube). En 1024x1024 cambian el juego: TREAD y SPRINT aumentan speed y mejoran calidad. Ejemplo: baseline 1024x1024 FID 17.42 -> TREAD 14.10 con más throughput.
Lección práctica
Si trabajas a alta resolución o con muchos tokens, routing/sparsificación puede ser el mayor multiplicador de eficiencia; en resoluciones bajas puede penalizar calidad.
4) Datos y captions: el supervisor que gobierna la tarea
- Long captions vs short captions: las captions largas y descriptivas aceleran y mejoran el aprendizaje. ¿Por qué? Porque más tokens = más señal. Para entrenamiento, prompts ricos reducen ambiguedad y evitan que el modelo promedie soluciones. Resultado claro: pasar a captions cortas empeora mucho la convergencia (FID 18.2 -> 36.84 en su experimento).
Tip práctico: entrena con captions largas y termina con un fine-tuning mixto de captions largas y cortas para que el modelo también funcione bien en prompts breves.
-
Sintético vs real: datos sintéticos (MidJourney) facilitaron estructura y coherencia compositiva temprano, mientras que imágenes reales mejoraron FID respecto a referencia real (textura y alto-frecuente). Conclusión: usa sintético para bootstrap de estructura y real para pulir texturas.
-
Supervised Fine-Tuning con datasets pequeños y curados (Alchemist, 3.3k pares) puede aportar un "capa de estilo" con polish fotográfico en una pasada corta.
Detalles operativos que duelen si los ignoras
-
Optimizador: cambiar de AdamW a Muon mostró una mejora tangible en calidad temprana (FID 18.2 -> 15.55). El optimizador importa: no es sólo estabilidad, es tiempo a calidad.
-
BF16: usar autocast BF16 está bien, pero guardar parámetros en BF16 fue un bug caro. Pesos almacenados en BF16 degradaron FID a 21.87. Regla: autocast en compute, pero mantén pesos y estado del optimizador en FP32.
Resumen técnico y recomendaciones para tu próxima corrida
- Alineamiento representacional (REPA) = mejor acelerador temprano. Úsalo como burn-in y luego apágalo.
- Diseñar latentes (REPA-E, Flux2-AE) ofrece la mayor ganancia en calidad por paso, con trade-offs claros en throughput; REPA-E es más balanceado, Flux2-AE rinde mejor pero es costoso.
- Objetivos: contrastive FM es un low-cost win para conditioning; x-pred permite entrenar alto-res sin tokenizer y es clave si quieres saltarte el VAE.
- Token routing: marginal en baja resolución, crítico y beneficioso en alta resolución con JiT/x-pred.
- Datos y captions: captions largas aceleran el aprendizaje; mezcla sintético para estructura y real para textura; un SFT pequeño y curado puede pulir el resultado.
- Operación: prueba optimizadores modernos (Muon) y vigila la precisión de los pesos (no los dejes en BF16).
Si estás pensando en replicar algo de esto o en ajustar una pipeline, empieza por controlar bien el baseline y añade una palanca a la vez: alineamiento temprano, luego latentes, luego x-pred si quieres 1024x1024, y finalmente token routing cuando ya tengas muchos tokens.
¿Y ahora qué sigue?
Los autores anuncian que publicarán el código completo del framework PRX y harán una speedrun pública de 24 horas combinando lo mejor de estas ideas. Si te interesa experimentar rápido, ese será un gran punto de partida para reproducir los ablations y ver cuáles funcionan en tu propio setup.
