MedQA: afina IA clínica en AMD ROCm sin CUDA | Keryc
MedQA demuestra que puedes entrenar y desplegar una IA clínica capaz de responder preguntas de examen con explicaciones clínicas, usando solamente hardware AMD y ROCm. ¿La sorpresa? No necesitas CUDA ni trucos mágicos de cuantización cuando tienes una GPU como la MI300X.
Qué es MedQA y por qué importa
MedQA es un adaptador LoRA finamente ajustado sobre Qwen3-1.7B para responder preguntas médicas de elección múltiple y además justificar la respuesta con razonamiento clínico. El objetivo no es reemplazar un diagnóstico médico, sino ofrecer respuestas con explicación que sean más útiles y verificables que una letra sin contexto.
Aquí hay tres razones clave por las que esto importa para equipos técnicos y clínicos:
La salida incluye tanto la letra correcta como una explicación clínica, lo que ayuda a auditoría y verificación.
Se entrenó y exportó el adaptador completo en hardware AMD usando ROCm, con cero dependencias CUDA.
El uso de LoRA mantiene el ajuste eficiente: solo ~2.2 millones de parámetros entrenables frente a 1.5B del modelo base.
Hardware: por qué la AMD Instinct MI300X cambia el juego
La MI300X ofrece 192 GB de HBM3 en una sola tarjeta. Para fine-tuning de LLMs, la memoria es a menudo la limitante: determina batch size, longitud de secuencia y si necesitas cuantizar.
Con 192 GB no fue necesario usar cuantización 4-bit ni 8-bit. Eso se traduce en un pipeline más limpio y menos riesgo de artefactos por cuantización. En este proyecto entrenamos Qwen3-1.7B en fp16 con LoRA y tardamos aproximadamente 5 minutos en la MI300X para 2,000 ejemplos.
Si quieres replicarlo en tu máquina ROCm: estos tres variables de entorno fueron suficientes para que el mismo código que corre en CUDA funcione en ROCm:
No se requirieron cambios en el código, kernels personalizados o shims de compatibilidad.
Pipeline técnico: modelo base, LoRA y parámetros de entrenamiento
Resumen rápido del stack:
Base model: Qwen3-1.7B (capaz y relativamente compacto a 1.7B parámetros).
Adaptación: LoRA vía PEFT para inyectar matrices de bajo rango en las capas de atención.
Frameworks: Transformers, PEFT, TRL, Accelerate sobre PyTorch + ROCm 6.1.
Ejemplo de configuración LoRA (concepto):
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=['q_proj', 'v_proj'],
bias='none',
)
model = get_peft_model(model, lora_config)
# trainable params ~2.2M of ~1.5B
Parámetros de entrenamiento relevantes:
fp16=True (bfloat16 produjo NaN en pruebas iniciales)
gradient_checkpointing=True para ahorrar memoria
per_device_train_batch_size=4 con gradient_accumulation_steps=4 => batch efectivo 16
optim='adamw_torch', lr=2e-4, scheduler cosine con warmup_ratio=0.05
El dataset usado fue una porción de MedMCQA: 2,000 ejemplos (pregunta, opciones A-D, etiqueta correcta y explicación opcional). La idea fue demostrar que un slice pequeño puede producir mejoras prácticas y explicables en minutos.
Inference y despliegue
Flujo de inferencia resumido:
Cargar tokenizer y modelo base.
Adjuntar el adaptador LoRA con PeftModel.from_pretrained.
Generación con greedy decoding y repetition_penalty para evitar loops.
Puedes descargar el adaptador desde HuggingFace Hub y fusionarlo con el modelo base si quieres un único checkpoint ligero. El adaptador ocupa unos pocos megabytes, no gigabytes.
Resultados, métricas y lecciones aprendidas
Trainable params: ~2.2M (0.15% del total).
Training time en MI300X: ~5 minutos para 2,000 ejemplos.
Dataset usado: 2,000 ejemplos de MedMCQA.
Baseline MedMCQA accuracy reportada: ~45% (referencia de dataset).
Nota: la ausencia de soporte bitsandbytes en ROCm es real, pero con 192 GB de HBM3 no fue un problema para este experimento. Eso simplifica el pipeline.
Qué sigue: escalado y robustez
Los autores sugieren pasos naturales para llevar esto más lejos:
Entrenar en el corpus completo MedMCQA (~180k preguntas) y añadir PubMedQA.
Añadir calibración de confianza para reportar estimaciones de certeza junto a la respuesta.
Integrar RAG para anclar respuestas a literatura médica en tiempo real.
Construir un harness de evaluación con splits de test reales para medir ganancia fuera de muestra.
Reflexión final
MedQA muestra que la barrera técnica de 'solo CUDA' puede romperse. Si tienes hardware AMD ROCm, el ecosistema HuggingFace funciona con pocas adaptaciones, y la gran memoria de la MI300X elimina muchos compromisos de ingeniería. Para proyectos médicos, la prioridad en explicaciones por encima de solo etiquetas es un buen recordatorio: la transparencia importa tanto como la precisión.