Skip to content

Commit

Permalink
Add conversion for more OPs (#11)
Browse files Browse the repository at this point in the history
* Use TILE_LAYOUT during data move-in

* Insert a ttnn.to_layout(ttnn.TILE_LAYOUT) between ttnn.from_torch and
ttnn.to_device when adding data move-in ops
* ttnn.reshape will skip inserting ttnn.to_layout
* Update the tests to reflect newly inserted function

* Fix reshape

* Add conversion from torch.relu and torch.addmm to ttnn

* Add conversion from torch.div, torch.bmm, and torch.gelu to ttnn

* Add workaround to handle input aliasing

* Add conversion from aten.rsub and aten.embedding

* Add conversion from aten.split

* Move GraphCleanup method to a new file

* Move Dummy string repr to separate utils file

* Fix rsub elif

* Add torch.clone conversion to ttnn.clone

ttnn.clone requires extra arguments compared to torch.clone:
MemoryConfig type and output clone type
* Construct ttnn.MemoryConfig for DRAM
* Retrieve metadata from original torch op and translate to ttnn type

* Add support for kwargs

* Add conversion from torch.nn.functional.layer_norm

torch.nn.LayerNorm does not have parameters for custom weights and bias
and produces values that differ quite a bit from ttnn.layer_norm. This
is not supporterd yet.

However, torch.nn.functional.layer_norm can produce values that are very
close to ttnn.layer_norm and this commit will test against the aten op
that is lowered by this higher level torch op.

aten.native_layer_norm returns 3 outputs: layer norm, mean(?), rstd(?).
However, torch.nn.functional.layer_norm only cares about the layer norm
output. Currently, this transformation replaces the mean and rstd with
layer norm output. This should be be fixed later.

* Add conversion from torch.neg, torch.ones, and torch.tril to ttnn counterparts

* ttnn.ones require passing the device object manually
* A default device has to be set up for AutoFormat, since ttnn.tril uses it

* Use custom class for kwarg object instead of a generic tuple.

* Add transformation aten.{eq.Tensor, eq.Scalar, logical_not, zeros_like, mean.dim}

* Fix torch.compile options for other tests

* Move transformations to ttnn.add and ttnn.mul to ToTtPass

This requires a patch to ttnn.decorators.Operation

* Fix test_fall_back

* Add transformations for several more ops

* Pow (Tensor base, scalar exponent)
* Rsqrt
* Silu
* Adaptive Avg Pool
* Clamp
* Squeeze (dim argument)

* Fix transformations for torch.eq and add transformation for torch.full

torch.eq (scalar) -> ttnn.full + ttnn.eq (tensor)

Previously ttnn.eq supports a scalar argument, but this errors now.

* Disable torch to ttnn.split test since fallback is disabled and op is not implemented yet

* Update torch to ttnn.reshape tests to match some limitations

* Implement transformation for torch.lq.{scalar,tensor} and generalize relational ops for cleaner implementation

* Implement aten.baddbmm transformation to ttnn

* Add transformation from torch.cos

* Remove conversion for split because ttnn.split is removed entirely

See: #5389

* Add transformation for torch.sigmoid

* Cast all model input arguments to bfloat16

* Set aten.view to fallback. Need to handle restrictions from ttnn.reshape

* Fix layer_norm conversion to handle cases where ttnn ops follow

* Handle case where aten.full has an empty shape

* Remove split conversion from to_tt pass

* Add fallback to squeeze conversion since ttnn.squeeze only supports dim 0.

* Add aten.rsub.Scalar conversion

* Match restrictions from ttnn.arange

* Add workaround for relational op conversion for certain input sizes

* Restrict embedding conversion to only support TILE_LAYOUT for now

* Handle case where the denominator is a scalar for div op

* Add workaround for when the model output takes the output from argmax.

aten.argmax outputs integer values, but ttnn.argmax outputs floating point

* Remove extraneous prints

* Add bert and falcon-7b models for testing with torch_stat backend

* These models have "/" in the names. Small fix in torch_stat backend.

* Update AutoModelForCausalLM models and add bigscience/bloom-1b1 model

* Add mamba, llama, gpt2, and yolos models

* Fix e2e run with torch_ttnn backend

* Add option to select a model

* Remove dependency on tests module from tt-metal and copy relevant test utility functions to this repo

* Update readme to include instructions on running transformer model with ttnn backend

* Run black formatter on files

* Fix formatting for run_transformers
  • Loading branch information
kevinwuTT authored Jun 26, 2024
1 parent f8087ae commit 4850320
Show file tree
Hide file tree
Showing 20 changed files with 3,017 additions and 236 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,11 @@ The `*_total_*_size_dist/` statistics the `op_type`'s input/output_size distribu
- Notice: the [aten ir interface is in there](https://pytorch.org/docs/stable/torch.compiler_ir.html)

[The `profile/` is the tools provided by pytorch](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html), you can open it by the url: chrome://tracing

# Run transformer models
To run transformer model with ttnn backend, run:
```
PYTHONPATH=${TT_METAL_HOME}:$(pwd) python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn
```

You can also substitute the backend with `torch_stat` to run a reference comparison.
24 changes: 13 additions & 11 deletions tests/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def input_shapes(self):
class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open(0)
self.device: ttnn.Device = ttnn.open_device(device_id=0)

def tearDown(self):
# Close the device
ttnn.close(self.device)
ttnn.close_device(self.device)

def test_add(self):
m = AddModule()
Expand All @@ -32,19 +32,21 @@ def test_add(self):
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend(option))
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
self.assertEqual(1, len(option._out_fx_graphs))
option._out_fx_graphs[0].print_tabular()
# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertEqual(nodes[3].target, ttnn.add)
self.assertEqual(nodes[3].args[0].target, ttnn.to_device)
self.assertEqual(nodes[3].args[0].args[0].target, ttnn.from_torch)
self.assertEqual(nodes[3].args[1].target, ttnn.to_device)
self.assertEqual(nodes[3].args[1].args[0].target, ttnn.from_torch)
self.assertEqual(nodes[4].target, ttnn.from_device)
self.assertEqual(nodes[5].target, ttnn.to_layout)
self.assertEqual(nodes[6].target, ttnn.to_torch)
self.assertEqual(nodes[4].target, ttnn.add)
self.assertEqual(nodes[4].args[0].target, ttnn.to_device)
self.assertEqual(nodes[4].args[0].args[0].target, ttnn.to_layout)
self.assertEqual(nodes[4].args[0].args[0].args[0].target, ttnn.from_torch)
self.assertEqual(nodes[4].args[1].target, ttnn.to_device)
self.assertEqual(nodes[4].args[1].args[0].target, ttnn.to_layout)
self.assertEqual(nodes[4].args[1].args[0].args[0].target, ttnn.from_torch)
self.assertEqual(nodes[5].target, ttnn.from_device)
self.assertEqual(nodes[6].target, ttnn.to_layout)
self.assertEqual(nodes[7].target, ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after))
32 changes: 22 additions & 10 deletions tests/test_fall_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import unittest
from torch_ttnn import ttnn

