Back to the lab. In the second installment of the PRX series something becomes clear that you might already suspect: training decisions move the needle as much as architecture. Here I tell you, in a practical and technical way, what really sped up convergence, what improved quality, and what was just noise in our PRX experiments.
Starting point: clean, reproducible reference
Before touching anything, the authors set an intentionally simple baseline so you can attribute improvements to concrete changes. Quick summary of the baseline they used:
- Model: PRX-1.2B (single stream, global attention over image and text tokens)
- Objective: pure Flow Matching (
Flow Matching) in FLUX VAE latent space - Data: 1M synthetic images (MidJourney V6), 256x256, global batch 256
- Optimization: AdamW, lr 1e-4, betas (0.9, 0.95), eps 1e-15, weight_decay 0
- Text encoder: GemmaT5, RoPE pos encoding, padding mask, EMA disabled
Control metrics (not perfect, but useful): FID, CMMD (CLIP-MMD), DINO-MMD and throughput (batches/s). Guiding question: does this change improve convergence or efficiency relative to the baseline?
1) Representation alignment: a powerful shortcut early on
The core idea: separate what the model must learn as internal representation from the denoising task. Instead of a single loss doing everything, you add an auxiliary loss that pushes the denoiser’s intermediate features toward the space of a strong, frozen vision encoder.
-
REPA: add a frozen teacher (e.g. DINOv2 or DINOv3), get patch embeddings from the teacher, and force the student’s intermediate tokens to match in that space (cosine similarity per patch). Result: much faster startup and cleaner global structure. Their numbers: FID drops from 18.2 to 14.64 with DINOv3, but throughput falls from 3.95 to 3.46 batches/s because of the teacher forward cost.
-
iREPA: small tweaks to preserve spatial structure (a 3x3 conv in the projector and spatial normalization). In their experience it helps with DINOv2 but not always with DINOv3; it’s a cheap tweak that can help, but watch out for interactions.
-
REPA Works Until It Does Not: great early accelerator, but it can limit the model later. Practical fix: use REPA as a burn-in and switch it off afterwards (stage-wise schedule).
-
REPA-E and latent design: instead of only aligning features, you can design the latent space to be intrinsically more learnable. REPA-E updates the VAE with an alignment loss while preventing the flow loss from taking harmful shortcuts. Flux2-AE approaches the same idea from another angle: latents with semantic structure are easier to model.
Results for latents: moving to Flux2-AE or REPA-E reduces FID dramatically (18.20 -> ~12.08). Flux2-AE achieves better CMMD/DINO-MMD but with costly throughput (3.95 -> 1.79). REPA-E is more balanced (throughput 3.39) and keeps strong gains.
Practical takeaway
If you have compute limits, use alignment early and consider investing in a tokenizer/AE designed for learnability: it gives the biggest jump per training step.
2) Training objectives: don't underestimate loss formulation
Small changes to the loss produce large effects.
-
Contrastive Flow Matching: add a term that pushes conditional trajectories to be distinct from each other (use batch negatives). In their text-to-image setup the benefit was modest: CMMD/DINO-MMD improve slightly, FID did not in that experiment, and throughput drops only a bit. Still a cheap regularizer.
-
JiT / X-prediction (Back to Basics): instead of predicting noise or velocity, the model predicts an estimate of the clean image
xand then converts that to velocity. This respects the manifold assumption and makes the learning problem easier.
On 256x256 latents the improvement is mixed: FID improves slightly but CMMD/DINO worsen. The interesting part: JiT lets you train directly in pixels at 1024x1024 with large patches (32x32) and remain stable. Training at 1024x1024 without a VAE was only ~3x slower than 256x256 latent and reached FID 17.42. JiT opens the door to high-res training without expensive tokenizers.
Practical takeaway
X-pred is your best bet if you want to train without a tokenizer at high resolution; in latent space it may not dominate.
3) Token routing and sparsification: wins when tokens are many
Heads up: these techniques target the compute cost of attention over long sequences.
-
TREAD: instead of dropping tokens, you route a subset so they skip a contiguous block of layers (you don't lose information, only depth). Users report routing up to 50% works well.
-
SPRINT: dense in early layers, sparse in the expensive middle layers, and re-expand before the output. It can drop up to 75% in the costly section.
At 256x256 (few tokens) both give marginal throughput gains (3.95 -> 4.11/4.20) but degrade quality (FID rises). At 1024x1024 they change the game: TREAD and SPRINT increase speed and improve quality. Example: baseline 1024x1024 FID 17.42 -> TREAD 14.10 with higher throughput.
Practical takeaway
If you work at high resolution or with many tokens, routing/sparsification can be the biggest efficiency multiplier; at low resolution it can hurt quality.
4) Data and captions: the supervisor that governs the task
- Long captions vs short captions: long, descriptive captions speed up and improve learning. Why? More tokens = more signal. For training, rich prompts reduce ambiguity and prevent the model from averaging solutions. Clear result: switching to short captions hurts convergence badly (FID 18.2 -> 36.84 in their experiment).
Tip: train with long captions, then finish with mixed fine-tuning of long and short captions so the model also works well with brief prompts.
-
Synthetic vs real: synthetic data (MidJourney) helped structure and compositional coherence early, while real images improved FID relative to real references (texture and high-frequency detail). Conclusion: use synthetic to bootstrap structure and real to polish textures.
-
Supervised fine-tuning with small, curated datasets (Alchemist, 3.3k pairs) can add a "style layer" with photographic polish in a short pass.
Operational details that hurt if you ignore them
-
Optimizer: switching from AdamW to Muon showed a tangible improvement in early quality (FID 18.2 -> 15.55). Optimizer choice matters: it’s not just stability, it’s time-to-quality.
-
BF16: using autocast BF16 is fine, but saving parameters in BF16 was an expensive bug. Weights stored in BF16 degraded FID to 21.87. Rule of thumb: autocast for compute, but keep model weights and optimizer state in FP32.
Technical summary and recommendations for your next run
- Representation alignment (REPA) = best early accelerator. Use it as a burn-in and then turn it off.
- Designing latents (REPA-E, Flux2-AE) gives the largest quality gain per step, with clear throughput trade-offs; REPA-E is more balanced, Flux2-AE performs better but is costly.
- Objectives: contrastive FM is a low-cost win for conditioning; x-pred enables training high-res without a tokenizer and is key if you want to skip the VAE.
- Token routing: marginal at low resolution, critical and beneficial at high resolution with JiT/x-pred.
- Data and captions: long captions speed learning; mix synthetic for structure and real for texture; a small curated SFT can polish outputs.
- Ops: try modern optimizers (Muon) and watch weight precision (don’t store in BF16).
If you’re thinking of replicating any of this or tuning a pipeline, start by controlling the baseline and add one lever at a time: alignment early, then latents, then x-pred if you want 1024x1024, and finally token routing once you already have many tokens.
What's next?
The authors announce they will publish the full PRX framework code and run a public 24-hour speedrun combining the best of these ideas. Interested in experimenting quickly? That will be a great starting point to reproduce the ablations and see what works in your own setup.
