Skip to content

Commit

Permalink
#0: LLM tech report performance analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 27, 2024
1 parent b4bcf3a commit d60e8de
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 2 deletions.
Binary file added tech_reports/LLMs/images/4.6-op-to-op-gap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tech_reports/LLMs/images/4.6-overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
229 changes: 227 additions & 2 deletions tech_reports/LLMs/llms.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# LLMs in TT-NN
Authors:
Authors: Mark O'Connor,

## Contents
- [LLMs in TT-NN](#llms-in-tt-nn)
- [Contents](#contents)
Expand Down Expand Up @@ -86,7 +87,231 @@ Authors:
- Accuracy tests
- Debugging PCC issues
### 4.6 Performance Analysis
- Performance tooling, tracy

Think of ttnn performance as having five components:

![Performance components overview](images/4.6-overview.png)

1. **Main python thread** - this is your code that executes the ttnn calls and does other bits of logic in between. The speed of this thread determines the speed at which python calls are dispatched to the API. You are in control of any overheads here. When counting in microseconds python is probably slower than you think.
2. **Host API** - most of your ttnn calls will be immediately dispatched onto multiple C++ threads for further processing before anything is done to the hardware. You are generally not in control of any overheads in this part of the stack.
3. **Host-device communications** - data is heavy and moving it is to be avoided. PCIe bandwidth and latency is not negligible at the kinds of speeds we want to run models at. In addition, Tenstorrent converts most data into tiles of 32x32 elements for faster processing. Tilizing and untilizing data takes time and should be performed on-device wherever possible, but you will have to specify that as we’ll see in this section.
4. **Device dispatch** - we can measure the gap between one op finishing and the next starting. At time of writing the lower limit of this is single-digit microseconds and there is work underway to reduce it to zero. However, for various reasons you might see much higher dispatch times, most notably if there are a lot of runtime arguments to a function or if something else is happening in between calls.
5. **Device op performance** - how long it takes the hardware to run a given operation. Ideally we want this to be limited either by DRAM bandwidth or math throughput and for larger ops both of these are generally achievable. Doing so is mostly about how the data is placed (DRAM vs L1, sharded vs interleaved) and how the compute kernels are configured (process more than one tile at once and use smaller data formats).

We will dive into all of these in detail - how we like to measure each section and tips and tricks to optimize it.
However, the high-order bit we should consider first is this: should you use tracing?

#### What is tracing and when (not) to use it

Check out the [Metal Trace guide](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/AdvancedPerformanceOptimizationsForModels/AdvancedPerformanceOptimizationsForModels.md) for background on this. Essentially it lets you record a single pass of your model and stores the list of commands and buffers used on-device. You can then execute that trace in a single command with no additional work performed on the host. This eliminates all the overhead in stages 1-4 (you are still responsible for transferring any data needed to and from device, but host-device transfer of commands is eliminated).

We typically use tracing for the decode pass of LLMs but not the prefill pass. The main reasons for this are linked to tracing’s key limitation:

* You cannot allocate or deallocate tensors during a trace. When executing a trace every buffer will be the same size every time.

This doesn’t fit well with prefill, in which the sequence length and so matmul row counts will likely change each time. For decode it’s no problem (see the sections on kv-cache and paging for how we handle those with tracing). Conveniently, in prefill we have big operations in the millisecond plus range which the host can usually dispatch fast enough that it doesn’t matter. Decode, with a comparatively small batch size, is another story. There we iterate through the entire model in like 10ms with microsecond-length op times where we cannot afford to wait for a CPU, the whims of linux process scheduling or anything else but the speed at which electrons coruscate from DRAM and the NoC through our cores and out again.
**TL;DR: for decode mode you won’t have to worry about 1-3 but for prefill mode you will.** We will cover everything anyway.

#### What is async mode and when (not) to use it

```python
mesh_device.enable_async(True)
```

Without async mode each python call to ttnn will block until the device has finished and the results are available. This is good for debugging, because any crash or error will show you the correct python line of code that caused it. With async mode enabled your python thread keeps on running whilst the host and device handle the calls in the background, only blocking when data needs to be read back from device.

Async mode is obviously much faster, but if something asserts or crashes then your python stack will be several lines further on than the call that caused the problem.
For performance work async mode should always be enabled. For debugging it can be useful to disable it from time to time.

#### 1. Main python thread

If you’re tracing this doesn’t matter, but if not then it matters a lot. The Metal Profiler/Tracy can also show python performance but for pure python analysis it's hard to beat [viztracer](https://github.com/gaogaotiantian/viztracer):

```bash
pip install viztracer
```

Is enough to install it, then find the part of the code you want to profile (normally the part of your code that calls your model’s forward function) and wrap it, e.g.:

```python
# ...
# setup code above

from viztracer import Viztracer
with Viztracer(output_file='trace.json') as tracer:
tt_out = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
```

You can view this file with `vizviewer trace.json` - it’s entirely self-sufficient so if you’re working on a remote machine you can copy it back to your laptop and run it there (remember to `pip install viztracer` locally as well). Use WASD to navigate the UI and use the mouse to expand processes to see the call stacks. Look for any non-ttnn code that takes a significant amount of time in between the ttnn calls in your functions and find a way to remove or optimize it.

What to look for:

* You should be able to see your model forward pass running quickly and then waiting in a ttnn.to_torch or similar call reading data back from device.
* Measure the time from the start to the end of the forward pass of your model. If this is shorter than the target latency of your device then it is Fast Enough™ and you are done with this section.

Top tips:

* Torch modules add a surprising amount of overhead to every function call and member access. We don’t subclass `torch.nn.Module` for anything that might have to run quickly.
* Generate shard spec and compute kernel config objects once (e.g. in a constructor) instead of recreating them every time you run the forward pass. Keep the forward pass extremely clean.
* Make sure Metal is compiled in Release mode (default) and you are using ttnn’s async mode (see above)

#### 2. Host API

Any overhead here is mostly outside your control and in our experience is pretty minimal. You can use a C++ profiler or [Metal Profiler/Tracy](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/MetalProfiler/metal-profiler.md) with host stack traces enabled to see this time but it’s not really worth focusing on unless you’re a Metal developer.

#### 3. Host-device communications

As a rule of thumb you want as little communication as possible between the host and the device. For LLMs this means:

* Perform embeddings on-device (tokens ids are much smaller than embeddings)
* Return only the last token from prefill, not all the tokens
* Perform sampling (argmax etc) on-device if you can (at time of writing only argmax is implemented)
* Avoid pushing attention masks, rotation matrices and so on if they can be generated on-device or re-used between iterations

In addition pay attention to where data is tilized and untilized. You don’t want to do this on the host! At the time of writing `to_torch` will by default do this on the host. You can untilize on-device like this:

```python
tt_out_tiled = tt_model(decode_input, current_pos, rot_mat=current_rot_mat)
tt_out_row_major = ttnn.untilize(tt_out_tiled, use_multicore=True)
tt_tok = ttnn.argmax(tt_out_row_major, dim=3, use_multicore=True)
torch_tok = ttnn.to_torch(tt_tok)
```

Looking at host-device communications in a python profiler like `viztracer` is possible but be careful - when async-mode is on then any time spent in a communication call like `to_torch` can be comprised of up to three things:

1. Time spent waiting for the device
2. Time spent transferring data
3. Time spent untilizing data

If you want to measure the calls this way, turn async mode off. In this way the time your main python thread spends in `to_torch` will not include any of (1) and will be a closer approximation to what you probably wanted to measure.

#### 4+5. Device dispatch and op performance

This is the fun bit, but we need to do a little prep to get started. First, metal must be compiled with `-p` to enable device profiling:

```bash
./build_metal -p
```

Then we can record an op performance csv file with tracy. For the pytests, run it like this:

```bash
python -m tracy -r -p -v -m pytest path/to/test.py
```

This produces a file named something like `ops_perf_results_2024_11_01_15_33_18.csv` - that file is all we need from the profiler for now, but to learn more see the [Metal Profiler tech report](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/MetalProfiler/metal-profiler.md).

> **Warning:** Only use one single trace execution step when profiling. At time of writing profiler support with tracing is still a work-in-progress and more iterations will result in a `AssertionError: Device data mismatch error`.
> **Note:** If you see errors whilst running tracy, try this device-only profiling process instead: run with `TT_METAL_DEVICE_PROFILER=1 pytest path/to/test.py` and after the run completes run `tt_metal/tools/profiler/process_ops_logs.py --date` to generate the CSV file.
This CSV file contains a wealth of information recorded from all the device during program execution. To summarize it into a more human-readable format we run the `perf_report.py` tool:

```bash
python models/perf/perf_report.py OPS_CSV_FILE
```

The [documentation for this tool](https://github.com/tenstorrent/tt-metal/tree/main/models/perf) describes how to use it to select specific ranges of operations. For device performance we strongly recommend looking at a single layer. You can do this by using `--id-range` or by changing your test to run only a single layer of the model (recommended).

##### What makes a good performance test?

Ideally you should run your model in as close to end-user form as possible, whilst simplifying as much as possible. In practice this means:

* Use tracing (if you are using tracing in production)
* Skip the first compilation iteration - this adds a lot of one-time host overhead between ops
* Run a single layer of the model - but be aware of which ops will be run for every single layer and which ones are only run at the start and end (e.g. embedding, final norm and LM head)
* Add a tracy signpost e.g. `tracy.signpost("Performance pass")` just before the part of your run you want to record - this will be focused on by default by `perf_report.py`, saving you some work

##### What does such a report look like?

Here is an example with no op-to-op tracing. You can instantly see that more time (756us) is spent in between ops (op-to-op gap) than running ops on device (362us)!

Reducing op-to-op gap

![op-to-op gap](images/4.6-op-to-op-gap.png)

There are two main contributors to op-to-op gap: **host time** and **dispatch time**.

* **Host time** is optimized in steps 1-3. If you are already using tracing or are using async mode and have ensured that your python thread is dispatching faster than the device is generating outputs then this has already been minimized.
* **Dispatch time** is mostly out of your hands, but as an example it is influenced by the number of runtime args a kernel uses.
* You can examine the source code for any kernel with high op-to-op latency and see if you can convert some of the runtime args into compile-time args for your use case.
* You can fuse multiple ops into a single kernel. Examples where this was well worthwhile in the past include `LayerNorm` and `ScaledDotProductAttentionDecode`.

Typically tracing reduces the op-to-op gap below 6us and as of November 2024 there are roadmap plans to reduce this to zero, so as long as your ops are below this level your opportunities for optimization here are limited.

##### Overall op performance advice

There are a lot of individual tips here but let’s start with overall advice:

1. Use as many cores as possible
2. Move data as little as possible

Essentially the perfect op runs on the entire core grid using sharded inputs from L1. Let’s look more at data movement first, then specific tips.

##### Data movement

Ops can read data from:

1. DRAM interleaved - each tile (32x32 datums) is read from a different DRAM bank. This is the ttnn default and is the slowest way to read data. A matmul can expect to read around 190 GB/s on a Wormhole like this.
2. DRAM sharded - specifically used for DRAM-bound matmuls and nothing else, this splits the data across DRAM banks and uses the closest core to each bank on the chip to read from that bank. This achieves around 240 GB/s on a Wormhole.
3. L1 interleaved - tiles are interleaved across the L1 of all the cores and are read across the NoC (network-on-chip)
4. L1 sharded - tiles are sharded across a particular grid of cores

Note that the term **sharding** is used in two ways in the metal stack. Here we are talking about **sharding across cores** within a single chip. It is also used to refer to sharding a dimension across multiple devices - an analogous operation but confusing in this context.

L1 sharded is particularly fast when the data an op requires is already placed in L1 of the correct core, avoiding the NoC entirely and reading at maximum speed.

Typically activations should be placed in L1 and weights placed in DRAM.

See the [op config section](#44-op-configs) for more details on writing shard specs in your code.

##### Specific tips

Ok so your ops are reading from the fastest memory they can, sharded if possible. What might still make things slow?

* **Unnecessary `ShardedToInterleaved` and `InterleavedToSharded` calls**. The fastest work is work that you don’t have to do. These calls are pure data movement and it is often better to have some ops using fewer cores if it means they can use the same sharding of their input data as the previous and subsequent ops. Always avoid data movement!
* **Always use `ScaledDotProductAttention` (SDPA) ops if possible**. These implement FlashAttention / FlashDecode and are much faster than writing attention using individual operations.
* **Cross-device communication ops**. `AllGather`, `ReduceScatter` etc. Avoid these where possible, try using `bfp8` inputs instead of `bf16` if you can. There is an `AllGatherMatmul` op that overlaps `AllGather` with a `Matmul` that you can investigate further too - see `ttnn.experimental.all_gather_matmul` with an [example of its use](https://github.com/tenstorrent/tt-metal/blob/79ff70b0e115ac50e70a72391dde3c4a4a6fab7f/models/demos/llama3/tt/llama_attention.py#L329) looking like this:

```python
_, dense_out_sharded, _ = ttnn.experimental.all_gather_matmul(
input_tensor,
weights,
dim=3,
all_gather_core_grid_offset=(0, 4),
num_links=1,
memory_config_ag=all_gather_memcfg,
memory_config_mm=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
program_config=all_gather_matmul_progcfg,
compute_kernel_config=compute_kernel_config_hifi2,
)
```

**Matmuls** are usually most significant workload. They should be memory-bound, compute-bound or too small to matter. `perf_report.py` gives good advice for your matmuls and you should follow it, which usually involves specifying a [program config](#44-op-configs):

* Output subblock size should be at least 2x1 or 1x2
* DRAM-sharded matmuls should be used for any DRAM-bound cases, e.g. most decode matmuls.
* The inner dim number of tiles (`in0_block_w`) should be at least 2 if possible.
* Use the lowest precision you can for weights and inputs - we find BFP8 weights always work and BFP4 weights work for some matmuls particularly in the MLP.
* Use an appropriate math fidelity in the compute kernel config. This controls the number of bits multiplied together and is especially important for compute-bound matmuls as the Tensix core’s math throughput is 2x higher with HiFi2 and 3.6x faster with LoFi.
* Use HiFi4 for BF16 weights or if accuracy is very important (you often see this in attention ops)
* Use HiFi2 for BFP8 weights - this drops the least-significant bit of a BF16 @ BFP8 matmul but this is usually not an issue. You may find that LoFi works as well.
* Use LoFi for BFP4 weights.

You can specify a compute kernel like this:

```python
self.compute_kernel_config_hifi2 = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
```

As always, do not recreate these every single forward pass if you want your python thread to be fast (which you do).


### 4.7 Misc. Performance Optimizations
- Which dim to shard matmuls on
- DRAM-sharding
Expand Down

0 comments on commit d60e8de

Please sign in to comment.