from torch_ttnn.utils import check_with_pcc


class MixModule(torch.nn.Module):
def __init__(self):
Expand All @@ -23,11 +25,11 @@ def input_shapes(self):
class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open(0)
self.device: ttnn.Device = ttnn.open_device(device_id=0)

def tearDown(self):
# Close the device
ttnn.close(self.device)
ttnn.close_device(self.device)

def test_fall_back(self):
m = MixModule()
Expand All @@ -37,19 +39,29 @@ def test_fall_back(self):
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend(option))
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
self.assertEqual(1, len(option._out_fx_graphs))
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertEqual(nodes[3].target, ttnn.from_torch)
self.assertEqual(nodes[2].target, ttnn.from_torch)
self.assertEqual(nodes[3].target, ttnn.to_layout)
self.assertEqual(nodes[4].target, ttnn.to_device)
self.assertEqual(nodes[5].target, ttnn.add)
self.assertEqual(nodes[6].target, ttnn.matmul)
self.assertEqual(nodes[7].target, ttnn.from_device)
self.assertEqual(nodes[8].target, ttnn.to_layout)
self.assertEqual(nodes[9].target, ttnn.to_torch)
self.assertEqual(nodes[5].target, ttnn.reciprocal)
self.assertEqual(nodes[6].target, ttnn.from_torch)
self.assertEqual(nodes[7].target, ttnn.to_layout)
self.assertEqual(nodes[8].target, ttnn.to_device)
self.assertEqual(nodes[9].target, ttnn.mul)
self.assertEqual(nodes[10].target, ttnn.add)
self.assertEqual(nodes[11].target, ttnn.matmul)
self.assertEqual(nodes[12].target, ttnn.reciprocal)
self.assertEqual(nodes[13].target, ttnn.mul)
self.assertEqual(nodes[14].target, ttnn.reciprocal)
self.assertEqual(nodes[15].target, ttnn.mul)
self.assertEqual(nodes[16].target, ttnn.from_device)
self.assertEqual(nodes[17].target, ttnn.to_layout)
self.assertEqual(nodes[18].target, ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after))
self.assertTrue(check_with_pcc(result_before, result_after))
32 changes: 17 additions & 15 deletions tests/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def input_shapes(self):
class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open(0)
self.device: ttnn.Device = ttnn.open_device(device_id=0)

