Skip to content

General Guide of AMD Triton Performance Optimization

Vinayak Gokhale edited this page Dec 18, 2024 · 27 revisions

This document introduces the general steps for Triton kernel optimization. Overall, Triton kernel optimization is similar to CUDA/HIP kernel optimization. It includes the following aspects:

Hardware resource utilization

AMD GPUs use a grid of workgroups of waves of threads to operate on a kernel. Each GPU has many Compute Units (CUs), and different CUs do computation in parallel. One or more workgroups can be allocated to a CU. A large grid is typically better, provided the size of the subproblem processed by each workgroup is not too small. To increase the hardware utilization, generally more parallelism needs to be found in the algorithm.

Hardware resources can be queried with the command rocminfo (in the folder /opt/rocm/bin). For instance, one can query # of computes, # of SIMD, and wavefront size as:

  • rocminfo | grep "Compute Unit"
  • rocminfo | grep "SIMD"
  • rocminfo | grep "Wavefront Size"

For MI300X, there are 304 CUs, 4 SIMD per CU and wavefront size (warp size) is 64.

Autotunable kernel configurations

This section goes over the kernel arguments that can be controlled by the user to improve performance of their kernel.

Software Pipelining

Software pipelining can improve performance on a variety of workloads by overlapping memory access with compute.

You can control pipelining behavior using the num_stages kernel argument:

  1. In tl.range of a for loop for that specific loop:
for k in tl.range(start, end, step, num_stages=2):
  ...
  1. Or using triton.Config. This sets globally on the function and applies to all for loops in the function with tl.load -> tl.dot.
triton.Config(..., num_stages=2),

When targeting AMD GPUs, Triton will employ a stream pipelining approach by streaming data through register buffers (instead of directly to shared memory). Depending on the workload stream pipelining may also allocate one or multiple buffers in shared memory.

Guidelines

The usage of num_stages depends on the type of kernel. When num_stages=1 no pipelining will occur, irrespective of the type of the kernel.

GEMM kernels

  • GEMM kernels that have direct loads will benefit from using num_stages=2. This can be specified as a kernel argument, or provided in the triton.Config list as part of autotuning. Direct loads are those whose input pointers are based on offsets from a base pointer that is provided as a kernel argument. For example, this is a direct load.

  • GEMM kernels with indirect loads will benefit from using num_stages=3. This is an example of an indirect load, because the offsets are not known by the kernel arguments directly but need to be loaded from global memory.

One may use larger values for num_stages. However, because the destination for loads are registers (instead of shared memory), there is the possibility of causing register spills. Typically values of 2 and 3 in accordance with the rules above should work well to hide global load latencies.

Fused GEMM kernels

  • Currently, it is recommended to set num_stages=1 for fused GEMM kernels that are compute bound (for example, flash attention kernels used in prefill). Using more pipeline stages typically results in register spills as fused GEMM kernels have at least three inputs.
  • For memory bound fused GEMM kernels, num_stages=2 will help with pipelining. Kernels like paged attention that perform indirect loads will need to use num_stages=3 similar to GEMM kernels.

If performance is suboptimal when using num_stages > 1, please inspect the assembly for register spills. If spills occur, please do not use pipelining.

Elementwise kernels

A third category of kernels are elementwise kernels like layernorm or softmax. These are kernels with no tl.dot. Software pipelining works for these kernels, but Triton's rules for pipelining non-GEMM kernels are slightly different.

First, the loop that must be pipelined must be the inner loop. Persistent kernels may have outer for or while loops, and pipelining does not work across these.

Second, the loop must use tl.range. Pipeline can then be specified as

for i in tl.range(start, end, step, num_stages=<n>):
...

The input to the num_stages argument to tl.range can be provided as a kernel argument, or through autotuning. However, unlike GEMM kernels, it will not be automatically picked up unless provided to tl.range.

Waves per EU

This can be specified by the kernel argument waves_per_eu=n. This is a hint to the compiler to reduce vector general purpose register (VGPR) utilization to a level such that occupancy = n could be achieved. This only helps if both of the following are satisfied:

  1. The occupancy of the kernel is limited by VGPR usage. See section for how to compute occupancy.

  2. The current VGPR usage is only a few above a boundary in Table.1 in AMD lab notes.

For example, according to Table.1 in the AMD lab notes, the available VGPR is 512 per Execution Unit (EU) and VGPR is allocated at the unit of 16. If the current VGPR usage is 170, the actual requested VGPR will be 176. This limits occupancy to a maximum of 2 waves per EU because 176 x 2 = 352 which is less than 512, but 176 x 3 is larger than 512. If we set waves_per_eu to 3, the LLVM backend will try to bring VGPR usage down so that we might fit 3 waves/EU. However, this is unlikely to always work as some times registers need to have long live ranges and cannot be reduced.

Block sizes

This is also often referred to as tile sizes, although the BLOCK nomenclature is more common in Triton. We assume the reader is familiar with how GEMM computation is tiled on GPUs. BLOCK_M and BLOCK_N are the block sizes in the M and N dimensions. They control the size of the output tile computed by a workgroup.

We want tile sizes large enough to maximize the efficiency of memory-to-computation ratio, but small enough to parallelize the greatest number of WGs at the grid level. Consider a GEMM of shape 4096 x 4096. If we use a block size of 256 x 256, we will only have 16 blocks in the M and N dimension, for a total of 256 blocks. This is less than 304, which is the number of CUs we have on a MI300X. As such, using this tile size will result in only 84% utilization. Using a block size of 128 x 128 results in 1024 blocks. However, with 304 CUs, we will end up with (1024/304) / (ceil(1024/304)) which is also equal to 84% utilization. 128 x 256 and 256 x 128 also result in the same utilization. A block size of 64 x 64 results in larger utilization, as does a block size of 128 x 64 (94% and 99% respectively). This kind of analysis would need to be done when setting values for BLOCK_M and BLOCK_N in the autotune configs list.

The block size along the K dimension does not directly inform the utilization, as it is the reduction dimension. However, an optimal K block size is one that results in 512 contiguous bytes of data. For fp16 or bf16 data types, this would be a BLOCK_K of 256. However, this typically makes the tile too large, and commonly sizes of 64 or 128 are used.

Elementwise kernels are typically memory bound. For these, we want to have at least 16-32 KiB of global load instructions in flight. This would imply a BLOCK_SIZE of 8192 or 16384, if using 2 byte datatypes. As an aside, it is recommended for elementwise kernels to perform their computation in the tl.float32 datatype.

MFMA instruction type

Triton supports two MFMA instructions for the bf16 and fp16 datatypes, and two for the fp8 and bf8 datatypes. These can be set by the matrix_instr_nonkdim kernel argument. For GEMM kernels on MI300X, we found that setting this parameter to 16 has better performance. On fused GEMM kernels, setting it to 32 has better performance. The instructions are

  1. v_mfma_f32_32x32x8_{f16,bf16} and v_mfma_f32_32x32x16_{fp8,bf8}
  2. v_mfma_f32_16x16x16_{f16,bf16} and v_mfma_f32_16x16x32_{fp8,bf8}

The dimensions are listed out as mxnxk so the "nonkdim" would be 32 for (1) and 16 for (2).

kpack

This is a kernel argument that optimizes shared memory accesses. It should be set to 2 for GEMMs and left unset for other kernels.

Memory access efficiency

GPU has global memory, local data share (LDS, shared memory), and register. We know that global memory has high access latency and size is big. LDS access has much lower latency, but size is small. Register access is the fastest yet smallest among the three. Generally, we want data in global memory to be loaded/stored as few times as possible.

IR analysis

In Triton, we have different layouts, including blocked layout, shared layout, and sliced layout, and MFMA layout. From the Triton GPU IR, we can know in which memory each computation is performed. Here is a snippet of IR from the Flash Attention (FA) decode int4 KV program. It is to dequantize the int4 KV from int4 data type to fp16.

