Entrena un modelo text-to-image en 24h con PRX | Keryc
Volviste justo a tiempo para la parte práctica. ¿Qué pasa si juntas todas las ideas que sí funcionan y las entrenas en 24 horas con un presupuesto ajustado? Eso es exactamente lo que hace este experimento PRX: apilar trucos de arquitectura, pérdidas perceptuales y optimizaciones para obtener un modelo text-to-image útil en un día de cómputo.
El reto: una speedrun de 24 horas
Objetivo claro y realista: entrenar un modelo competible en 24 horas usando 32 H200 con un presupuesto aproximado de 1500 dólares (2 $/hora por GPU). No es investigación puramente teórica: es ingeniería para maximizar rendimiento bajo restricciones fuertes. ¿Qué tan lejos puedes llegar combinando lo que ya funciona? Mucho más de lo que pensarías.
Arquitectura y formulación: pixel-space con x-prediction
En lugar de entrenar en latentes y depender de un VAE, usan la formulación x-prediction (Back to Basics: Let Denoising Generative Models Denoise). Eso permite entrenar directamente en píxeles y reaprovechar todo el toolbox clásico de visión computacional.
Decisiones clave:
Patch size 32 y proyección inicial a un bottleneck de 256 dimensiones para controlar la longitud de secuencia. Esto hace viable entrenar en pixel-space incluso a resoluciones altas.
Arrancan directo a 512px y luego afinan a 1024px (no siguen la escalera 256 -> 512 -> 1024). Así se concentra la mayor parte del entrenamiento en la resolución que importa.
Ventaja práctica: al predecir píxeles directamente puedes usar pérdidas perceptuales tal cual estaban diseñadas originalmente, sin el rodeo de decodificar latentes.
Pérdidas perceptuales: LPIPS y DINOv2
Siguen la inspiración de PixelGen y añaden dos pérdidas auxiliares:
LPIPS para similitud perceptual baja-nivel
Pérdida perceptual basada en DINOv2 para señal semántica más fuerte
Detalles de implementación que marcaron la diferencia:
Aplicarlas sobre la imagen completa poolada en vez de características por parches
Aplicarlas a todos los niveles de ruido durante el entrenamiento
Pesos usados (valores empíricos que funcionaron bien): LPIPS 0.1 y DINO perceptual 0.01. Añade poco overhead comparado con el paso principal del transformer, pero acelera la convergencia y mejora la calidad final.
Eficiencia: routing con TREAD y guía para tokens ruteados
Para abaratar cada paso usan token routing con TREAD, que selecciona aleatoriamente una fracción de tokens y los deja bypassear un bloque contiguo del transformer para reinyectarlos después. Elección práctica frente a otras alternativas:
TREAD por simplicidad y buen balance entre ahorro y complejidad (ejemplo: secuencia 64 vs 128 en su setting)
Ruta aplicada: 50% de tokens desde el bloque 2 hasta el penúltimo bloque
Problema conocido: modelos ruteados pueden verse peor bajo CFG convencional si están poco entrenados. Solución práctica: implementación de una auto-guía (self-guidance) inspirada en Guiding Token-Sparse Diffusion Models que guía usando una predicción condicional densa vs. ruteada en lugar de una rama incondicional.
Alineamiento de representaciones con REPA y maestro DINOv3
Usan REPA para alinear representaciones con un teacher DINOv3 (el que mejor rindió en experimentos previos). Concretamente:
Alineamiento aplicado una vez en el bloque 8
Peso de pérdida REPA = 0.5
Como combinan REPA con TREAD, la pérdida se calcula solo sobre tokens no ruteados (los que pasan por los bloques donde se aplica la pérdida). Esto evita señales inconsistentes comparando tokens que saltaron la ruta.
Optimización: Muon + Adam (FSDP)
Optimizador principal para matrices 2D: Muon con FSDP (muon_fsdp_2). Resto de parámetros (biases, norm, embeddings) con Adam. Dos grupos de parámetros pragmáticos:
Grupo
Aplicación
Parámetros clave
Muon
Parámetros 2D (matrices)
lr=1e-4, momentum=0.95, nesterov=true, ns_steps=5
Adam
Parámetros no 2D
lr=1e-4, betas=(0.9, 0.95), eps=1e-8
Resultado: Muon mostró mejora clara sobre Adam puro en runs previos, por eso lo aplicaron selectivamente.
Datos y agenda de entrenamiento
Conjuntos públicos sintéticos usados:
Flux generated (1.7M) - lehduong/flux_generated
FLUX-Reason-6M (6M) - LucasFang/FLUX-Reason-6M
midjourney-v6-llava (1M) - brivangl/midjourney-v6-llava, re-captionado con Gemini 2.5 Flash para uniformar prompts y reducir ruido en captions
Horario de entrenamiento (la receta práctica):
512px: 100k pasos, batch size 1024
1024px: 20k pasos, batch size 512 (sin REPA en esta etapa)
EMA para muestreo y evaluación:
smoothing = 0.999
update_interval = 10ba
ema_start = 0ba
El pipeline fue diseñado para ser configurable: puedes sustituir datasets, ajustar routing, REPA, pérdidas perceptuales o la configuración de Muon.
Resultados y lecciones prácticas
¿Funciona en 24 horas? Sí, y de forma útil. Observaciones principales:
Calidad: prompt following fuerte, estética consistente, la etapa a 1024px afina detalles sin romper la composición.
Fallos: texturas glitch, anatomía ocasionalmente rara y problemas en prompts muy difíciles. Estos parecen más artefactos de subentrenamiento o falta de diversidad de datos que fallas estructurales del recipe.
Lección clave: combinando pixel-space, routing eficiente, alineamiento de representaciones y pérdidas perceptuales ligeras, puedes obtener un modelo significativo en un solo día con presupuesto moderado. No es magia: es ingeniería cuidadosa y elección de componentes probados.
Qué sigue y cómo reproducirlo
Esto es un punto de partida. El equipo planea escalar la receta, iterar en la mezcla de datos y mejorar la captioning. Si quieres reproducirlo o probar variantes, el código y configuraciones están abiertos para la comunidad.