Training text-to-image models: lessons from PRX | Keryc
Back to the lab. In the second installment of the PRX series something becomes clear that you might already suspect: training decisions move the needle as much as architecture. Here I tell you, in a practical and technical way, what really sped up convergence, what improved quality, and what was just noise in our PRX experiments.
Starting point: clean, reproducible reference
Before touching anything, the authors set an intentionally simple baseline so you can attribute improvements to concrete changes. Quick summary of the baseline they used:
Model: PRX-1.2B (single stream, global attention over image and text tokens)
Objective: pure Flow Matching (Flow Matching) in FLUX VAE latent space
Data: 1M synthetic images (MidJourney V6), 256x256, global batch 256
Text encoder: GemmaT5, RoPE pos encoding, padding mask, EMA disabled
Control metrics (not perfect, but useful): FID, CMMD (CLIP-MMD), DINO-MMD and throughput (batches/s). Guiding question: does this change improve convergence or efficiency relative to the baseline?
1) Representation alignment: a powerful shortcut early on
The core idea: separate what the model must learn as internal representation from the denoising task. Instead of a single loss doing everything, you add an auxiliary loss that pushes the denoiser’s intermediate features toward the space of a strong, frozen vision encoder.
REPA: add a frozen teacher (e.g. DINOv2 or DINOv3), get patch embeddings from the teacher, and force the student’s intermediate tokens to match in that space (cosine similarity per patch). Result: much faster startup and cleaner global structure. Their numbers: FID drops from 18.2 to 14.64 with DINOv3, but throughput falls from 3.95 to 3.46 batches/s because of the teacher forward cost.
iREPA: small tweaks to preserve spatial structure (a 3x3 conv in the projector and spatial normalization). In their experience it helps with DINOv2 but not always with DINOv3; it’s a cheap tweak that can help, but watch out for interactions.
REPA Works Until It Does Not: great early accelerator, but it can limit the model later. Practical fix: use REPA as a burn-in and switch it off afterwards (stage-wise schedule).
REPA-E and latent design: instead of only aligning features, you can design the latent space to be intrinsically more learnable. REPA-E updates the VAE with an alignment loss while preventing the flow loss from taking harmful shortcuts. Flux2-AE approaches the same idea from another angle: latents with semantic structure are easier to model.
Results for latents: moving to Flux2-AE or REPA-E reduces FID dramatically (18.20 -> ~12.08). Flux2-AE achieves better CMMD/DINO-MMD but with costly throughput (3.95 -> 1.79). REPA-E is more balanced (throughput 3.39) and keeps strong gains.
Practical takeaway
If you have compute limits, use alignment early and consider investing in a tokenizer/AE designed for learnability: it gives the biggest jump per training step.
2) Training objectives: don't underestimate loss formulation
Small changes to the loss produce large effects.
Contrastive Flow Matching: add a term that pushes conditional trajectories to be distinct from each other (use batch negatives). In their text-to-image setup the benefit was modest: CMMD/DINO-MMD improve slightly, FID did not in that experiment, and throughput drops only a bit. Still a cheap regularizer.
JiT / X-prediction (Back to Basics): instead of predicting noise or velocity, the model predicts an estimate of the clean image x and then converts that to velocity. This respects the manifold assumption and makes the learning problem easier.
On 256x256 latents the improvement is mixed: FID improves slightly but CMMD/DINO worsen. The interesting part: JiT lets you train directly in pixels at 1024x1024 with large patches (32x32) and remain stable. Training at 1024x1024 without a VAE was only ~3x slower than 256x256 latent and reached FID 17.42. JiT opens the door to high-res training without expensive tokenizers.
Practical takeaway
X-pred is your best bet if you want to train without a tokenizer at high resolution; in latent space it may not dominate.
3) Token routing and sparsification: wins when tokens are many
Heads up: these techniques target the compute cost of attention over long sequences.
TREAD: instead of dropping tokens, you route a subset so they skip a contiguous block of layers (you don't lose information, only depth). Users report routing up to 50% works well.
SPRINT: dense in early layers, sparse in the expensive middle layers, and re-expand before the output. It can drop up to 75% in the costly section.
At 256x256 (few tokens) both give marginal throughput gains (3.95 -> 4.11/4.20) but degrade quality (FID rises). At 1024x1024 they change the game: TREAD and SPRINT increase speed and improve quality. Example: baseline 1024x1024 FID 17.42 -> TREAD 14.10 with higher throughput.
Practical takeaway
If you work at high resolution or with many tokens, routing/sparsification can be the biggest efficiency multiplier; at low resolution it can hurt quality.
4) Data and captions: the supervisor that governs the task
Long captions vs short captions: long, descriptive captions speed up and improve learning. Why? More tokens = more signal. For training, rich prompts reduce ambiguity and prevent the model from averaging solutions. Clear result: switching to short captions hurts convergence badly (FID 18.2 -> 36.84 in their experiment).
Tip: train with long captions, then finish with mixed fine-tuning of long and short captions so the model also works well with brief prompts.
Synthetic vs real: synthetic data (MidJourney) helped structure and compositional coherence early, while real images improved FID relative to real references (texture and high-frequency detail). Conclusion: use synthetic to bootstrap structure and real to polish textures.
Supervised fine-tuning with small, curated datasets (Alchemist, 3.3k pairs) can add a "style layer" with photographic polish in a short pass.
Operational details that hurt if you ignore them
Optimizer: switching from AdamW to Muon showed a tangible improvement in early quality (FID 18.2 -> 15.55). Optimizer choice matters: it’s not just stability, it’s time-to-quality.
BF16: using autocast BF16 is fine, but saving parameters in BF16 was an expensive bug. Weights stored in BF16 degraded FID to 21.87. Rule of thumb: autocast for compute, but keep model weights and optimizer state in FP32.
Technical summary and recommendations for your next run
Representation alignment (REPA) = best early accelerator. Use it as a burn-in and then turn it off.
Designing latents (REPA-E, Flux2-AE) gives the largest quality gain per step, with clear throughput trade-offs; REPA-E is more balanced, Flux2-AE performs better but is costly.
Objectives: contrastive FM is a low-cost win for conditioning; x-pred enables training high-res without a tokenizer and is key if you want to skip the VAE.
Token routing: marginal at low resolution, critical and beneficial at high resolution with JiT/x-pred.
Data and captions: long captions speed learning; mix synthetic for structure and real for texture; a small curated SFT can polish outputs.
Ops: try modern optimizers (Muon) and watch weight precision (don’t store in BF16).
If you’re thinking of replicating any of this or tuning a pipeline, start by controlling the baseline and add one lever at a time: alignment early, then latents, then x-pred if you want 1024x1024, and finally token routing once you already have many tokens.
What's next?
The authors announce they will publish the full PRX framework code and run a public 24-hour speedrun combining the best of these ideas. Interested in experimenting quickly? That will be a great starting point to reproduce the ablations and see what works in your own setup.