Skip to content

Commit

Permalink
Add tests for basic ops (#8)
Browse files Browse the repository at this point in the history
* Tests added for abs, constant, typecast, div, exp, maximum, multiply, negate,
reduce::max, reduce::sum, subtract, transpose::2d, and transpose::3d.
* Create inputs with torch.int32 if all input types are not floating point.
  • Loading branch information
mmanzoorTT authored and uazizTT committed Oct 31, 2024
1 parent 2216553 commit d93f4ff
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 12 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/nightly-uplift.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ on:
- cron: '0 8 * * *' # Runs at 08:00 UTC every day
workflow_dispatch: # Manual trigger

jobs:
uplift-pr:
jobs:
uplift-pr:
runs-on: ubuntu-latest

env:
Expand Down Expand Up @@ -47,20 +47,20 @@ jobs:
base: main
commit-message: "Uplift ${{ env.SUBMODULE_PATH }} to ${{ env.SUBMODULE_VERSION }} ${{ env.TODAY }}"
title: "Uplift ${{ env.SUBMODULE_PATH }} to ${{ env.SUBMODULE_VERSION }} ${{ env.TODAY }}"
body: "This PR uplifts the ${{ env.SUBMODULE_PATH }} to the ${{ env.SUBMODULE_VERSION }}"
body: "This PR uplifts the ${{ env.SUBMODULE_PATH }} to the ${{ env.SUBMODULE_VERSION }}"
labels: uplift
delete-branch: true
token: ${{ secrets.GH_TOKEN }}

- name: Approve Pull Request
if: ${{ steps.create-pr.outputs.pull-request-number }}
env:
GITHUB_TOKEN: ${{ secrets.GH_APPROVE_TOKEN }}
run: |
echo "Pull Request Number - ${{ steps.create-pr.outputs.pull-request-number }}"
echo "Pull Request URL - ${{ steps.create-pr.outputs.pull-request-url }}"
gh pr review ${{ steps.create-pr.outputs.pull-request-number }} --approve
echo "Pull Request URL - ${{ steps.create-pr.outputs.pull-request-url }}"
gh pr review ${{ steps.create-pr.outputs.pull-request-number }} --approve
- name: Enable Pull Request Automerge
if: ${{ steps.create-pr.outputs.pull-request-number }}
run: gh pr merge --squash --auto "${{ steps.create-pr.outputs.pull-request-number }}"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/on-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ jobs:
secrets: inherit
build-and-test:
uses: ./.github/workflows/build-and-test.yml
secrets: inherit
secrets: inherit
161 changes: 161 additions & 0 deletions test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
from tt_torch.tools.verify import verify_module


def test_abs():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.abs(x)

verify_module(Basic(), [(256, 256)])


def test_add():
class Basic(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -64,6 +75,65 @@ def forward(self, x, y):
verify_module(Basic(), [(32, 32, 32, 32), (32, 32, 32, 64)])


@pytest.mark.skip(
"Torch keeps the 'value' as dialect resource which are not processed."
)
def test_constant_ones():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.tensor([1.0, 1.0, 1.0, 1.0])

verify_module(Basic(), [(1, 1)])


def test_convert():
class Basic_toFloat(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.to(torch.float32)

class Basic_toInt(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.to(torch.int32)

verify_module(Basic_toFloat(), [(4, 4)], input_data_types=[torch.int32])
verify_module(Basic_toFloat(), [(4, 4)], input_data_types=[torch.float32])
verify_module(Basic_toInt(), [(4, 4)], input_data_types=[torch.int32])
verify_module(
Basic_toInt(), [(4, 4)], input_data_types=[torch.float32], input_range=(0, 60)
)


def test_div():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x / y

verify_module(Basic(), [(2, 2), (2, 2)], required_atol=3e-2)


def test_exp():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.exp(x)

verify_module(Basic(), [(2, 2)], required_atol=3e-2)


def test_linear():
class Basic(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -98,6 +168,63 @@ def forward(self, x):
verify_module(Basic(), [(32, 32)])


def test_maximum():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.maximum(x, y)

verify_module(Basic(), [(32, 32), (32, 32)], input_range=(-6, 6))


def test_multiply():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x * y

verify_module(Basic(), [(32, 32), (32, 32)])


def test_negate():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return -x

verify_module(Basic(), [(32, 32)], input_range=(-6, 6))


@pytest.mark.skip("keepdim=False is not supported")
def test_reduce_max():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.max(x)

verify_module(Basic(), [(32, 32)], input_range=(-6, 6))


@pytest.mark.skip("keepdim=False is not supported")
def test_reduce_sum():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sum(x)

verify_module(Basic(), [(32, 32)], input_range=(-6, 6))


def test_relu():
pytest.xfail()

Expand Down Expand Up @@ -177,6 +304,40 @@ def forward(self, a):
verify_module(Basic(), [shape])


def test_subtract():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return x - y

verify_module(Basic(), [(32, 32), (32, 32)], input_range=(-6, 6))


def test_transpose_2d():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.transpose(x, 0, 1)

verify_module(Basic(), [(4, 8)], input_range=(-6, 6))


@pytest.mark.skip("TTNN does not support transpose for higher ranks/dimensions.")
def test_transpose_3d():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.transpose(x, 0, 1)

verify_module(Basic(), [(4, 8, 4)], input_range=(-6, 6))


def test_bert():
pytest.xfail()
from torch_mlir import fx
Expand Down
9 changes: 5 additions & 4 deletions tt_torch/tools/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ def verify_module(

if all([dtype.is_floating_point for dtype in input_data_types]):
low, high = input_range
inputs = [
(low - high) * torch.rand(shape) + high for shape in input_shapes
] # uniformly distribute random numbers within the input_range
# Uniformly distribute random numbers within the input_range
inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes]
else:
inputs = [torch.randint(0, 1000, shape) for shape in input_shapes]
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes
]
ret = tt_mod(*inputs)
golden = mod(*inputs)

Expand Down

0 comments on commit d93f4ff

Please sign in to comment.