%190 = tt.load %189 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1x64xi32, #blocked6> loc(#loc159)
%266 = arith.andi %190, %cst_28 : tensor<1x64xi32, #blocked6> loc(#loc250)
%267 = arith.trunci %266 : tensor<1x64xi32, #blocked6> to tensor<1x64xi16, #blocked6> loc(#loc251)
%268 = tt.bitcast %267 : tensor<1x64xi16, #blocked6> -> tensor<1x64xf16, #blocked6> loc(#loc252)
%269 = triton_gpu.convert_layout %268 : (tensor<1x64xf16, #blocked6>) -> tensor<1x64xf16, #shared1> loc(#loc252)
%270 = tt.trans %269 : (tensor<1x64xf16, #shared1>) -> tensor<64x1xf16, #shared2> loc(#loc194)
%276 = triton_gpu.convert_layout %270 : (tensor<64x1xf16, #shared2>) -> tensor<64x1xf16, #blocked5> loc(#loc254)
%293 = arith.mulf %276, %cst_30 : tensor<64x1xf16, #blocked5> loc(#loc254)
%295 = arith.mulf %292, %294 : tensor<64x32xf16, #blocked5> loc(#loc264)
%297 = arith.addf %295, %296 : tensor<64x32xf16, #blocked5> loc(#loc255)
%298 = triton_gpu.convert_layout %297 : (tensor<64x32xf16, #blocked5>) -> tensor<64x32xf16, #shared1> loc(#loc255)
%299 = tt.trans %298 : (tensor<64x32xf16, #shared1>) -> tensor<32x64xf16, #shared2> loc(#loc196)
%300 = triton_gpu.convert_layout %299 : (tensor<32x64xf16, #shared2>) -> tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> loc(#loc197)

From the IR here, we can see i32 data is loaded from global memory to registers. With a few element-wise operations in registers, then it is stored in shared memory for the transpose operation, which needs data movement across different threads. With transpose done, it is loaded from LDS to register again, with a few more element-wise operations, they are stored to LDS again. Last step is loaded from LDS to registers and converted to the dot operand layout. We can see from the IR that it uses the LDS twice, one is for the transpose, the other is to convert blocked layout to dot operand layout. However, these two do not require multiple LDS accesses. The conversion from shared1 to shared2 is a logical conversion and does not come with additional LDS accesses. Instead, it lowers to LDS address calculation for the subsequent LDS read accesses that will read the block transposed instead of normal. But this address calculation has to exist even for the case when the read is normal (just with different addresses).

A note on transpose operations: Some times transpose operations can be absorbed into existing layout conversion operations. However, in general, if possible, one should avoid doing transposes in the user level code.

Assembly analysis

  • In the ISA, make sure global_load_dwordx4 is used, especially when the load happens in the loop.
  • In most cases, the LDS read and write should use _b128 as well to minimize the number of LDS access instructions. _b64 instructions are also okay.
  • The AMD ISA has s_waitcnt instruction to synchronize the dependency of memory access and computations. The s_waitcnt instructions can have two signals typically in the Triton context
  1. lgkmcnt(n): lgkm stands for LDS, GDS, Constant and Message. For our context, it is often related to LDS access. The number n here means the number of such accesses can be left out to continue. For example, 0 means all lgkm access must finish before continuing, and 1 means only 1 lgkm access can be still running asynchronously before proceeding.
  2. vmcnt(n): vm means vector memory. This happens when vector memory is accessed, e.g., global load from global memory to vector memory. The variable n here means the same thing as the above.

The general guideline is:

  1. Vectorize memory access as much as possible - loads should be dwordx4 and ds reads/writes should be b64 or b128.
  2. Ensure synchronization is done efficiently. We don't want to be waiting for LDS or global loads if other instructions could be executed before the barrier.
  3. Overlap of instructions to hide latency, but it requires thoughtful analysis of the algorithms.

Issues here are typically not able to be solved by the user and must be fixed in the compiler or by the LLVM backend. It would be preferable to contact the Triton team if such issues are spotted so that they can be investigated. It is possible that due to Triton's nature, a fix to your issue may fix it in multiple kernels for many other users.

Tools (rocprof, omniperf)

PyTorch inductor Triton tuning knobs

  • To enable gemm/conv lowerings to triton, requires use of inductor’s max_autotune mode. This will benchmark a static list of triton configs (conv configs for max autotune + matmul configs for max autotune) and use the fastest for each shape. (Note: if regular MIOpen/rocBlas is faster for a specific operation, triton will not be used)

    1. torch._inductor.config.max_autotune = True or TORCHINDUCTOR_MAX_AUTOTUNE=1
    2. Or for more fine-grained control
  1. torch._inductor.config.max_autotune.pointwise = True - to enable tuning for pointwise/reduction ops
  2. torch._inductor.config.max_autotune_gemm = True - to enable tuning/lowering of mm/convs
  3. torch._inductor.max_autotune_gemm_backends/TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS - Selects the candidate backends for mm autotuning Defaults to “TRITON,ATEN”, NV also includes CUTLASS tuning option. Limiting this to “TRITON” may improve performance by enabling more fused mm kernels instead of going to rocBlas
  • For further mm tuning coordinate_descent tuning may improve performance, which attempts
    1. torch._inductor.config.coordinate_descent_tuning=True/TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1
  • Inference can see large improvements on AMD by utilising torch._inductor.config.freezing=True/TORCHINDUCTOR_FREEZING=1, which inlines weights as constants and enables constant folding optimisations.
  • Enabling inductor’s cpp_wrapper may improve overhead, this will generate a c++ code which launches Triton binaries directly with hipModuleLaunchKernel and relies on hipification. (Note: We are still failing a few tests regarding this feature) - torch._inductor.config.cpp_wrapper=True/TORCHINDUCTOR_CPP_WRAPPER=1
  • For NHWC convolutions workloads torch._inductor.config.layout_optimization=True/TORCHINDUCTOR_LAYOUT_OPTIMIZATION=` can help be enforcing channels_last format throughout the graph avoiding any additional transposes added by inductor. (Note: PYTORCH_MIOPEN_SUGGEST_NHWC=1 recommended if using this)
  • If need to extract the triton kernel TORCH_COMPILE_DEBUG creates a torch_compile_debug/ directory at current path, in output_code.py the code-strings for the triton kernels are defined. Manual work is then required to strip out the kernel and create kernel compilation/launch via triton.
  • For advanced matmul/conv config tuning the inductor-gemm-tuner can help, this implements the triton conv/mm implementations used upstream and allows specification of inputs and config tuning search space, if new tunings are found can be added to autotune list. More work needs to be done on parsing the results of this tuning
    1. Example used for resnet152: HIP_FORCE_DEV_KERNARG=1 MIOPEN_FIND_ENFORCE=4 MIOPEN_FIND_MODE=1 TORCHINDUCTOR_COMPILE_THREADS=1 python bench.py --fp16 --kernel conv --input_file=input/models/resnet152/conv_perf_drop.json --config_file=configs/models/resnet152/tuned.json

Debugging Memory Access Faults

Identifying the faulting kernel is often enough to triage a memory access fault. To that end, the rocm debug agent can trap a memory access fault and provide a dump of all active wavefronts that caused the error as well as the name of the kernel. The README provides full instructions, but to summarize:

  1. Compiling with -ggdb -O0 is recommended but not required.
  2. HSA_TOOLS_LIB=/opt/rocm/lib/librocm-debug-agent.so.2 HSA_ENABLE_DEBUG=1 ./my_program

When the debug agent traps the fault, it will produce extremely verbose output of all wavefront registers and memory content. Importantly, it also prints something like:

Disassembly for function vector_add_assert_trap(int*, int*, int*):
code object: file:////rocm-debug-agent/build/test/rocm-debug-agent-test#offset=14309&size=31336
loaded at: [0x7fd4f100c000-0x7fd4f100e070]

The kernel name and the code object file should be listed. In the example above, the kernel name is vector_add_assert_trap, but this might also look like:

Disassembly for function memory:///path/to/codeobject#offset=1234&size=567:

In this case, it is an in-memory kernel that was generated at runtime. Using the env var ROCM_DEBUG_AGENT_OPTIONS="--all --save-code-objects" will have the debug agent save all code objects to the current directory (use --save-code-objects=[DIR] to place them in another location). The code objects will be renamed from the URI format with special characters replaced by ‘_’. Use llvm-objdump to disassemble the indicated in-memory code object that has now been saved to disk. The name of the kernel is often found inside the disassembled code object.

llvm-objdump --disassemble-all path/to/code-object.co

It is recommended to disable various memory caching strategies both within the ROCm stack and PyTorch, where possible. This will give the debug agent the best chance at finding the memory fault where it originates, otherwise it could be masked by writing past the end of a cached block within a larger allocation.

PYTORCH_NO_HIP_MEMORY_CACHING=1
HSA_DISABLE_FRAGMENT_ALLOCATOR=1

Miscellaneous

a. Performance critical HIP provides an environment variable export HIP_FORCE_DEV_KERNARG=1 that can put arguments of HIP kernels directly to device memory to reduce the latency of accessing kernel arguments. It can reduce 2 to 3 us for some kernels.

b. Set clock for deterministic. Use the command rocm-smi --setperfdeterminism 1900 to see the max clock speedup to 1900MHz instead of the default 2100MHz. This can reduce the chance of clock speed decrease due to chip high temperature by setting a lower cap. This setting can be restored to default with rocm-smi -r.

c. Set numa autobalance. Run the command cat /proc/sys/kernel/numa_balancing to check the current settings. Output 0 indicates this setting is available. If not output or output is 1, we can run the command sudo sh -c \'echo 0 > /proc/sys/kernel/numa_balancing to set this.

Usage: ./env_check.sh [set/reset/check] (use ./env_check.sh -h for help info)

Script contents: (you can download it env_check.sh)

#!/bin/bash

function print_usage {
	echo "    Usage: env_check.sh set/reset/check"
	echo "                      set: configure the settings in this script"
	echo "                      reset: reset to default settings"
	echo "                      check: check the current settings"
}

function set_env {
	export HIP_FORCE_DEV_KERNARG=1
	rocm-smi --setperfdeterminism 1900
	sudo sh -c echo 0 > /proc/sys/kernel/numa_balancing
        sudo cpupower frequency-set -r -g performance
        sudo cpupower idle-set -d 1
}

function reset_env {
	unset HIP_FORCE_DEV_KERNARG
	rocm-smi -r
	sudo sh -c echo 1 > /proc/sys/kernel/numa_balancing
}

function check_env {
	echo ""
	echo "---------------------------------------------------------------"
	echo ""

	# check the flag to force kernel to be on device memory
	echo "1. Check forcing kernel args on device memory"
	dev_kernarg=$(env | grep HIP_FORCE_DEV_KERNARG)
	if [ -z $dev_kernarg ]
	then
		echo "  no setting for forcing kernel args on device memory"
		echo "  run the command \"export HIP_FORCE_DEV_KERNARG=1\" to force it"
	else
		echo "  env var \"HIP_FORCE_DEV_KERNARG\" for forcing kernel args on device"
		echo "  memory is set, we have HIP_FORCE_DEV_KERNARG=" $HIP_FORCE_DEV_KERNARG
		if [ "$HIP_FORCE_DEV_KERNARG" -eq 0 ]
		then
			echo "  env var HIP_FORCE_DEV_KERNARG is 0, set it to 1 by:"
			echo "  command \"export HIP_FORCE_DEV_KERNARG=1\""
		fi
	fi

	echo ""
	echo ""
	echo "2. Set perfdeterminism, highest frequency"
	echo "  run the command \"rocm-smi -a | grep sclk\" to check highest frequency."
	echo "  you can run the command \"rocm-smi --setperfdeterminism # (e.g. 1900)\" to"
	echo "  set clock frequency limit to get minimal performance, which is more reproducible"
	echo "  you can restore the setting by running \"rocm-smi --resetperfdeterminism\""
	
	echo ""
	echo ""
	echo "3. Check numa autobalance"
	autobal=$(cat /proc/sys/kernel/numa_balancing)
	if [ $autobal -ne 0 ]
	then
		echo "  run the command \"sudo sh -c \'echo 0 > /proc/sys/kernel/numa_balancing\'\""
		echo "  to set numa autobalance". 
		echo "  you can disable it with \"sudo sh -c \'echo 1 > /proc/sys/kernel/numa_balancing\'\""
	else
		echo "  numa autobalance is checked with:"
		echo "  (cat /proc/sys/kernel/numa_balancing)=0"
	fi

	echo ""
	echo "---------------------------------------------------------------"
	echo ""
}


if [ $# -eq 0 ]
then
	echo "   \"env_set.sh -h\" for help info"
	print_usage
	exit 1
fi

input=$1
if [ $1 == "set" ]
then
	set_env
elif [ $1 == "reset" ]
then
	reset_env
elif [ $1 == "check" ]
then
	check_env
else
	print_usage
fi

Appendix

Generating assembly

If you wish to inspect the static assembly for your kernel, you can use the AMDGCN_ENABLE_DUMP=1 environment variable. You may want to pipe your output to a text file like so

AMDGCN_ENABLE_DUMP=1 python <kernel_file>.py > asm.amdgcn 2>&1

as otherwise, it can take a long time and might also be difficult to inspect if just printed to the terminal.

There will be one kernel compiled per autotune config per run. Typically we only care about the assembly of the kernel that actually was used which is the winner of the autotuning. To do this, you can identify the winner by first running

TRITON_PRINT_AUTOTUNING=1 python <kernel_file>.py

This will print the winning config. You may then comment out all other configs and generate the assembly as above.

Assembly is also available in the Triton cache under /root/.triton/cache (note that this path will have all kernels compiled stored as folders with alphanumeric names and one would have to navigate to the one generated for your kernel which can be done by deleting the cache at each run).

Generating the Triton IRs

This can be done using the MLIR_ENABLE_DUMP=1 environment variable. Similar to assembly (see above), this will also be generated per compiled kernel, and can also be found in the Triton cache.

Understand/Compute the occupancy of the kernel

  • Get the VGPR count, search for .vgpr_count in the ISA (see Appendix on how to generate the assembly). Let's say the VGPR count is N.
  • Get the allocated LDS, search for ttg.shared in the IR (see Appendix on how to generate IR). Let's say this value is L bytes.
  • Get number of waves per workgroup, search for ttg.num-warps in the IR. Let's say this value is nW waves per workgroup.
  • Compute occupancy limited by VGPR based on N according to Table.1 in AMD lab notes. Say you got waves per EU as occ_vgpr.
  • Compute occupancy limited by LDS based on L by: occ_lds = floor(65536 / L).

Then the occupancy is occ = min(floor(occ_vgpr * 4 / nW), occ_lds) * nW / 4

  1. occ_vgpr * 4 gives the total number of waves on all 4 EUs (SIMDs) per CU
  2. floor(occ_vgpr * 4 / nW) gives the occupancy of workgroups per CU regrading VGPR usage

Then the true occ is the minimum of the two. The above logic is available in occ.sh.

Occupancy is critical to understand how many waves can be resident on a SIMD within a CU at once. Generally, compute bound large GEMM kernels have an occupancy of 1 or 2 waves per SIMD while memory bound GEMMs can have larger occupancy on account of being skinny GEMMs or GEMVs, and thus using fewer VGPRs and shared memory.

Note (names are in alphabetical order): Jason Furmanek, Vinayak Gokhale, Jack Taylor, Peng Sun, Simon Waters, Shucai Xiao, Lei Zhang, and Lixun Zhang contributed to the contents together.