0% found this document useful (0 votes)
34 views17 pages

Demystify OpenAI Triton Fkong' Tech Blog

The document discusses OpenAI Triton, a programming language designed to simplify GPU programming for researchers without CUDA experience. It highlights Triton's ability to automate optimizations and generate efficient GPU code, while also noting its limitations such as debugging difficulties and certain kernel restrictions. The article includes examples and comparisons of Triton-generated code with traditional CUDA code, emphasizing the performance benefits and rapid prototyping capabilities of Triton.

Uploaded by

冯宝枢
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
34 views17 pages

Demystify OpenAI Triton Fkong' Tech Blog

The document discusses OpenAI Triton, a programming language designed to simplify GPU programming for researchers without CUDA experience. It highlights Triton's ability to automate optimizations and generate efficient GPU code, while also noting its limitations such as debugging difficulties and certain kernel restrictions. The article includes examples and comparisons of Triton-generated code with traditional CUDA code, emphasizing the performance benefits and rapid prototyping capabilities of Triton.

Uploaded by

冯宝枢
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 17

fkong' tech blog

About
Blog
Categories
Series
Tags

Demystify OpenAI Triton

April 23, 2023 24-minute read


Fei Kong
High Performance Computing
OpenAI Triton • CUDA • PTX

Introduction Link to heading


The original claim of OpenAI Triton is:

We’re releasing Triton 1.0, an open-source Python-like programming language which enables researchers with no CUDA
experience to write highly efficient GPU code — most of the time on par with what an expert would be able to produce.

The core ideas of Triton:

Program GPU with Python: so that the effort to program GPU device is minimal;
Encapsulate optimizations within SM: user should focus on partitioning the job and schedule them on SMs, Triton is then
responsible for optimizing the code within SM automatically;

How did Triton implement it?

Triton turns Python code to Triton IR on-the-fly through Python AST, then Triton optimizes and lowers it to LLVM-IR / MLIR, followed by
generating PTX directly through libLLVM, and compile to cubin through ptxas.

Did Triton manage to achieve its goal?

Program GPU with Python: pretty successful, we can write a functional CUDA kernel with Triton quite fast;
Encapsulate optimizations within SM: I would say Triton only achieved 60% of this goal. User usually can’t generate an efficient
kernel directly, as one concrete example, we wrote a very first version GroupNorm kernel using Triton, which is functionally correct,
but not necessarily efficient without knowing what Triton did to the code;

What’re the advantages of Triton?

Easy to use: requires far less time to write a kernel compared to CUDA;
Performant: can generate kernel with comparable performance as skilled CUDA programmer, the prerequisite is user knows how to
tune Triton code;

What’re the disadvantages of Triton?

Hard to debug: the whole optimization is black-box, user has to go through PTX/IR to understand the issues;
Limitations: some kernel can’t be implemented with Triton due to limitations, e.g. tile size must be power of 2, doesn’t support slice;

What can we do with Triton?

Fast prototyping: we can experiment various ideas with Triton and iterate very fast, we can implement the idea with CUDA at last if
nailed;
New kernels/operators: we can use Triton to generate kernels/operators not haven’t been implemented in
cuBLAS/cuDNN/frameworks;

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 1 of 17
:
Compared to the manual written CUDA kernel, where do the performance benefits of Triton mainly come from?

Triton can automate a bunch of optimizations, which CUDA developer may not aware of;
Kernels from library like cuBLAS has universality requirement, Triton can drop it on-demand and generate simpler code, e.g. bias
addition in cuBLAS GEMM;
Triton kernels can be auto-tuned;
LLVM generates better PTX than NVCC sometimes, e.g. loop unroll;

What inspiration can we draw from Triton?

Optimization can be abstracted/encapsulated, or even automated: this can accelerate the CUDA kernel development flow and save
the life of a lot developers;
Rapid iteration is important: developer with minimal skills can write fairly good CUDA kernel after a couple of iterations;

Basic Link to heading


The goal of this experiment is to build the basic mapping from Triton to CUDA, so that Triton novice can get a sense of what kind of
underlying CUDA code he is manipulating. This is important for us to utilize Triton for fast prototyping while ensuring the performance is
reasonable, since it’s not easy to debug the performance of Triton generated kernels. Due to the complexities of Triton automatic
optimization pipeline, we won’t expect the discoveries to apply to all of the cases. We only expect it covers 80% of the cases to make it
simple.

Setup:

NVIDIA 23.03 PyTorch contaienr;


Triton d54c04a;
GH100-700W;
Measured on 04/23/2023;

As a start, we use a simple copy kernel to show basic Triton to CUDA mappings.

import torch

import triton
import triton.language as tl

torch.manual_seed(0)

@triton.jit
def kernel_230423_01(
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
idx = tl.program_id(0)
offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_01():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
grid = triton.cdiv(n, BLOCK_SIZE_N)

kernel_230423_01[(grid, )](x, y, n, BLOCK_SIZE_N)


assert torch.allclose(x, y)

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 2 of 17
:
Reverse engineer the generated PTX, we get its corresponding CUDA code:

__global__ void kernel_230423_01(float *x_ptr, float *y_ptr, int n) {


const int BLOCK_SIZE_N = 128;
int lane = threadIdx.x & (BLOCK_SIZE_N - 1);
int gid = (blockIdx.x << 7) | lane; // LOG(128) = 7

if (gid < n) {
y_ptr[gid] = x_ptr[gid];
}
}

It’s a very simple example, but we can derive following conclusions:

(grid,) maps to (gridDim.x,);

Tensor x casted to pointer x_ptr implicitly;

BLOCK_SIZE_N, whose type is tl.constexpr, embedded in the generated code through constant folding and propagation;

tl.program_id(0) maps to blockIdx.x;

Threads assigned to tl.arange(0, BLOCK_SIZE_N), it directly maps to threadIdx.x in this case, since default number of warps is 4,
and BLOCK_SIZE_N is 128;

mask maps to the if condition, and finally turned to predicate;

Memory load and store with a block of pointers, i.e. tl.load() and tl.store(), distributed and scheduled to different threads. In
this example, each thread load and store a 32-bits value:

@%p1 ld.global.b32 { %r2 }, [ %rd1 + 0 ];


@%p1 st.global.b32 [ %rd2 + 0 ], { %r2 };

Scheduling Link to heading


Re-launch the kernel with fewer warps:

def launch_kernel_230423_02():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
grid = triton.cdiv(n, BLOCK_SIZE_N)

kernel_230423_01[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=2)


assert torch.allclose(x, y)

Observations:

The compiled kernel kernel_230423_01 still launched with 128 threads, changing num_warps to a compiled function doesn’t invoke
re-compilation and doesn’t take any effect. So num_warps is a captured constant during the first compilation. This is different from
CUDA, where we can change the number of blocks and threads when launching kernel;

Copy kernel_230423_01 to kernel_230423_03a and kernel_230423_03b, launch with 1 warp for kernel_230423_03a and and 16 warps
for kernel_230423_03b to see what happens if we have less or more threads:

@triton.jit
def kernel_230423_03a( # same as kernel_230423_01
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 3 of 17
:
):
idx = tl.program_id(0)
offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

@triton.jit
def kernel_230423_03b( # same as kernel_230423_01
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
idx = tl.program_id(0)
offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_03():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
grid = triton.cdiv(n, BLOCK_SIZE_N)

kernel_230423_03a[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=1)


assert torch.allclose(x, y)
y.fill_(0)
kernel_230423_03b[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=16)
assert torch.allclose(x, y)

Reverse engineer the PTX and we get the CUDA code for them:

__global__ void kernel_230423_03a(int *x_ptr, int *y_ptr, int n) {


int lane = (threadIdx.x << 2) & 124;
int gid_base = (blockIdx.x << 7) | lane;

if (gid_base < n) {
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
if (gid_base + i < n) {
y_ptr[gid_base + i] = x_ptr[gid_base + i];
}
}
}
}

__global__ void kernel_230423_03b(int *x_ptr, int *y_ptr, int n) {


int lane = threadIdx.x & 127;
int gid = (blockIdx.x << 7) | lane;

if (gid < n) {
y_ptr[gid] = x_ptr[gid];
}
}

Observations:

If tile size larger than thread count, Triton will let each thread process multiple elements. In this specific case Triton unroll the code;
If tile size smaller than thread count, redundant threads will process the same data, which are wasted. Thus user has to take care of
the relation between number of warps used and the tile size;

However, it’s worth to notice that Triton uses LLVM which generates different PTX than what NVCC generates, and it’s obviously the

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 4 of 17
:
PTX generated by LLVM is better than NVCC (the result is subject to change depending on different version of NVCC):

// Generated by Triton and LLVM | // Generated by NVCC


.visible .entry kernel_03a_0d1d2( | .visible .entry _Z10kernel_03aPiS_i(
.param .u64 kernel_03a_0d1d2_param_0, | .param .u64 _Z10kernel_03aPiS_i_param_0,
.param .u64 kernel_03a_0d1d2_param_1, | .param .u64 _Z10kernel_03aPiS_i_param_1,
.param .u32 kernel_03a_0d1d2_param_2 | .param .u32 _Z10kernel_03aPiS_i_param_2
) | )
.maxntid 32, 1, 1 | {
{ | .reg .pred %p<5>;
.reg .pred %p<9>; | .reg .b32 %r<15>;
.reg .b32 %r<19>; | .reg .b64 %rd<8>;
.reg .b64 %rd<12>; |
| ld.param.u64 %rd3, [_Z10kernel_03aPiS_i_param_0];
ld.param.u64 %rd9, [kernel_03a_0d1d2_param_0]; | ld.param.u64 %rd4, [_Z10kernel_03aPiS_i_param_1];
ld.param.u64 %rd10, [kernel_03a_0d1d2_param_1]; | ld.param.u32 %r2, [_Z10kernel_03aPiS_i_param_2];
mov.u32 %r9, %tid.x; | mov.u32 %r3, %tid.x;
shl.b32 %r10, %r9, 2; | shl.b32 %r4, %r3, 2;
ld.param.u32 %r11, [kernel_03a_0d1d2_param_2]; | and.b32 %r5, %r4, 124;
and.b32 %r12, %r10, 124; | mov.u32 %r6, %ctaid.x;
mov.u32 %r13, %ctaid.x; | shl.b32 %r7, %r6, 7;
shl.b32 %r14, %r13, 7; | or.b32 %r1, %r5, %r7;
or.b32 %r15, %r12, %r14; | setp.ge.s32 %p1, %r1, %r2;
or.b32 %r16, %r15, 1; | @%p1 bra $L__BB0_7;
or.b32 %r17, %r15, 2; |
or.b32 %r18, %r15, 3; | cvta.to.global.u64 %rd5, %rd3;
setp.lt.s32 %p1, %r15, %r11; | mul.wide.s32 %rd6, %r1, 4;
setp.lt.s32 %p2, %r16, %r11; | add.s64 %rd1, %rd5, %rd6;
setp.lt.s32 %p3, %r17, %r11; | ld.global.u32 %r8, [%rd1];
setp.lt.s32 %p4, %r18, %r11; | cvta.to.global.u64 %rd7, %rd4;
mul.wide.s32 %rd11, %r15, 4; | add.s64 %rd2, %rd7, %rd6;
add.s64 %rd1, %rd9, %rd11; | st.global.u32 [%rd2], %r8;
add.s64 %rd2, %rd1, 4; | add.s32 %r9, %r1, 1;
add.s64 %rd3, %rd1, 8; | setp.ge.s32 %p2, %r9, %r2;
add.s64 %rd4, %rd1, 12; | @%p2 bra $L__BB0_3;
@%p1 ld.global.b32 { %r5 }, [ %rd1 + 0 ]; |
@%p2 ld.global.b32 { %r6 }, [ %rd2 + 0 ]; | ld.global.u32 %r10, [%rd1+4];
@%p3 ld.global.b32 { %r7 }, [ %rd3 + 0 ]; | st.global.u32 [%rd2+4], %r10;
@%p4 ld.global.b32 { %r8 }, [ %rd4 + 0 ]; |
add.s64 %rd5, %rd10, %rd11; | $L__BB0_3:
add.s64 %rd6, %rd5, 4; | add.s32 %r11, %r1, 2;
add.s64 %rd7, %rd5, 8; | setp.ge.s32 %p3, %r11, %r2;
add.s64 %rd8, %rd5, 12; | @%p3 bra $L__BB0_5;
@%p1 st.global.b32 [ %rd5 + 0 ], { %r5 }; |
@%p2 st.global.b32 [ %rd6 + 0 ], { %r6 }; | ld.global.u32 %r12, [%rd1+8];
@%p3 st.global.b32 [ %rd7 + 0 ], { %r7 }; | st.global.u32 [%rd2+8], %r12;
@%p4 st.global.b32 [ %rd8 + 0 ], { %r8 }; |
ret; | $L__BB0_5:
| add.s32 %r13, %r1, 3;
} | setp.ge.s32 %p4, %r13, %r2;
| @%p4 bra $L__BB0_7;
|
| ld.global.u32 %r14, [%rd1+12];
| st.global.u32 [%rd2+12], %r14;
|
| $L__BB0_7:
| ret;
| }

Consider an example where we use a for loop to iterate each tile, and launched with more threads than tile size:

@triton.jit
def kernel_230424_01(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 5 of 17
:
BLOCK_SIZE_N: tl.constexpr,
):
tiles = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
for i in range(tiles):
offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < N
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230424_01():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')

kernel_230424_01[(1, )](x, y, n, BLOCK_SIZE_N, num_warps=16)


assert torch.allclose(x, y)

Reverse engineer the generated PTX code, we get following CUDA kernel:

__global__ void kernel_230424_01(int *input, int *output) {


int tid = threadIdx.x;
int lane = tid & 127;

#pragma unroll(8)
for (int i = 0; i < 7; ++i) {
output[lane + i * 128] = input[lane + i * 128];
}

if ((lane | (7 * 128)) < 1000) {


output[lane + 7 * 128] = input[lane + 7 * 128];
}
}

We launched 1 CTA, each with 16 warps (512 threads), tile size is 128. However, as we can see from above equivalent CUDA code, Triton
doesn’t schedule threads to parallelize the for loop, it unroll the for loop instead. The extra threads are wasted and only 128 thread
utilized. This is counterintuitive and should be taken carefully.

What happens if we make it as 2D tile?

@triton.jit
def kernel_230424_02(
x_ptr,
y_ptr,
M: tl.constexpr, # static shape
N: tl.constexpr, # static shape
BLOCK_SIZE_M: tl.constexpr,
):
rows = tl.arange(0, BLOCK_SIZE_M)
cols = tl.arange(0, N)
tiles = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
for i in range(tiles):
# 2D tile, mask
offsets = i * BLOCK_SIZE_M * N + N * rows[:, None] + cols[None, :]
mask = (i * BLOCK_SIZE_M + rows < M)[:, None]
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230424_02():
m, n = 1000, 32
BLOCK_SIZE_M = 128
x = torch.randn((m, n), device='cuda')
y = torch.zeros((m, n), device='cuda')

kernel_230424_02[(1, )](x, y, m, n, BLOCK_SIZE_M, num_warps=16)

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 6 of 17
:
assert torch.allclose(x, y)

Reverse engineer the PTX and we got following CUDA code:

__global__ void kernel_230424_02(int *x_ptr, int *y_ptr) {


int tid_x = threadIdx.x;
int lane_row = (tid_x >> 3) & 127;
int lane_col = (tid_x << 2) & 28;
int offset = (lane_row << 5) | lane_col;

x_ptr += offset;
y_ptr += offset;

#pragma unroll(7)
for (int i = 0; i < 7; ++i) {
int *ptr1 = x_ptr + i * 4096;
int *ptr2 = y_ptr + i * 4096;
uint4 var1 = *((uint4 *)ptr1);
uint4 var2 = *((uint4 *)(ptr1 + 2048));

*((uint4 *)ptr2) = var1;


*((uint4 *)(ptr2 + 2048)) = var2;
}

if ((lane_row | 896) < 1000) {


int *ptr1 = x_ptr + 28672;
int *ptr2 = y_ptr + 28672;
uint4 temp = *((uint4 *)ptr1);
*((uint4 *)ptr2) = temp;
}

if (lane_row < 40) {


int *ptr1 = x_ptr + 30720;
int *ptr2 = y_ptr + 30720;
uint4 temp = *((uint4 *)ptr1);
*((uint4 *)ptr2) = temp;
}
}

As you can see, the problem size is 1000x32, tile size is 128x32, we launched 16 warps (512 threads). Triton deduce that each thread need to
process 128x32/512=8 elements, so Triton vectorizes it with ld.global.v4.b32 and unroll for 2 times. Triton also takes care of the last tile
correctly in this case.

Observations:

Triton doesn’t schedule threads to parallelize the for loop, every iteration of a for loop wrote by user will be executed by all threads;
Triton deduces elements need to be processed per thread, and may potentially vectorize and/or unroll if tile size is larger than the
number of threads per CTA;
If tile size is smaller than the number of threads per CTA, Triton won’t parallelize the redundant threads across the for loop. User has
to re-organize the tile with higher dimension so that the redundant threads can be fully utilized;

Vectorization Link to heading


As we mentioned earlier, Triton will generate unrolled code if each thread process multiple elements. When will Triton vectorize the code?

Let’s start by making kernel_230423_04 as static shape:

@triton.jit
def kernel_230423_04(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape
BLOCK_SIZE_N: tl.constexpr,

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 7 of 17
:
):
idx = tl.program_id(0)
offsets = idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < N
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_04():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')
grid = triton.cdiv(n, BLOCK_SIZE_N)

kernel_230423_04[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=1)


assert torch.allclose(x, y)

The generated PTX contains vectorized load/store:

add.s64 %rd1, %rd3, %rd5;


@%p1 ld.global.v4.b32 { %r5, %r6, %r7, %r8 }, [ %rd1 + 0 ];
add.s64 %rd2, %rd4, %rd5;
@%p1 st.global.v4.b32 [ %rd2 + 0 ], { %r5, %r6, %r7, %r8 };

Its corresponding CUDA code is:

__global__ void kernel_230423_04(int4 *x_ptr, int4 *y_ptr) {


int lane = (threadIdx.x << 2) & 124;
int gid = (blockIdx.x << 7) | lane;
int idx = gid >> 2;

if (gid < 1000) {


y_ptr[idx] = x_ptr[idx];
}
}

Observations:

Triton can generate vectorized code with static shape. The width of vectorization depends on which multiple of N is. For fp32,
if N is multiple of 4 and 2, then it will generate ldg.128 and ldg.64. But if N is not multiple of 2, e.g. 1001, Triton can’t generate
vectorized code.

What if it’s dynamic shape? Let’s try with kernel_230423_05, who has a single CTA to iterate all tiles:

@triton.jit
def kernel_230423_05(
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
for i in range(tiles):
offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_05():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')

kernel_230423_05[(1, )](x, y, n, BLOCK_SIZE_N, num_warps=1)

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 8 of 17
:
assert torch.allclose(x, y)

The generated PTX contains unrolled load/store:

add.s64 %rd3, %rd1, %rd11;


add.s64 %rd4, %rd3, 4;
add.s64 %rd5, %rd3, 8;
add.s64 %rd6, %rd3, 12;
@%p2 ld.global.b32 { %r15 }, [ %rd3 + 0 ];
@%p3 ld.global.b32 { %r16 }, [ %rd4 + 0 ];
@%p4 ld.global.b32 { %r17 }, [ %rd5 + 0 ];
@%p5 ld.global.b32 { %r18 }, [ %rd6 + 0 ];
add.s64 %rd7, %rd2, %rd11;
add.s64 %rd8, %rd7, 4;
add.s64 %rd9, %rd7, 8;
add.s64 %rd10, %rd7, 12;
@%p2 st.global.b32 [ %rd7 + 0 ], { %r15 };
@%p3 st.global.b32 [ %rd8 + 0 ], { %r16 };
@%p4 st.global.b32 [ %rd9 + 0 ], { %r17 };
@%p5 st.global.b32 [ %rd10 + 0 ], { %r18 };

Its corresponding CUDA code is:

__global__ void kernel_230423_05(int *x_ptr, int *y_ptr, int n) {


int lane = (threadIdx.x << 2) & 124;
int m = n + 127;
if (m >= 128) {
int i = 0, blockIdx = 0;
do {
int gid = blockIdx | lane;
int gid1 = gid + 1;
int gid2 = gid + 2;
int gid3 = gid + 3;

if (gid < n) {
y_ptr[gid] = x_ptr[gid];
}
if (gid1 < n) {
y_ptr[gid1] = x_ptr[gid1];
}
if (gid2 < n) {
y_ptr[gid2] = x_ptr[gid2];
}
if (gid3 < n) {
y_ptr[gid3] = x_ptr[gid3];
}

i += 1;
blockIdx += 128;
} while (i < ((m + ((m >> 31) >> 25)) >> 7));
}
}

Even if we give compiler hint n = tl.multiple_of(n, 4), it still can’t generate vectorized code. Let’s try to make the number of elements
as multiple of 16:

@triton.jit
def kernel_230423_06( # same as kernel_230423_05
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
for i in range(tiles):
offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 9 of 17
:
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_06():
n = 1008 # make it multiple of 16
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')

kernel_230423_06[(1, )](x, y, n, BLOCK_SIZE_N, num_warps=1)


assert torch.allclose(x, y)

Then the generated PTX is vectorized:

add.s64 %rd3, %rd1, %rd5;


@%p2 ld.global.v4.b32 { %r15, %r16, %r17, %r18 }, [ %rd3 + 0 ];
add.s64 %rd4, %rd2, %rd5;
@%p2 st.global.v4.b32 [ %rd4 + 0 ], { %r15, %r16, %r17, %r18 };

Another way is to specialize the last tile, since we know for sure that all tiles except the last tile won’t out-of-bounds. So we can drop mask
for them and then Triton won’t consider if the boundary will break inside a vectorized instruction:

@triton.jit
def kernel_230423_07(
x_ptr,
y_ptr,
n,
BLOCK_SIZE_N: tl.constexpr,
):
tiles = (n + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
# We know for sure tiles except for the last tile won't out-of-bounds.
# Drop mask so that Triton can vectorize it, i.e. use ldg.128.
for i in range(tiles - 1):
offsets = i * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
x = tl.load(x_ptr + offsets)
tl.store(y_ptr + offsets, x)
# Last tile, aware of mask, no vectorization.
offsets = (tiles - 1) * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(y_ptr + offsets, x, mask=mask)

def launch_kernel_230423_07():
n = 1000
BLOCK_SIZE_N = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((n, ), device='cuda')

kernel_230423_07[(1, )](x, y, n, BLOCK_SIZE_N, num_warps=1)


assert torch.allclose(x, y)

Triton generates vectorized code for all previous tiles, and generate unrolled code for last tile:

add.s64 %rd3, %rd1, %rd5;


@%p2 ld.global.v4.b32 { %r16, %r17, %r18, %r19 }, [ %rd3 + 0 ];
add.s64 %rd4, %rd2, %rd5;
@%p2 st.global.v4.b32 [ %rd4 + 0 ], { %r16, %r17, %r18, %r19 };
...
...
add.s64 %rd6, %rd1, %rd14;
add.s64 %rd7, %rd6, 4;
add.s64 %rd8, %rd6, 8;
add.s64 %rd9, %rd6, 12;
@%p5 ld.global.b32 { %r29 }, [ %rd6 + 0 ];
@%p6 ld.global.b32 { %r30 }, [ %rd7 + 0 ];

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 10 of 17
:
@%p7 ld.global.b32 { %r31 }, [ %rd8 + 0 ];
@%p8 ld.global.b32 { %r32 }, [ %rd9 + 0 ];
add.s64 %rd10, %rd2, %rd14;
add.s64 %rd11, %rd10, 4;
add.s64 %rd12, %rd10, 8;
add.s64 %rd13, %rd10, 12;
@%p5 st.global.b32 [ %rd10 + 0 ], { %r29 };
@%p6 st.global.b32 [ %rd11 + 0 ], { %r30 };
@%p7 st.global.b32 [ %rd12 + 0 ], { %r31 };
@%p8 st.global.b32 [ %rd13 + 0 ], { %r32 };

Its corresponding CUDA code is:

__global__ void kernel_230423_07(int *x_ptr, int *y_ptr, int n) {


int lane = (threadIdx.x << 2) & 124;
int m = n + 127;
int num_blocks = (m + ((m >> 31) >> 25)) >> 7;

if (m >= 256) {
int i = 0, blockIdx = 0;
while (i < num_blocks - 1) {
int gid = blockIdx | lane;
int idx = gid >> 2;
int4 temp = reinterpret_cast<int4 *>(x_ptr)[idx];
reinterpret_cast<int4 *>(y_ptr)[idx] = temp;

i += 1;
blockIdx += 128;
}
}

int gid_last_block = ((num_blocks - 1) << 7) | lane;


int gid1 = gid_last_block + 1;
int gid2 = gid_last_block + 2;
int gid3 = gid_last_block + 3;

if (gid_last_block < n) {
y_ptr[gid_last_block] = x_ptr[gid_last_block];
}
if (gid1 < n) {
y_ptr[gid1] = x_ptr[gid1];
}
if (gid2 < n) {
y_ptr[gid2] = x_ptr[gid2];
}
if (gid3 < n) {
y_ptr[gid3] = x_ptr[gid3];
}
}

Conclusions: Triton can generate vectorized code only if Triton knows for sure the vectorized memory access per thread (e.g.
ldg.128, ldg.64) won’t traverse legal memory boundary. Triton can utilize vectorization in following situations:

The input is static shape, and it’s at least multiple of 2;


The input is dynamic shape, and the shape is multiple of 16;
The input is dynamic shape, and user specialize the last tile;

Reduction Link to heading


@triton.jit
def kernel_230424_03(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 11 of 17
:
):
cols = tl.arange(0, N)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
total = tl.sum(x, axis=0)
tl.store(y_ptr, total)

def launch_kernel_230424_03():
n = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((1, ), device='cuda')

kernel_230424_03[(1, )](x, y, n, num_warps=4)


assert torch.allclose(x.sum(), y)

Its corresponding CUDA code is:

__global__ void kernel_230424_03(float *input, float *output) {


__shared__ float global_smem[128];

int tid = threadIdx.x;


int idx = tid & 127;
int lane = tid & 31;
int bank = (tid >> 3) & 0x1FFFFFFC;

float res = input[idx];

#pragma unroll
for (unsigned N = 16; N > 0; N >>= 1) {
res += __shfl_xor_sync(0xFFFFFFFF, res, N);
}
if (lane == 0) {
global_smem[bank] = res;
}
__syncthreads();

float shared_val = global_smem[tid];


#pragma unroll
for (unsigned N = 2; N > 0; N >>= 1) {
shared_val += __shfl_xor_sync(0xFFFFFFFF, shared_val, N);
}
if ((tid < 4) && ((tid & 3) == 0)) {
global_smem[tid] = shared_val;
}
__syncthreads();

output[0] = global_smem[0];
}

Observations:

tl.sum() is lowered to a intra-warp reduction with warp shuffle instruction first, followed by inter-warp reduction across warps
within shared memory;
tl.sum() performs butterfly reduction instead of tree reduction, so every thread has the reduced value;

However, the 2nd __syncthreads() is not necessary here, and we only need thread 0 to write out data.

Now let’s see what happens if with inter-CTA synchronization. Here we use multi CTAs to reduce a tensor to a float, each CTA only has one
warp:

@triton.jit
def kernel_230424_04(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 12 of 17
:
BLOCK_SIZE_N: tl.constexpr,
):
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
x = tl.load(x_ptr + cols, mask=mask, other=0.0)
total = tl.sum(x, axis=0)
tl.atomic_add(y_ptr, total)

def launch_kernel_230424_04():
n = 128
x = torch.randn((n, ), device='cuda')
y = torch.zeros((1, ), device='cuda') # has to be zeros
BLOCK_SIZE_N = 32
grid = triton.cdiv(n, BLOCK_SIZE_N)

kernel_230424_04[(grid, )](x, y, n, BLOCK_SIZE_N, num_warps=1)


assert torch.allclose(x.sum(), y)

After reverse engineering the PTX, we got following CUDA code:

__global__ void kernel_230424_04(float *input, float *output) {


__shared__ float global_smem[32];

int tid = threadIdx.x;


int ctaid = blockIdx.x;
int globalIdx = (ctaid << 5) | tid;

if (globalIdx < 128) {


float temp = input[globalIdx];

// Perform butterfly shuffle operations and sum the values


#pragma unroll
for (unsigned N = 16; N > 0; N >>= 1) {
temp += __shfl_xor_sync(0xFFFFFFFF, temp, N, 0x1F);
}

if (tid == 0) {
global_smem[0] = temp;
}
__syncthreads();

if (tid < 1) {
global_smem[tid] = global_smem[0];
}
__syncthreads();

if (tid == 0) {
atomicAdd(output, global_smem[0]);
}
}
}

Observations:

Atomic operations like tl.atomic_add() is translated to CUDA atomicAdd() directly;


Triton can be optimized here, since shared memory and __syncthreads() are not required here;

Memory Link to heading


User can’t control whether the loaded data stored, shared memory or register file. However, there are rules that we infer where the data is
stored.

Let write a simple transpose kernel with only 1 CTA and 1 warp:

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 13 of 17
:
@triton.jit
def kernel_230424_05(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape
):
off = tl.arange(0, N)
x_ptrs = x_ptr + off[:, None] * N + off[None, :]
y_ptrs = y_ptr + off[None, :] * N + off[:, None]
x = tl.load(x_ptrs)
tl.store(y_ptrs, x)

def launch_kernel_230424_05():
n = 16
x = torch.randint(10, (n, n), device='cuda').float()
y = torch.empty((n, n), device='cuda') # has to be zeros

kernel_230424_05[(1,)](x, y, n, num_warps=1)
assert torch.allclose(x.t(), y)

Reverse engineering the PTX code and we get its corresponding CUDA:

__global__ void kernel_230424_05(float *input, float *output) {


__shared__ float global_smem[17 * 17];

int tid = threadIdx.x;


int row = (tid >> 2) & 0xF;
int col = (tid << 2) & 0xC;
int off1 = (row << 4) + col;
int off2 = (row << 4) + 128 + col;

if (tid < 32) {


int smemOffset = col * 17 + row;
float4 inData1 = *((float4 *)(input + off1));
float4 inData2 = *((float4 *)(input + off2));

#pragma unroll(4)
for (int i = 0; i < 4; i++) {
*(global_smem + smemOffset + i * 17) = *((float *)&inData1 + i);
}
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
*(global_smem + smemOffset + 8 + i * 17) = *((float *)&inData2 + i);
}
}
__syncthreads();

if (tid < 32) {


int smemOffset = row * 17 + col;
float4 outData1, outData2;
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
*((float *)&outData1 + i) = *(global_smem + smemOffset + i);
}
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
*((float *)&outData2 + i) = *(global_smem + smemOffset + 136 + i);
}

*((float4 *)(output + off1)) = outData1;


*((float4 *)(output + off2)) = outData2;
}
}

As we can see from above code, for the example of transposing 16x16 tile with 32 threads, Triton infers each thread can use ldg.128, so the
CTA has to split the data into two tiles, and those two tiles are unrolled. Triton hold the intermediate data in shared memory and use it for

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 14 of 17
:
data transpose.

Observations:

Triton put the loaded tile in register file by default;

Triton will put the tile into shared memory if 3 situations:

1. The operator requires shared memory: e.g. tl.sum() for reduction;


2. Layout transformation is required: e.g. transpose operation;
3. The operator requires shared memory operand: e.g. tl.dot() for GEMM;

Instruction Link to heading


Now let’s check if Triton generates fast math instruction with a sigmoid function:

@triton.jit
def kernel_230424_06(
x_ptr,
y_ptr,
N: tl.constexpr, # static shape
):
off = tl.arange(0, N)
x = tl.load(x_ptr + off)
x = tl.sigmoid(x)
tl.store(y_ptr + off, x)

def launch_kernel_230424_06():
n = 32
x = torch.randn((n,), device='cuda')
y = torch.empty((n,), device='cuda')

kernel_230424_06[(1,)](x, y, n, num_warps=1)
assert torch.allclose(x.sigmoid(), y)

Its corresponding CUDA kernel is:

__global__ void kernel_230424_06(float *input, float *output) {


int tid = threadIdx.x;
int idx = tid & 31;

if (tid < 32) {


float inValue = input[idx];
output[idx] = 1.0f / (1.0f + expf(-inValue));
}
}

However, the generated PTX has some differences:

// nvcc, no fast math | triton | nvcc, fast math


ld.global.f32 %f1, [%rd6]; | @%p1 ld.global.b32 {%r1}, [%rd1+0]; | ld.global.f32 %f1, [%rd6];
neg.f32 %f2, %f1; | mov.b32 %f3, %r1; | mul.f32 %f2, %f1, 0fBFB8AA3B;
mov.f32 %f3, 0f3F000000; | mov.f32 %f4, 0f00000000; | ex2.approx.f32 %f3, %f2;
mov.f32 %f4, 0f3BBB989D; | sub.f32 %f5, %f4, %f3; | add.f32 %f4, %f3, 0f3F800000;
fma.rn.f32 %f5, %f2, %f4, %f3; | mul.f32 %f2, %f5, 0f3FB8AA3B; | rcp.approx.f32 %f5, %f4;
mov.f32 %f6, 0f3FB8AA3B; | ex2.approx.f32 %f1, %f2; | cvta.to.global.u64 %rd7, %rd2;
mov.f32 %f7, 0f437C0000; | add.f32 %f6, %f1, 0f3F800000; | add.s64 %rd8, %rd7, %rd5;
cvt.sat.f32.f32 %f8, %f5; | mov.b32 %r4, %f6; | st.global.f32 [%rd8], %f5;
mov.f32 %f9, 0f4B400001; | mov.u32 %r3, 1065353216; |
fma.rm.f32 %f10, %f8, %f7, %f9; | div.full.f32 %r5, %r3, %r4; |
add.f32 %f11, %f10, 0fCB40007F; | add.s64 %rd2, %rd4, %rd5; |
neg.f32 %f12, %f11; | @%p1 st.global.b32 [%rd2+0], {%r5}; |
fma.rn.f32 %f13, %f2, %f6, %f12; |
mov.f32 %f14, 0f32A57060; |

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 15 of 17
:
fma.rn.f32 %f15, %f2, %f14, %f13; |
mov.b32 %r3, %f10; |
shl.b32 %r4, %r3, 23; |
mov.b32 %f16, %r4; |
ex2.approx.ftz.f32 %f17, %f15; |
fma.rn.f32 %f18, %f17, %f16, 0f3F800000; |
rcp.rn.f32 %f19, %f18; |
cvta.to.global.u64 %rd7, %rd2; |
add.s64 %rd8, %rd7, %rd5; |
st.global.f32 [%rd8], %f19; |

Observations:

Triton uses fast math for exp(), but doesn’t use fast math for rcp();

Tensor Core Link to heading


@triton.jit
def kernel_230429_01(
a_ptr,
b_ptr,
d_ptr,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
):
offs_m = tl.arange(0, M)
offs_n = tl.arange(0, N)
offs_k = tl.arange(0, K)
a_ptrs = a_ptr + (offs_m[:, None] * K + offs_k[None, :])
b_ptrs = b_ptr + (offs_k[:, None] * N + offs_n[None, :])
d_ptrs = d_ptr + (offs_m[:, None] * N + offs_n[None, :])
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
d = tl.dot(a, b)
tl.store(d_ptrs, d)

def launch_kernel_230429_01():
m, n, k = 16, 16, 16
a = torch.randn((m, k), device='cuda', dtype=torch.float16)
b = torch.randn((k, n), device='cuda', dtype=torch.float16)
d = torch.empty((m, n), device='cuda', dtype=torch.float16)

kernel_230429_01[(1,)](a, b, d, m, n, k, num_warps=1)
assert torch.allclose(a @ b, d, atol=1e-2, rtol=0)

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 16 of 17
:
What do you think?
11 Responses

Upvote Funny Love Surprised Angry Sad

3 Comments 
1 Login

G Join the discussion…

LOG IN WITH OR SIGN UP WITH DISQUS ?

Name

 Share Best Newest Oldest


郭⼤瘦 − ⚑
a year ago

Hi,请教下是如何逆向ptx的呢?

0 0 Reply Share ›

Wil Kong Mod > 郭⼤瘦 − ⚑


a year ago

把 PTX 给 ChatGPT 让它翻译为 CUDA,会有⼀些⼩错误,然后⾃⼰再校正下,这种不是太⻓的 PTX ⼀般都没问题

0 0 Reply Share ›


郭⼤瘦 > Wil Kong − ⚑
a year ago

了解了,我⽤gpt4试试,thanks~

0 0 Reply Share ›

Subscribe Privacy Do Not Sell My Data

© 2019 - 2024 Fei Kong · Powered by Hugo & Coder.

https://fkong.tech/posts/2023-04-23-triton-cuda/ 3/7/25, 3 24 AM
Page 17 of 17
:

You might also like

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy