Demystify OpenAI Triton Fkong' Tech Blog
Demystify OpenAI Triton Fkong' Tech Blog
About
Blog
Categories
Series
Tags
We’re releasing Triton 1.0, an open-source Python-like programming language which enables researchers with no CUDA
experience to write highly efficient GPU code — most of the time on par with what an expert would be able to produce.
Program GPU with Python: so that the effort to program GPU device is minimal;
Encapsulate optimizations within SM: user should focus on partitioning the job and schedule them on SMs, Triton is then
responsible for optimizing the code within SM automatically;
Triton turns Python code to Triton IR on-the-fly through Python AST, then Triton optimizes and lowers it to LLVM-IR / MLIR, followed by
generating PTX directly through libLLVM, and compile to cubin through ptxas.
Program GPU with Python: pretty successful, we can write a functional CUDA kernel with Triton quite fast;
Encapsulate optimizations within SM: I would say Triton only achieved 60% of this goal. User usually can’t generate an efficient
kernel directly, as one concrete example, we wrote a very first version GroupNorm kernel using Triton, which is functionally correct,
but not necessarily efficient without knowing what Triton did to the code;
Easy to use: requires far less time to write a kernel compared to CUDA;
Performant: can generate kernel with comparable performance as skilled CUDA programmer, the prerequisite is user knows how to
tune Triton code;
Hard to debug: the whole optimization is black-box, user has to go through PTX/IR to understand the issues;
Limitations: some kernel can’t be implemented with Triton due to limitations, e.g. tile size must be power of 2, doesn’t support slice;
Fast prototyping: we can experiment various ideas with Triton and iterate very fast, we can implement the idea with CUDA at last if
nailed;
New kernels/operators: we can use Triton to generate kernels/operators not haven’t been implemented in
cuBLAS/cuDNN/frameworks;
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 1 of 17
:
Compared to the manual written CUDA kernel, where do the performance benefits of Triton mainly come from?
Triton can automate a bunch of optimizations, which CUDA developer may not aware of;
Kernels from library like cuBLAS has universality requirement, Triton can drop it on-demand and generate simpler code, e.g. bias
addition in cuBLAS GEMM;
Triton kernels can be auto-tuned;
LLVM generates better PTX than NVCC sometimes, e.g. loop unroll;
Optimization can be abstracted/encapsulated, or even automated: this can accelerate the CUDA kernel development flow and save
the life of a lot developers;
Rapid iteration is important: developer with minimal skills can write fairly good CUDA kernel after a couple of iterations;
Setup:
As a start, we use a simple copy kernel to show basic Triton to CUDA mappings.
import torch
import triton
import triton.language as tl
torch.manual_seed(0)
@triton.jit
def kernel_230423_01(
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
idx = tl.program_id(0)
offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
def launch_kernel_230423_01():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
grid = triton.cdiv(n, BLOCK_SIZE_N)
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 2 of 17
:
Reverse engineer the generated PTX, we get its corresponding CUDA code:
if (gid < n) {
y_ptr[gid] = x_ptr[gid];
}
}
BLOCK_SIZE_N, whose type is tl.constexpr, embedded in the generated code through constant folding and propagation;
Threads assigned to tl.arange(0, BLOCK_SIZE_N), it directly maps to threadIdx.x in this case, since default number of warps is 4,
and BLOCK_SIZE_N is 128;
Memory load and store with a block of pointers, i.e. tl.load() and tl.store(), distributed and scheduled to different threads. In
this example, each thread load and store a 32-bits value:
def launch_kernel_230423_02():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
grid = triton.cdiv(n, BLOCK_SIZE_N)
Observations:
The compiled kernel kernel_230423_01 still launched with 128 threads, changing num_warps to a compiled function doesn’t invoke
re-compilation and doesn’t take any effect. So num_warps is a captured constant during the first compilation. This is different from
CUDA, where we can change the number of blocks and threads when launching kernel;
Copy kernel_230423_01 to kernel_230423_03a and kernel_230423_03b, launch with 1 warp for kernel_230423_03a and and 16 warps
for kernel_230423_03b to see what happens if we have less or more threads:
@triton.jit
def kernel_230423_03a( # same as kernel_230423_01
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 3 of 17
:
):
idx = tl.program_id(0)
offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
@triton.jit
def kernel_230423_03b( # same as kernel_230423_01
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
idx = tl.program_id(0)
offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
def launch_kernel_230423_03():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
grid = triton.cdiv(n, BLOCK_SIZE_N)
Reverse engineer the PTX and we get the CUDA code for them:
if (gid_base < n) {
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
if (gid_base + i < n) {
y_ptr[gid_base + i] = x_ptr[gid_base + i];
}
}
}
}
if (gid < n) {
y_ptr[gid] = x_ptr[gid];
}
}
Observations:
If tile size larger than thread count, Triton will let each thread process multiple elements. In this specific case Triton unroll the code;
If tile size smaller than thread count, redundant threads will process the same data, which are wasted. Thus user has to take care of
the relation between number of warps used and the tile size;
However, it’s worth to notice that Triton uses LLVM which generates different PTX than what NVCC generates, and it’s obviously the
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 4 of 17
:
PTX generated by LLVM is better than NVCC (the result is subject to change depending on different version of NVCC):
Consider an example where we use a for loop to iterate each tile, and launched with more threads than tile size:
@triton.jit
def kernel_230424_01(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 5 of 17
:
BLOCK_SIZE_N: tl.constexpr,
):
tiles = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
for i in range(tiles):
offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < N
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
def launch_kernel_230424_01():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
Reverse engineer the generated PTX code, we get following CUDA kernel:
#pragma unroll(8)
for (int i = 0; i < 7; ++i) {
output[lane + i * 128] = input[lane + i * 128];
}
We launched 1 CTA, each with 16 warps (512 threads), tile size is 128. However, as we can see from above equivalent CUDA code, Triton
doesn’t schedule threads to parallelize the for loop, it unroll the for loop instead. The extra threads are wasted and only 128 thread
utilized. This is counterintuitive and should be taken carefully.
@triton.jit
def kernel_230424_02(
x_ptr,
y_ptr,
M: tl.constexpr, # static shape
N: tl.constexpr, # static shape
BLOCK_SIZE_M: tl.constexpr,
):
rows = tl.arange(0, BLOCK_SIZE_M)
cols = tl.arange(0, N)
tiles = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
for i in range(tiles):
# 2D tile, mask
offsets = i * BLOCK_SIZE_M * N + N * rows[:, None] + cols[None, :]
mask = (i * BLOCK_SIZE_M + rows < M)[:, None]
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
def launch_kernel_230424_02():
m, n = 1000, 32
BLOCK_SIZE_M = 128
x = torch.randn((m, n), device='cuda')
y = torch.zeros((m, n), device='cuda')
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 6 of 17
:
assert torch.allclose(x, y)
x_ptr += offset;
y_ptr += offset;
#pragma unroll(7)
for (int i = 0; i < 7; ++i) {
int *ptr1 = x_ptr + i * 4096;
int *ptr2 = y_ptr + i * 4096;
uint4 var1 = *((uint4 *)ptr1);
uint4 var2 = *((uint4 *)(ptr1 + 2048));
As you can see, the problem size is 1000x32, tile size is 128x32, we launched 16 warps (512 threads). Triton deduce that each thread need to
process 128x32/512=8 elements, so Triton vectorizes it with ld.global.v4.b32 and unroll for 2 times. Triton also takes care of the last tile
correctly in this case.
Observations:
Triton doesn’t schedule threads to parallelize the for loop, every iteration of a for loop wrote by user will be executed by all threads;
Triton deduces elements need to be processed per thread, and may potentially vectorize and/or unroll if tile size is larger than the
number of threads per CTA;
If tile size is smaller than the number of threads per CTA, Triton won’t parallelize the redundant threads across the for loop. User has
to re-organize the tile with higher dimension so that the redundant threads can be fully utilized;
@triton.jit
def kernel_230423_04(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape
BLOCK_SIZE_N: tl.constexpr,
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 7 of 17
:
):
idx = tl.program_id(0)
offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < N
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
def launch_kernel_230423_04():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
grid = triton.cdiv(n, BLOCK_SIZE_N)
Observations:
Triton can generate vectorized code with static shape. The width of vectorization depends on which multiple of N is. For fp32,
if N is multiple of 4 and 2, then it will generate ldg.128 and ldg.64. But if N is not multiple of 2, e.g. 1001, Triton can’t generate
vectorized code.
What if it’s dynamic shape? Let’s try with kernel_230423_05, who has a single CTA to iterate all tiles:
@triton.jit
def kernel_230423_05(
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
for i in range(tiles):
offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
def launch_kernel_230423_05():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 8 of 17
:
assert torch.allclose(x, y)
if (gid < n) {
y_ptr[gid] = x_ptr[gid];
}
if (gid1 < n) {
y_ptr[gid1] = x_ptr[gid1];
}
if (gid2 < n) {
y_ptr[gid2] = x_ptr[gid2];
}
if (gid3 < n) {
y_ptr[gid3] = x_ptr[gid3];
}
i += 1;
blockIdx += 128;
} while (i < ((m + ((m >> 31) >> 25)) >> 7));
}
}
Even if we give compiler hint n = tl.multiple_of(n, 4), it still can’t generate vectorized code. Let’s try to make the number of elements
as multiple of 16:
@triton.jit
def kernel_230423_06( # same as kernel_230423_05
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
for i in range(tiles):
offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 9 of 17
:
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
def launch_kernel_230423_06():
n = 1008 # make it multiple of 16
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
Another way is to specialize the last tile, since we know for sure that all tiles except the last tile won’t out-of-bounds. So we can drop mask
for them and then Triton won’t consider if the boundary will break inside a vectorized instruction:
@triton.jit
def kernel_230423_07(
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
# We know for sure tiles except for the last tile won't out-of-bounds.
# Drop mask so that Triton can vectorize it, i.e. use ldg.128.
for i in range(tiles - 1):
offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x = tl.load(x_ptr + offsets)
tl.store(y_ptr + offsets, x)
# Last tile, aware of mask, no vectorization.
offsets = (tiles - 1) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)
def launch_kernel_230423_07():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
Triton generates vectorized code for all previous tiles, and generate unrolled code for last tile:
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 10 of 17
:
@%p7 ld.global.b32 { %r31 }, [ %rd8 + 0 ];
@%p8 ld.global.b32 { %r32 }, [ %rd9 + 0 ];
add.s64 %rd10, %rd2, %rd14;
add.s64 %rd11, %rd10, 4;
add.s64 %rd12, %rd10, 8;
add.s64 %rd13, %rd10, 12;
@%p5 st.global.b32 [ %rd10 + 0 ], { %r29 };
@%p6 st.global.b32 [ %rd11 + 0 ], { %r30 };
@%p7 st.global.b32 [ %rd12 + 0 ], { %r31 };
@%p8 st.global.b32 [ %rd13 + 0 ], { %r32 };
if (m >= 256) {
int i = 0, blockIdx = 0;
while (i < num_blocks - 1) {
int gid = blockIdx | lane;
int idx = gid >> 2;
int4 temp = reinterpret_cast<int4 *>(x_ptr)[idx];
reinterpret_cast<int4 *>(y_ptr)[idx] = temp;
i += 1;
blockIdx += 128;
}
}
if (gid_last_block < n) {
y_ptr[gid_last_block] = x_ptr[gid_last_block];
}
if (gid1 < n) {
y_ptr[gid1] = x_ptr[gid1];
}
if (gid2 < n) {
y_ptr[gid2] = x_ptr[gid2];
}
if (gid3 < n) {
y_ptr[gid3] = x_ptr[gid3];
}
}
Conclusions: Triton can generate vectorized code only if Triton knows for sure the vectorized memory access per thread (e.g.
ldg.128, ldg.64) won’t traverse legal memory boundary. Triton can utilize vectorization in following situations:
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 11 of 17
:
):
cols = tl.arange(0, N)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
total = tl.sum(x, axis=0)
tl.store(y_ptr, total)
def launch_kernel_230424_03():
n = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((1, ), device='cuda')
#pragma unroll
for (unsigned N = 16; N > 0; N >>= 1) {
res += __shfl_xor_sync(0xFFFFFFFF, res, N);
}
if (lane == 0) {
global_smem[bank] = res;
}
__syncthreads();
output[0] = global_smem[0];
}
Observations:
tl.sum() is lowered to a intra-warp reduction with warp shuffle instruction first, followed by inter-warp reduction across warps
within shared memory;
tl.sum() performs butterfly reduction instead of tree reduction, so every thread has the reduced value;
However, the 2nd __syncthreads() is not necessary here, and we only need thread 0 to write out data.
Now let’s see what happens if with inter-CTA synchronization. Here we use multi CTAs to reduce a tensor to a float, each CTA only has one
warp:
@triton.jit
def kernel_230424_04(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 12 of 17
:
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
total = tl.sum(x, axis=0)
tl.atomic_add(y_ptr, total)
def launch_kernel_230424_04():
n = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((1, ), device='cuda') # has to be zeros
BLOCK_SIZE_N = 32
grid = triton.cdiv(n, BLOCK_SIZE_N)
if (tid == 0) {
global_smem[0] = temp;
}
__syncthreads();
if (tid < 1) {
global_smem[tid] = global_smem[0];
}
__syncthreads();
if (tid == 0) {
atomicAdd(output, global_smem[0]);
}
}
}
Observations:
Let write a simple transpose kernel with only 1 CTA and 1 warp:
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 13 of 17
:
@triton.jit
def kernel_230424_05(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape
):
off = tl.arange(0, N)
x_ptrs = x_ptr + off[:, None] * N + off[None, :]
y_ptrs = y_ptr + off[None, :] * N + off[:, None]
x = tl.load(x_ptrs)
tl.store(y_ptrs, x)
def launch_kernel_230424_05():
n = 16
x = torch.randint(10, (n, n), device='cuda').float()
y = torch.empty((n, n), device='cuda') # has to be zeros
kernel_230424_05[(1,)](x, y, n, num_warps=1)
assert torch.allclose(x.t(), y)
Reverse engineering the PTX code and we get its corresponding CUDA:
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
*(global_smem + smemOffset + i * 17) = *((float *)&inData1 + i);
}
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
*(global_smem + smemOffset + 8 + i * 17) = *((float *)&inData2 + i);
}
}
__syncthreads();
As we can see from above code, for the example of transposing 16x16 tile with 32 threads, Triton infers each thread can use ldg.128, so the
CTA has to split the data into two tiles, and those two tiles are unrolled. Triton hold the intermediate data in shared memory and use it for
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 14 of 17
:
data transpose.
Observations:
@triton.jit
def kernel_230424_06(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape
):
off = tl.arange(0, N)
x = tl.load(x_ptr + off)
x = tl.sigmoid(x)
tl.store(y_ptr + off, x)
def launch_kernel_230424_06():
n = 32
x = torch.randn((n,), device='cuda')
y = torch.empty((n,), device='cuda')
kernel_230424_06[(1,)](x, y, n, num_warps=1)
assert torch.allclose(x.sigmoid(), y)
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 15 of 17
:
fma.rn.f32 %f15, %f2, %f14, %f13; |
mov.b32 %r3, %f10; |
shl.b32 %r4, %r3, 23; |
mov.b32 %f16, %r4; |
ex2.approx.ftz.f32 %f17, %f15; |
fma.rn.f32 %f18, %f17, %f16, 0f3F800000; |
rcp.rn.f32 %f19, %f18; |
cvta.to.global.u64 %rd7, %rd2; |
add.s64 %rd8, %rd7, %rd5; |
st.global.f32 [%rd8], %f19; |
Observations:
Triton uses fast math for exp(), but doesn’t use fast math for rcp();
def launch_kernel_230429_01():
m, n, k = 16, 16, 16
a = torch.randn((m, k), device='cuda', dtype=torch.float16)
b = torch.randn((k, n), device='cuda', dtype=torch.float16)
d = torch.empty((m, n), device='cuda', dtype=torch.float16)
kernel_230429_01[(1,)](a, b, d, m, n, k, num_warps=1)
assert torch.allclose(a @ b, d, atol=1e-2, rtol=0)
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 16 of 17
:
What do you think?
11 Responses
3 Comments
1 Login
Name
郭
郭⼤瘦 − ⚑
a year ago
Hi,请教下是如何逆向ptx的呢?
0 0 Reply Share ›
0 0 Reply Share ›
郭
郭⼤瘦 > Wil Kong − ⚑
a year ago
了解了,我⽤gpt4试试,thanks~
0 0 Reply Share ›
https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 17 of 17
: