PyTorch profiling: from nn.Linear to a fused MLP | Keryc
In the second installment of the series on profiling in PyTorch we climb one step: we move from the manual matmul+add pair to using nn.Linear, and then stack three of those layers with an activation to form an MLP block. What changes in the profiler trace when we use torch.compile or a hand-optimized kernel? Let’s see with concrete examples and real traces.
Qué pasa con un nn.Linear (repaso práctico)
If you remember the first part, we had already profiled torch.add(torch.matmul(x, w), b). With nn.Linear(in_dim, out_dim, bias=True) the algebra is the same:
y = x @ w.T + b
In the profiler trace you’ll see something that looks odd: aten::t appears before aten::addmm. Is that an expensive GPU copy? No. aten::t only rewrites the tensor metadata (shape and stride) and does not launch any kernel on the GPU. You can verify it by looking at the CUDA column in the trace: shows 0.000us of CUDA time.
aten::t
Another key point: the bias addition does not show up as a separate kernel. The reason is that cuBLAS GEMM has a variant with an epilogue that includes the bias, and aten::addmm picks exactly that variant. Result: a single cuBLAS kernel does the multiplication and the writeback with bias included.
¿Qué hace torch.compile con una sola nn.Linear?
Curiously: when you compile the forward of a single nn.Linear with torch.compile, the GPU runs exactly the same cuBLAS kernel as in eager mode. What changes is a small amount of CPU overhead: Inductor resolved the views (strides) at compile time and removes the bookkeeping from aten::t in the dispatch chain. Those microseconds on CPU disappear, but the heavy GPU work stays the same.
Moraleja: for a single GEMM with bias, don’t expect torch.compile to give you a dramatic miracle on the GPU. Its main gain here is reducing CPU overhead.
Subimos la apuesta: perfilando un MLP simple (GeGLU)
The example uses three nn.Linear and a GeGLU variant in the middle. The forward essentially does:
two parallel projections gate_proj and up_proj (for the gate and the linear part)
gelu on the gate
pointwise multiplication gelu(g) * u
down_proj to go back to dimension dim
If before opening the trace you ask yourself “how many GPU kernels will I see?” a reasonable answer is: 3 GEMMs (one per linear) plus 2 pointwise kernels (GeLU and mul) = 5 kernels per forward.
In real traces that’s exactly what we see. You’ll also notice that each GEMM makes an extra sizing call: an occupancy query (cudaOccupancyMaxActiveBlocksPerMultiprocessor) before the cudaLaunchKernel. The pointwise kernels don’t make that query.
¿Qué cambia al compilar el MLP?
When you run with --compile, torch.compile (Inductor + Triton backend) does two important things:
The dispatch chain of each nn.Linear (transpose, reshape, mm, etc.) is folded: clean calls to aten::mm appear and the cuBLAS kernels are byte-for-byte the same as in eager.
The big win comes from fusing the pointwise operations: reshape + gelu + mul collapse into a single Triton kernel. In the trace a kernel shows up with a name like triton_poi_fused__unsafe_view_gelu_mul_0. Why does that matter? Because the intermediate tensor h = gelu(g) stops being written and read from global memory (HBM). Instead, the fusion keeps data in registers or on-chip fast memory, eliminating a whole roundtrip to HBM.
That elimination of the intermediate memory trip is the main reason for the timing improvement, even though the GEMMs remain the same.
Kernels escritos a mano: Liger y sus compromisos
Another route is using precompiled, hand-tuned kernels. The example uses LigerGEGLUMLP from the kernels library on the Hugging Face Hub. Clear advantages:
Fusion is already baked-in: a Triton kernel computes gelu(g) * u in one pass.
Launch parameters (block size, etc.) are chosen for the architecture, avoiding the occupancy query and dynamic recompilation.
The kernels package delivers versioned precompiled binaries, so the build doesn’t depend on your local toolchain.
But there’s a real trade-off: in the comparative trace the Inductor kernel (generated by torch.compile and specialized to the exact shape [8192, 3072]) can be slightly faster than the manual Liger kernel. However, Inductor gains those microseconds because it’s specialized for the exact shape via re-tracing and recompilation. If your shapes change (batch, seq, hidden), Inductor re-traces and you pay the compilation cost again. Liger, on the other hand, is robust to variable shapes without recompilation.
So the choice isn’t simply “human slow vs compiler fast.” It’s: do you want an extremely fast kernel for a fixed shape (with a compilation cost per change), or a robust prebuilt kernel for many shapes?
Resumen práctico: tabla mental de cambios
Eager nn.Linear: bias folded into the GEMM (epilogue), a single cuBLAS call per linear.
Compiled nn.Linear: view bookkeeping on CPU disappears; GEMM on GPU unchanged.
Eager MLP: 5 kernels (3 GEMM + GeLU + mul) and a large intermediate tensor via HBM.
Compiled MLP: pointwise ops fuse (Triton), the intermediate stays in registers; GEMMs unchanged.
Liger MLP: same fusion as Inductor but without guards or recompilation; tuned launch params and prebuilt binaries.
Buenas prácticas cuando perfiles en PyTorch
Before opening the trace, guess what you expect to see. How many kernels? Are there operations that should fuse? That forces you to read the trace to confirm or to learn why something fails.
Learn to read kernel names in the trace. A suffix like _tn_ or _nn_ tells you the layout (transposed or not). Byte-for-byte identical names mean identical GPU work.
Differentiate CPU vs CUDA times. Many optimizations cut CPU overhead (dispatch, views), others reduce HBM traffic (fusion). Both matter depending on the bottleneck.
If you need robustness to variable shapes, consider tuned precompiled kernels (like Liger). If your shapes are stable and you want every microsecond, torch.compile with Inductor can win by specialization.
Reflexión final
Profiling stops being an esoteric act when you make it a habit: form a hypothesis, look at the trace, and close the gap between intuition and evidence. Here we learned that nn.Linear is already efficient by design (bias in the epilogue), that torch.compile shines at fusing pointwise ops, and that manual kernels provide robustness without relying on re-tracing. Which should you use? It depends on your shapes, tolerance for recompilation, and how much you value binary portability.