LinkedIn comparte un recorrido práctico para lograr que GPT-OSS funcione como backbone en entrenamiento agentic RL: no es solo ajustar pesos, sino resolver una pila de incompatibilidades entre MoE, attention kernels, paralelismo y la infraestructura de inferencia.
¿Qué es agentic RL y por qué importa?
Agentic reinforcement learning extiende el RL clásico de respuestas en una sola vuelta a políticas que planifican y actúan a lo largo de trayectorias multi-paso. En vez de optimizar una sola respuesta, entrenas al modelo para que seleccione consultas, llame herramientas, observe resultados y ajuste su conducta en función de recompensas que dependen de decisiones a largo plazo.
¿Y por qué nos importa? Porque aplicaciones reales (reclutamiento, asistentes con herramientas, flujos multi-step) requieren esta capacidad de adaptación: recuperar información, refinar peticiones y coordinar herramientas en serie, no solo producir un texto bonito.
Problemas que surgieron al intentar agentic RL con GPT-OSS
Durante los experimentos con verl y GPT-OSS-20B aparecieron fallos graves:
- KL y entropía explotaban, y las recompensas no subían.
- Valores de clipping por importancia muestral no nulos aun siendo on-policy (PPO requiere ratio = 1).
- Gradientes explotando y training-inference mismatch entre FSDP y motores de inferencia (vLLM / SGLang).
- Inestabilidad ligada al manejo de attention sinks y kernels de FlashAttention.
- OOM inesperados por materialización de estados en la ruta de MoE durante cómputo de log-probs.
Estos síntomas no eran ruido: indicaban interacciones sutiles entre la arquitectura MoE, la implementación del attention sink, y las diferencias entre caminos de inferencia y de entrenamiento.
Diagnóstico técnico (resumen con lo esencial)
1) Dual forward pass en MoE y ratio de importancia
El cómputo del old_log_prob se hacía en una pasada separada del log_prob actual. En modelos MoE la red de gating puede enrutar a expertos distintos entre ambas pasadas por pequeñas diferencias numéricas o estocasticidad. Eso cambia log_prob y genera un ratio distinto de 1, activando el clipping de PPO y rompiendo la premisa on-policy.
Solución lógica aplicada:
if on_policy:
old_log_prob = log_prob.detach()
else:
old_log_prob = model_inputs["old_log_probs"]
Al forzar old_log_prob igual al log_prob recién calculado (con detach() para evitar flujo de gradiente) se restaura la integridad on-policy.
2) Attention sinks: forward/backward y mismatch entrenamiento-inferencia
Los attention sinks son parámetros escalares por cabeza que absorben parte de la masa del softmax como "token virtual". En la implementación usada:
- El sink participaba en la normalización softmax pero no contribuía al output (no multiplicaba V).
- El backward no estaba soportado en FlashAttention v2/v3 upstream, por lo que el gradiente del sink quedaba mal definido y la fase de entrenamiento divergía.
Al implementar el backward del sink (reutilizando el forward adaptado de la fork de vLLM y añadiendo el gradiente del sink) se corrigió un gran desajuste token-level entre inferencia y entrenamiento, mejorando la estabilidad.
3) Materialización de MoE y OOM
Al calcular log-probs bajo FSDP, se activó la ruta de inferencia que duplica hidden_states con repeat(num_experts, 1) y realiza operaciones batched que materializan tensores gigantes. Resultado: OOM incluso para el GPT-OSS-20B en grandes ventanas de contexto.
La diferencia es que la ruta de entrenamiento procesa expertos en un for-loop (más lenta, pero mucho más eficiente en memoria). La solución fue parchear la materialización y forzar una ruta más eficiente que no duplique estados innecesariamente. Este issue ya ha sido reportado en Hugging Face: https://github.com/huggingface/transformers/issues/40073
4) Rollout correction y training-inference mismatch
Las optimizaciones agresivas en motores de inferencia (p. ej. SGLang con Triton) producen diferencias numéricas frente al stack de entrenamiento (FSDP + FlashAttention-v2). Eso puede convertir una actualización aparentemente on-policy en off-policy y desestabilizar el aprendizaje. Aplicar correcciones de rollout (sequence-level importance sampling) y alinear enrutamiento reduce estas discrepancias y estabiliza gradientes.
Soluciones implementadas por el equipo
- Restaurar en PPO la integridad on-policy: sustituir
old_log_probpor ellog_prob.detach()cuando el minibatch es on-policy. - Implementar backward del attention sink en FlashAttention v3 (basado en fork de vLLM), corrigiendo el cómputo de gradientes y el mismatch token-level.
- Parchear la ruta de MoE en Transformers para evitar la materialización de todos los expertos y reducir el uso de memoria durante
compute_log_prob. - Integrar sequence parallelism compatible con attention sinks y FlashAttention v3 para escalar ventanas largas de contexto (max response 16k, prompt 8k) sin explotar memoria.
- Mantener rollout correction (sequence-level IS) y pruebas de aislamiento (ej. congelar capas de atención) para localizar cuellos de botella.
Resultados experimentales resumidos
Tras los arreglos, los entrenamientos con GPT-OSS-20B mostraron:
- Normas de gradiente estables y sin explosiones.
- Mejor y más rápido aumento de recompensa en GSM8K (single-turn) y en tareas agentic como ReTool.
- Mejora en tareas de instruction following verificable (VerifyIf) y en evaluaciones fuera de dominio.
- Validación de que la corrección del sink y las optimizaciones de memoria son necesarias para que GPT-OSS converja a ritmos comparables a variantes densas y modelos benchmark como Qwen-2.5-32B en tendencias métricas durante RL.
En resumen: después de los fixes, el entrenamiento agentic RL sobre GPT-OSS es estable y utilizable como backbone para agentes multi-step.
Consejos prácticos para equipos que quieran reproducir o extender esto
- Si trabajas con MoE y PPO, comprueba que
old_log_probylog_probprovengan exactamente de la misma pasada o sustituyeold_log_proben modo on-policy. - Verifica compatibilidad entre kernels de atención en entrenamiento e inferencia: mismatches en
log-pplindican problemas serios. - Activa rollout correction (sequence-level IS) para mitigar training-inference mismatch.
- Implementa sequence parallelism y evita rutas que materialicen todos los expertos simultáneamente si tienes ventanas largas.
- Añade pruebas unitarias para attention sinks (forward+backward) y registra rutas de enrutamiento en MoE para detectar no determinismo.
Reflexión final
La lección es clara: llevar un LLM open-source al mundo agentic no es solo "más datos y más GPU". Requiere inspeccionar en detalle la interacción entre arquitectura (MoE y sinks), kernels de atención, y la topología de paralelismo. Un cambio pequeño en la ruta de inferencia o en la implementación de un sink puede convertir aprendizajes estables en colapsos rápidos.
Si estás construyendo agentes con LLMs abiertos, planifica tiempo para engineering profundo: asegurarte de comportamiento determinista en MoE, soporte backward completo para mecanismos especiales como sinks, y rutas de ejecución memoria-eficientes serán siempre inversiones que pagan con estabilidad y escalabilidad.
