Skip to content

Commit

Permalink
Merge branch 'main' into ay/run_integration_test
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt authored Jun 28, 2024
2 parents 354d492 + 4f8daaf commit f1e4057
Show file tree
Hide file tree
Showing 22 changed files with 3,025 additions and 240 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/before_merge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ jobs:
validate-pr:
env:
ARCH_NAME: wormhole_b0
TT_METAL_HOME: ${pwd}
PYTHONPATH: ${pwd}
runs-on: ["in-service", "n150"]
steps:
- name: Checkout Repo
Expand All @@ -36,4 +34,4 @@ jobs:
RUN_INTEGRATION_TESTS: 1
run: |
source venv/bin/activate
python3 -m pytest --github-report tests/*.py
python3 -m pytest --github-report tests/*.py -s
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,11 @@ The `*_total_*_size_dist/` statistics the `op_type`'s input/output_size distribu
- Notice: the [aten ir interface is in there](https://pytorch.org/docs/stable/torch.compiler_ir.html)

[The `profile/` is the tools provided by pytorch](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html), you can open it by the url: chrome://tracing

# Run transformer models
To run transformer model with ttnn backend, run:
```
PYTHONPATH=${TT_METAL_HOME}:$(pwd) python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn
```

You can also substitute the backend with `torch_stat` to run a reference comparison.
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
--find-links https://download.pytorch.org/whl/torch_stable.html

torch==2.2.1.0+cpu
torchvision==0.17.1+cpu
tabulate==0.9.0
networkx==3.1
graphviz
matplotlib
matplotlib==3.7.1
https://github.com/tenstorrent/tt-metal/releases/download/v0.50.0-rc18/metal_libs-0.50.0rc18+wormhole.b0-cp38-cp38-linux_x86_64.whl
24 changes: 13 additions & 11 deletions tests/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def input_shapes(self):
class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open(0)
self.device: ttnn.Device = ttnn.open_device(device_id=0)

def tearDown(self):
# Close the device
ttnn.close(self.device)
ttnn.close_device(self.device)

def test_add(self):
m = AddModule()
Expand All @@ -32,19 +32,21 @@ def test_add(self):
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend(option))
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
self.assertEqual(1, len(option._out_fx_graphs))
option._out_fx_graphs[0].print_tabular()
# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertEqual(nodes[3].target, ttnn.add)
self.assertEqual(nodes[3].args[0].target, ttnn.to_device)
self.assertEqual(nodes[3].args[0].args[0].target, ttnn.from_torch)
self.assertEqual(nodes[3].args[1].target, ttnn.to_device)
self.assertEqual(nodes[3].args[1].args[0].target, ttnn.from_torch)
self.assertEqual(nodes[4].target, ttnn.from_device)
self.assertEqual(nodes[5].target, ttnn.to_layout)
self.assertEqual(nodes[6].target, ttnn.to_torch)
self.assertEqual(nodes[4].target, ttnn.add)
self.assertEqual(nodes[4].args[0].target, ttnn.to_device)
self.assertEqual(nodes[4].args[0].args[0].target, ttnn.to_layout)
self.assertEqual(nodes[4].args[0].args[0].args[0].target, ttnn.from_torch)
self.assertEqual(nodes[4].args[1].target, ttnn.to_device)
self.assertEqual(nodes[4].args[1].args[0].target, ttnn.to_layout)
self.assertEqual(nodes[4].args[1].args[0].args[0].target, ttnn.from_torch)
self.assertEqual(nodes[5].target, ttnn.from_device)
self.assertEqual(nodes[6].target, ttnn.to_layout)
self.assertEqual(nodes[7].target, ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after))
32 changes: 22 additions & 10 deletions tests/test_fall_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import unittest
from torch_ttnn import ttnn

from torch_ttnn.utils import check_with_pcc


class MixModule(torch.nn.Module):
def __init__(self):
Expand All @@ -23,11 +25,11 @@ def input_shapes(self):
class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open(0)
self.device: ttnn.Device = ttnn.open_device(device_id=0)

def tearDown(self):
# Close the device
ttnn.close(self.device)
ttnn.close_device(self.device)

def test_fall_back(self):
m = MixModule()
Expand All @@ -37,19 +39,29 @@ def test_fall_back(self):
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend(option))
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
self.assertEqual(1, len(option._out_fx_graphs))
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertEqual(nodes[3].target, ttnn.from_torch)
self.assertEqual(nodes[2].target, ttnn.from_torch)
self.assertEqual(nodes[3].target, ttnn.to_layout)
self.assertEqual(nodes[4].target, ttnn.to_device)
self.assertEqual(nodes[5].target, ttnn.add)
self.assertEqual(nodes[6].target, ttnn.matmul)
self.assertEqual(nodes[7].target, ttnn.from_device)
self.assertEqual(nodes[8].target, ttnn.to_layout)
self.assertEqual(nodes[9].target, ttnn.to_torch)
self.assertEqual(nodes[5].target, ttnn.reciprocal)
self.assertEqual(nodes[6].target, ttnn.from_torch)
self.assertEqual(nodes[7].target, ttnn.to_layout)
self.assertEqual(nodes[8].target, ttnn.to_device)
self.assertEqual(nodes[9].target, ttnn.mul)
self.assertEqual(nodes[10].target, ttnn.add)
self.assertEqual(nodes[11].target, ttnn.matmul)
self.assertEqual(nodes[12].target, ttnn.reciprocal)
self.assertEqual(nodes[13].target, ttnn.mul)
self.assertEqual(nodes[14].target, ttnn.reciprocal)
self.assertEqual(nodes[15].target, ttnn.mul)
self.assertEqual(nodes[16].target, ttnn.from_device)
self.assertEqual(nodes[17].target, ttnn.to_layout)
self.assertEqual(nodes[18].target, ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after))
self.assertTrue(check_with_pcc(result_before, result_after))
32 changes: 17 additions & 15 deletions tests/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def input_shapes(self):
class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open(0)
self.device: ttnn.Device = ttnn.open_device(device_id=0)

def tearDown(self):
# Close the device
ttnn.close(self.device)
ttnn.close_device(self.device)

def test_if(self):
m = IfModule()
Expand All @@ -36,7 +36,7 @@ def test_if(self):
result_before_else = m.forward(*inputs_else)
option = torch_ttnn.TorchTtnnOption(device=self.device)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend(option))
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after_then = m.forward(*inputs_then)
result_after_else = m.forward(*inputs_else)

Expand All @@ -49,21 +49,23 @@ def test_if(self):
self.assertEqual(nodes_0[1].target, torch.ops.aten.sum.default)
self.assertEqual(nodes_0[2].target, torch.ops.aten.gt.Scalar)
nodes_1 = list(option._out_fx_graphs[1].nodes)
self.assertEqual(len(nodes_1), 8)
self.assertEqual(len(nodes_1), 9)
self.assertEqual(nodes_1[1].target, ttnn.from_torch)
self.assertEqual(nodes_1[2].target, ttnn.to_device)
self.assertEqual(nodes_1[3].target, ttnn.add)
self.assertEqual(nodes_1[4].target, ttnn.from_device)
self.assertEqual(nodes_1[5].target, ttnn.to_layout)
self.assertEqual(nodes_1[6].target, ttnn.to_torch)
self.assertEqual(nodes_1[2].target, ttnn.to_layout)
self.assertEqual(nodes_1[3].target, ttnn.to_device)
self.assertEqual(nodes_1[4].target, ttnn.add)
self.assertEqual(nodes_1[5].target, ttnn.from_device)
self.assertEqual(nodes_1[6].target, ttnn.to_layout)
self.assertEqual(nodes_1[7].target, ttnn.to_torch)
nodes_2 = list(option._out_fx_graphs[2].nodes)
self.assertEqual(len(nodes_2), 8)
self.assertEqual(len(nodes_2), 9)
self.assertEqual(nodes_2[1].target, ttnn.from_torch)
self.assertEqual(nodes_2[2].target, ttnn.to_device)
self.assertEqual(nodes_2[3].target, ttnn.matmul)
self.assertEqual(nodes_2[4].target, ttnn.from_device)
self.assertEqual(nodes_2[5].target, ttnn.to_layout)
self.assertEqual(nodes_2[6].target, ttnn.to_torch)
self.assertEqual(nodes_2[2].target, ttnn.to_layout)
self.assertEqual(nodes_2[3].target, ttnn.to_device)
self.assertEqual(nodes_2[4].target, ttnn.matmul)
self.assertEqual(nodes_2[5].target, ttnn.from_device)
self.assertEqual(nodes_2[6].target, ttnn.to_layout)
self.assertEqual(nodes_2[7].target, ttnn.to_torch)

# Check inference result
self.assertTrue(torch.allclose(result_before_then, result_after_then))
Expand Down
Loading

0 comments on commit f1e4057

Please sign in to comment.