Skip to content

Commit

Permalink
#10815: Update the composite ops with ttnn support
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Jul 31, 2024
1 parent b920d75 commit 552a790
Show file tree
Hide file tree
Showing 17 changed files with 161 additions and 526 deletions.
22 changes: 0 additions & 22 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,28 +280,12 @@ Tensor elementwise operations

.. autofunction:: tt_lib.tensor.mac

.. autofunction:: tt_lib.tensor.softshrink

.. autofunction:: tt_lib.tensor.hardshrink

.. autofunction:: tt_lib.tensor.remainder

.. autofunction:: tt_lib.tensor.fmod

.. autofunction:: tt_lib.tensor.logical_xori

.. autofunction:: tt_lib.tensor.celu

.. autofunction:: tt_lib.tensor.logit

.. autofunction:: tt_lib.tensor.logical_andi

.. autofunction:: tt_lib.tensor.assign

.. autofunction:: tt_lib.tensor.logical_ori

.. autofunction:: tt_lib.tensor.frac

.. autofunction:: tt_lib.tensor.floor_div

.. autofunction:: tt_lib.tensor.rfloor_div
Expand Down Expand Up @@ -363,8 +347,6 @@ Broadcast and Reduce

.. autofunction:: tt_lib.tensor.global_mean

.. autofunction:: tt_lib.tensor.rpow


Fallback Operations
*******************
Expand Down Expand Up @@ -521,10 +503,6 @@ Other Operations

.. autofunction:: tt_lib.tensor.mean_hw

.. autofunction:: tt_lib.tensor.logical_noti

.. autofunction:: tt_lib.tensor.normalize_global

.. autofunction:: tt_lib.tensor.lamb_optimizer

.. autofunction:: tt_lib.tensor.repeat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import pytest
import torch
import tt_lib as ttl

import ttnn

from tests.tt_eager.python_api_testing.sweep_tests import pytorch_ops
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc
from tests.tt_eager.python_api_testing.sweep_tests.tt_lib_ops import eltwise_logical_andi as tt_eltwise_logical_andi
from tests.tt_eager.python_api_testing.sweep_tests.tt_lib_ops import eltwise_logical_and_ as tt_eltwise_logical_and_


