Making OpenAI Triton easier π± π
Published:
I find writing triton kernels involves many repetitive tasks, that can be cleanly abstracted away. This allows to write triton code much more in line with how I actually think. Itβs way more fun, and less mentally draining.
So I made triton_util, which aims to be a fastcore-like plug & play addon to OpenAI triton.
With triton util you can e.g. write
tu.load_2d(ptr, sz0, sz1, n0, n1, max0, max1, stride0)
instead of
offs0 = n0 * sz0 + tl.arange(0, sz0)
offs1 = n1 * sz1 + tl.arange(0, sz1)
offs = offs0[:,None] * stride0 + offs1[None,:] * stride1
mask = (offs0[:,None] < max0) & (offs1[None,:] < max1)
tl.load(ptr + offs, mask)
Also, I find the breakpoint_once()
and print_once()
functions incredibly handy when debugging triton kernels. They breakpoint / print only in one kernel (i.e., pid=(0,0,0)
)
You can install it via pip install triton-util
. I suggest importing it like import triton_util as tu
.
I also like to abbreviate the debugging functions like so: breakpoint_once, print_once, breakpoint_if, breakpoint_if = tu.breakpoint_once, tu.print_once, tu.breakpoint_if, tu.breakpoint_if
.
Check triton_util out!
And let me know what you think via Twitter or umer.hayat.adil@gmail.com.
- Umer