def tearDown(self):
# Close the device
ttnn.close(self.device)
ttnn.close_device(self.device)

def test_if(self):
m = IfModule()
Expand All @@ -36,7 +36,7 @@ def test_if(self):
result_before_else = m.forward(*inputs_else)
option = torch_ttnn.TorchTtnnOption(device=self.device)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend(option))
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after_then = m.forward(*inputs_then)
result_after_else = m.forward(*inputs_else)

Expand All @@ -49,21 +49,23 @@ def test_if(self):
self.assertEqual(nodes_0[1].target, torch.ops.aten.sum.default)
self.assertEqual(nodes_0[2].target, torch.ops.aten.gt.Scalar)
nodes_1 = list(option._out_fx_graphs[1].nodes)
self.assertEqual(len(nodes_1), 8)
self.assertEqual(len(nodes_1), 9)
self.assertEqual(nodes_1[1].target, ttnn.from_torch)
self.assertEqual(nodes_1[2].target, ttnn.to_device)
self.assertEqual(nodes_1[3].target, ttnn.add)
self.assertEqual(nodes_1[4].target, ttnn.from_device)
self.assertEqual(nodes_1[5].target, ttnn.to_layout)
self.assertEqual(nodes_1[6].target, ttnn.to_torch)
self.assertEqual(nodes_1[2].target, ttnn.to_layout)
self.assertEqual(nodes_1[3].target, ttnn.to_device)
self.assertEqual(nodes_1[4].target, ttnn.add)
self.assertEqual(nodes_1[5].target, ttnn.from_device)
self.assertEqual(nodes_1[6].target, ttnn.to_layout)
self.assertEqual(nodes_1[7].target, ttnn.to_torch)
nodes_2 = list(option._out_fx_graphs[2].nodes)
self.assertEqual(len(nodes_2), 8)
self.assertEqual(len(nodes_2), 9)
self.assertEqual(nodes_2[1].target, ttnn.from_torch)
self.assertEqual(nodes_2[2].target, ttnn.to_device)
self.assertEqual(nodes_2[3].target, ttnn.matmul)
self.assertEqual(nodes_2[4].target, ttnn.from_device)
self.assertEqual(nodes_2[5].target, ttnn.to_layout)
self.assertEqual(nodes_2[6].target, ttnn.to_torch)
self.assertEqual(nodes_2[2].target, ttnn.to_layout)
self.assertEqual(nodes_2[3].target, ttnn.to_device)
self.assertEqual(nodes_2[4].target, ttnn.matmul)
self.assertEqual(nodes_2[5].target, ttnn.from_device)
self.assertEqual(nodes_2[6].target, ttnn.to_layout)
self.assertEqual(nodes_2[7].target, ttnn.to_torch)

# Check inference result
self.assertTrue(torch.allclose(result_before_then, result_after_then))
Expand Down
Loading

0 comments on commit 4850320

Please sign in to comment.