def run_eltwise_logical_andi_tests(
Expand All @@ -20,7 +20,6 @@ def run_eltwise_logical_andi_tests(
dlayout,
in_mem_config,
out_mem_config,
immediate,
data_seed,
device,
):
Expand All @@ -31,47 +30,50 @@ def run_eltwise_logical_andi_tests(
in_mem_config = None

x = torch.Tensor(size=input_shape).uniform_(-100, 100).to(torch.bfloat16)
y = torch.Tensor(size=input_shape).uniform_(-100, 100).to(torch.bfloat16)
x_ref = x.detach().clone()
y_ref = y.detach().clone()

# get referent value
ref_value = pytorch_ops.logical_andi(x_ref, immediate=immediate)
golden_function = ttnn.get_golden_function(ttnn.logical_and_)
ref_value = golden_function(x_ref, y_ref)

# calculate tt output
logger.info("Running eltwise_andi test")
tt_result = tt_eltwise_logical_andi(
logger.info("Running eltwise_and_ test")
tt_result = tt_eltwise_logical_and_(
x=x,
immediate=immediate,
y=y,
device=device,
dtype=[dtype],
layout=[dlayout],
input_mem_config=[in_mem_config],
dtype=dtype,
layout=dlayout,
input_mem_config=in_mem_config,
output_mem_config=out_mem_config,
)

# compare tt and golden outputs
success, pcc_value = comp_pcc(ref_value, tt_result)
success, pcc_value = comp_pcc(ref_value, x)
logger.debug(pcc_value)

assert success


# eltwise-logical_andi,"[[6, 9, 192, 128]]","{'dtype': [<DataType.BFLOAT8_B: 3>], 'layout': [<Layout.TILE: 1>], 'input_mem_config': [None], 'output_mem_config': tt::tt_metal::MemoryConfig(memory_layout=TensorMemoryLayout::INTERLEAVED,buffer_type=BufferType::L1,shard_spec=std::nullopt), 'immediate': 0}",19790443,(),error,"TT_FATAL @ /home/ubuntu/tt-metal/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp:94: input_tensor_a.get_dtype() == input_tensor_b.get_dtype()

test_sweep_args = [
(
(6, 9, 192, 128),
ttl.tensor.DataType.BFLOAT8_B,
ttl.tensor.Layout.TILE,
"SYSTEM_MEMORY",
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1),
0,
19790443,
[ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT16],
[ttl.tensor.Layout.TILE, ttl.tensor.Layout.TILE],
[
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
],
ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM),
14854324,
),
]


@pytest.mark.parametrize(
"input_shape, dtype, dlayout, in_mem_config, out_mem_config, immediate, data_seed",
"input_shape, dtype, dlayout, in_mem_config, out_mem_config, data_seed",
(test_sweep_args),
)
def test_eltwise_logical_andi_test(
Expand All @@ -80,7 +82,6 @@ def test_eltwise_logical_andi_test(
dlayout,
in_mem_config,
out_mem_config,
immediate,
data_seed,
device,
):
Expand All @@ -90,7 +91,6 @@ def test_eltwise_logical_andi_test(
dlayout,
in_mem_config,
out_mem_config,
immediate,
data_seed,
device,
)
24 changes: 12 additions & 12 deletions tests/tt_eager/python_api_testing/sweep_tests/op_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@
"tt_op": tt_lib_ops.eltwise_i0,
"pytorch_op": pytorch_ops.i0,
},
"eltwise-logical_noti": {
"tt_op": tt_lib_ops.eltwise_logical_noti,
"pytorch_op": pytorch_ops.logical_noti,
"eltwise-logical_not_": {
"tt_op": tt_lib_ops.eltwise_logical_not_,
"pytorch_op": pytorch_ops.logical_not_,
},
"eltwise-bitwise_complement": {
"tt_op": None, # tt_lib_ops.eltwise_bitwise_complement,
Expand Down Expand Up @@ -304,9 +304,9 @@
"tt_op": tt_lib_ops.eltwise_logical_and,
"pytorch_op": pytorch_ops.logical_and,
},
"eltwise-logical_andi": {
"tt_op": tt_lib_ops.eltwise_logical_andi,
"pytorch_op": pytorch_ops.logical_andi,
"eltwise-logical_and_": {
"tt_op": tt_lib_ops.eltwise_logical_and_,
"pytorch_op": pytorch_ops.logical_and_,
},
"eltwise-leaky_relu": {
"tt_op": tt_lib_ops.eltwise_leaky_relu,
Expand Down Expand Up @@ -420,9 +420,9 @@
"tt_op": tt_lib_ops.lamb_optimizer,
"pytorch_op": pytorch_ops.lamb_optimizer,
},
"eltwise-logical_xori": {
"tt_op": tt_lib_ops.eltwise_logical_xori,
"pytorch_op": pytorch_ops.logical_xori,
"eltwise-logical_xor_": {
"tt_op": tt_lib_ops.eltwise_logical_xor_,
"pytorch_op": pytorch_ops.logical_xor_,
},
"eltwise-log": {
"tt_op": tt_lib_ops.eltwise_log,
Expand Down Expand Up @@ -609,9 +609,9 @@
"tt_op": tt_lib_ops.eltwise_logical_or,
"pytorch_op": pytorch_ops.logical_or,
},
"eltwise-logical_ori": {
"tt_op": tt_lib_ops.eltwise_logical_ori,
"pytorch_op": pytorch_ops.logical_ori,
"eltwise-logical_or_": {
"tt_op": tt_lib_ops.eltwise_logical_or_,
"pytorch_op": pytorch_ops.logical_or_,
},
# Eltwise binary with optional output
"eltwise-ne-optional": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def custom_compare(*args, **kwargs):
function = kwargs.pop("function")
if function in [
"logical_xor",
"logical_ori",
"logical_or_",
"logical_or",
"logical_xori",
"logical_noti",
"logical_xor_",
"logical_not_",
"logical_not",
"logical_andi",
"logical_and_",
"is_close",
]:
comparison_func = comparison_funcs.comp_equal
Expand Down Expand Up @@ -97,11 +97,11 @@ def custom_compare(*args, **kwargs):
# "bias_gelu_unary",
"addalpha",
"logit",
# "logical_ori",
"logical_or_",
"logical_xor",
# "logical_xori",
# "logical_noti",
# "logical_andi",
"logical_xor_",
"logical_not_",
"logical_and_",
"isclose",
"digamma",
"lgamma",
Expand Down Expand Up @@ -145,9 +145,9 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
options["asinh"] = (-100, 100)
options["isclose"] = (-100, 100)
options["acosh"] = (1, 100)
options["logical_ori"] = (-100, 100)
options["logical_andi"] = (-100, 100)
options["logical_xori"] = (-100, 100)
options["logical_or_"] = (-100, 100)
options["logical_and_"] = (-100, 100)
options["logical_xor_"] = (-100, 100)

generator = generation_funcs.gen_rand

Expand All @@ -157,7 +157,12 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
if is_grayskull():
if fn in ["mish"]:
pytest.skip("does not work for Grayskull -skipping")
if fn in ["logical_xor", "logical_xori", "logical_ori", "logical_andi"]:
if fn in [
"logical_xor",
"logical_and_",
"logical_or_",
"logical_xor_",
]:
datagen_func = [
generation_funcs.gen_func_with_cast(
partial(generator, low=options[fn][0], high=options[fn][1]),
Expand Down Expand Up @@ -191,7 +196,9 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
"isclose",
"assign_binary",
"nextafter",
# "scatter",
"logical_and_",
"logical_or_",
"logical_xor_",
]:
num_inputs = 2

Expand Down Expand Up @@ -221,8 +228,6 @@ def test_run_eltwise_composite_test(fn, input_shapes, device, function_level_def
test_args.update({"eps": np.random.randint(-10, 0.99)})
elif fn in ["polygamma"]:
test_args.update({"k": np.random.randint(1, 10)})
elif fn in ["logical_ori", "logical_andi", "logical_xori", "logical_noti"]:
test_args.update({"immediate": np.random.randint(0, 100)})
elif fn in ["isclose"]:
test_args.update(
{
Expand Down
20 changes: 8 additions & 12 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,8 @@ def polygamma(x, *args, k, **kwargs):
return torch.special.polygamma(n=k, input=x)


def logical_xori(x, *args, **kwargs):
value = kwargs.pop("immediate")
result = torch.logical_xor(x, torch.tensor(value, dtype=torch.int32))
def logical_xor_(x, y, *args, **kwargs):
result = x.logical_xor_(y)
return result


Expand Down Expand Up @@ -744,9 +743,8 @@ def multigammaln(x, *args, **kwargs):
return torch.special.multigammaln(x, 4)


def logical_andi(x, *args, **kwargs):
value = kwargs.pop("immediate")
result = torch.logical_and(x, torch.tensor(value, dtype=torch.int32))
def logical_and_(x, y, *args, **kwargs):
result = x.logical_and_(y)
return result


Expand Down Expand Up @@ -960,9 +958,8 @@ def logical_and(x, y, *args, **kwargs):
return result


def logical_noti(x, *args, **kwargs):
immediate = kwargs.pop("immediate")
result = torch.logical_not(torch.full_like(x, immediate)).to(torch.int32)
def logical_not_(x, *args, **kwargs):
result = x.logical_not_()
return result


Expand Down Expand Up @@ -1162,9 +1159,8 @@ def logical_or(x, y, *args, **kwargs):
return torch.logical_or(x, y)


def logical_ori(x, *args, **kwargs):
value = kwargs.pop("immediate")
result = torch.logical_or(x, torch.tensor(value, dtype=torch.int32))
def logical_or_(x, y, *args, **kwargs):
result = x.logical_or_(y)
return result


Expand Down
Loading

0 comments on commit 552a790

Please sign in to comment.