PyTorch profiling: de nn.Linear a un MLP fusionado | Keryc
En la segunda entrega de la serie sobre profiling en PyTorch subimos un peldaño: pasamos del par matmul+add manual al uso de nn.Linear, y luego apilamos tres de esas capas con una activación para formar un bloque MLP. ¿Qué cambia en la traza del profiler cuando usamos torch.compile o un kernel manual optimizado? Vamos a verlo con ejemplos concretos y trazas reales.
Qué pasa con un nn.Linear (repaso práctico)
Si recuerdas la primera parte, ya habíamos perfilado torch.add(torch.matmul(x, w), b). Con nn.Linear(in_dim, out_dim, bias=True) la operación es la misma algebraicamente:
y = x @ w.T + b
En la traza del profiler verás algo que parece extraño: aparece aten::t antes de aten::addmm. ¿Es eso una copia costosa en GPU? No. aten::t solo reescribe la metadata del tensor (shape y stride) y no lanza ningún kernel en la GPU. Lo puedes verificar mirando la columna de CUDA en la traza: muestra 0.000us de tiempo en CUDA.
aten::t
Otro punto clave: la suma con el bias no aparece como kernel separado. La razón es que la cuBLAS GEMM tiene una variante con epilogue que incluye el bias, y aten::addmm elige precisamente esa variante. Resultado: un solo kernel cuBLAS hace la multiplicación y el writeback con bias incorporado.
¿Qué hace torch.compile con una sola nn.Linear?
Curioso: al compilar la forward de una sola nn.Linear con torch.compile, la GPU ejecuta exactamente el mismo kernel cuBLAS que en modo eager. Lo que cambia es una pequeña porción de overhead en CPU: Inductor resolvió las vistas (strides) en tiempo de compilación, y elimina la bookkeeping de aten::t en la cadena de dispatch. Esos microsegundos en CPU desaparecen, pero el trabajo pesado en la GPU permanece igual.
Moraleja: para una única GEMM con bias, no esperes que torch.compile te dé un milagro dramático en la GPU. Su ganancia principal aquí es reducir la sobrecarga en CPU.
Subimos la apuesta: perfilando un MLP simple (GeGLU)
El ejemplo usa tres nn.Linear y una variante GeGLU en medio. El forward hace esencialmente:
dos proyecciones paralelas gate_proj y up_proj (para la puerta y la parte lineal)
gelu sobre la puerta
multiplicación punto a punto gelu(g) * u
down_proj para volver a dimensión dim
Si antes de abrir la traza te preguntas "¿cuántos kernels GPU veré?" la respuesta razonable es: 3 GEMMs (una por cada linear) más 2 kernels punto a punto (GeLU y mul) = 5 kernels por forward.
En las trazas reales eso es exactamente lo que vemos. Además notarás que cada GEMM hace una llamada extra de sizing: una consulta de occupancy (cudaOccupancyMaxActiveBlocksPerMultiprocessor) antes del cudaLaunchKernel. Los kernels punto a punto no hacen esa consulta.
¿Qué cambia al compilar el MLP?
Cuando corres --compile, torch.compile (Inductor + Triton backend) hace dos cosas importantes:
El chain de dispatch de cada nn.Linear (transpose, reshape, mm, etc.) se pliega: aparecen llamadas limpias a aten::mm y los kernels cuBLAS siguen siendo byte-por-byte los mismos que en eager.
La gran ganancia viene de fusionar las operaciones punto a punto: reshape + gelu + mul colapsan en un único kernel Triton. En la traza aparece un kernel con nombre tipo triton_poi_fused__unsafe_view_gelu_mul_0. ¿Por qué importa? Porque el tensor intermedio h = gelu(g) deja de escribirse y leerse desde la memoria global (HBM). En lugar de eso, la fusión mantiene datos en registros o en memoria rápida del chip, eliminando una ronda completa de tráfico a HBM.
Esa eliminación del viaje intermedio a memoria es la razón principal de la mejora en tiempos, aunque los GEMMs sigan iguales.
Kernels escritos a mano: Liger y sus compromisos
Otra ruta es usar kernels precompilados y tunearlos manualmente. El ejemplo usa LigerGEGLUMLP desde la librería kernels del Hub de Hugging Face. Ventajas claras:
La fusión ya viene baked-in: un kernel Triton calcula gelu(g) * u en una sola pasada.
Los parámetros de lanzamiento (block size, etc.) están elegidos para la arquitectura, evitando la consulta de occupancy y la recompilación dinámica.
El paquete kernels entrega binarios precompilados versionados, así el build no depende de tu toolchain local.
Pero hay un trade-off real: en la traza comparativa el kernel Inductor (generado por torch.compile y especializado en la forma exacta [8192, 3072]) puede ser ligeramente más rápido que el kernel manual Liger. Sin embargo, Inductor gana esos microsegundos porque está especializado para la forma exacta a través de re-tracing y recompilación. Si tus formas cambian (batch, seq, hidden), Inductor re-traza y vuelve a pagar costo de compilación. Liger, en cambio, es robusto a formas variables sin recompilación.
Entonces la elección no es simplemente "humano lento vs compilador rápido". Es: ¿quieres un kernel extremadamente rápido para una forma fija (con coste de compilación por cambio), o un kernel robusto y ya compilado para muchas formas?
Resumen práctico: tabla mental de cambios
Eager nn.Linear: bias foldado en la GEMM (epilogue), una sola llamada cuBLAS por linear.
Compiled nn.Linear: desaparece bookkeeping de vistas en CPU; GEMM en GPU igual.
Eager MLP: 5 kernels (3 GEMM + GeLU + mul) y un gran tensor intermedio via HBM.
Compiled MLP: los punto a punto se fusionan (Triton), el intermedio se queda en registros; GEMMs iguales.
Liger MLP: misma fusión que Inductor pero sin guards ni recompilación; parámetros de lanzamiento afinados y binarios preconstruidos.
Buenas prácticas cuando perfiles en PyTorch
Antes de abrir la traza, adivina lo que esperas ver. ¿Cuántos kernels? ¿Hay operaciones que deberían fusionarse? Eso te obliga a leer la traza para confirmar o para aprender por qué algo falla.
Aprende a leer los nombres de kernel en la traza. Un sufijo como _tn_ o _nn_ te dice el layout (transposed o no). Nombres byte-for-byte iguales significan trabajo GPU idéntico.
Diferencia tiempos CPU vs CUDA. Muchas optimizaciones reducen overhead en CPU (dispatch, vistas), otras reducen tráfico a HBM (fusión). Ambos importan dependiendo del cuello de botella.
Si necesitas robustez a formas variables, considera kernels tunados y precompilados (como Liger). Si tus formas son estables y buscas cada microsegundo, torch.compile con Inductor puede ganar por especialización.
Reflexión final
Profiling deja de ser un acto esotérico cuando lo conviertes en un hábito: hacer una hipótesis, mirar la traza y cerrar la brecha entre intuición y evidencia. Aquí aprendimos que nn.Linear ya es eficiente por diseño (bias en epilogue), que torch.compile brilla al fusionar punto a punto, y que los kernels manuales ofrecen robustez sin depender de re-tracing. ¿Cuál usar? Depende de tus formas, tolerancia a recompilaciones y cuánto valoras la portabilidad del binario.