From 9dd512f3142cd4641e1d48492dc05ee554105146 Mon Sep 17 00:00:00 2001 From: Vladica Obojevic Date: Fri, 7 Feb 2025 11:04:36 +0000 Subject: [PATCH] Add gelu and leaky_relu operators tests - Add PytorchUtils class - Add gelu and leaky_relu operators tests --- .../pytorch/eltwise_unary/test_unary.py | 86 ++++++++++++++++++- forge/test/operators/utils/__init__.py | 2 + forge/test/operators/utils/utils.py | 17 ++++ 3 files changed, 102 insertions(+), 3 deletions(-) diff --git a/forge/test/operators/pytorch/eltwise_unary/test_unary.py b/forge/test/operators/pytorch/eltwise_unary/test_unary.py index 32b26ee7b..3a69b7dc4 100644 --- a/forge/test/operators/pytorch/eltwise_unary/test_unary.py +++ b/forge/test/operators/pytorch/eltwise_unary/test_unary.py @@ -49,14 +49,13 @@ # (/) Reuse inputs for selected operators -import torch - from typing import List, Dict from loguru import logger from forge import MathFidelity, DataFormat from test.operators.utils import InputSourceFlags, VerifyUtils from test.operators.utils import InputSource +from test.operators.utils import PytorchUtils from test.operators.utils import TestVector from test.operators.utils import TestPlan from test.operators.utils import FailingReasons @@ -90,7 +89,8 @@ def verify( if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,): input_source_flag = InputSourceFlags.FROM_DRAM - operator = getattr(torch, test_vector.operator) + module = PytorchUtils.get_pytorch_module(test_vector.operator) + operator = getattr(module, test_vector.operator) kwargs = test_vector.kwargs if test_vector.kwargs else {} @@ -125,6 +125,7 @@ class TestParamsData: __test__ = False test_plan_implemented: TestPlan = None + test_plan_implemented_float: TestPlan = None test_plan_not_implemented: TestPlan = None no_kwargs = [ @@ -144,12 +145,27 @@ class TestParamsData: {"exponent": 10.0}, ] + kwargs_gelu = [ + {"approximate": "tanh"}, + {}, + ] + + kwargs_leaky_relu = [ + {"negative_slope": 0.01, "inplace": True}, + {"negative_slope": 0.1, "inplace": False}, + {}, + ] + @classmethod def generate_kwargs(cls, test_vector: TestVector): if test_vector.operator in ("clamp",): return cls.kwargs_clamp if test_vector.operator in ("pow",): return cls.kwargs_pow + if test_vector.operator in ("gelu",): + return cls.kwargs_gelu + if test_vector.operator in ("leaky_relu",): + return cls.kwargs_leaky_relu return cls.no_kwargs @@ -179,6 +195,12 @@ class TestCollectionData: "log1p", ], ) + implemented_float = TestCollection( + operators=[ + "gelu", + "leaky_relu", + ], + ) not_implemented = TestCollection( operators=[ "acos", @@ -691,6 +713,63 @@ class TestCollectionData: ) +TestParamsData.test_plan_implemented_float = TestPlan( + verify=lambda test_device, test_vector: TestVerification.verify( + test_device, + test_vector, + ), + collections=[ + # Test gelu, leaky_relu operators collection: + TestCollection( + operators=TestCollectionData.implemented_float.operators, + input_sources=TestCollectionCommon.all.input_sources, + input_shapes=TestCollectionCommon.all.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + ), + # Test gelu, leaky_relu data formats collection: + TestCollection( + operators=TestCollectionData.implemented_float.operators, + input_sources=TestCollectionCommon.single.input_sources, + input_shapes=TestCollectionCommon.single.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + dev_data_formats=[ + item + for item in TestCollectionCommon.float.dev_data_formats + if item not in TestCollectionCommon.single.dev_data_formats + ], + math_fidelities=TestCollectionCommon.single.math_fidelities, + ), + # Test gelu, leaky_relu math fidelities collection: + TestCollection( + operators=TestCollectionData.implemented_float.operators, + input_sources=TestCollectionCommon.single.input_sources, + input_shapes=TestCollectionCommon.single.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + dev_data_formats=TestCollectionCommon.single.dev_data_formats, + math_fidelities=TestCollectionCommon.all.math_fidelities, + ), + ], + failing_rules=[ + TestCollection( + operators=["gelu"], + input_shapes=[(1, 1)], + kwargs=[ + {"approximate": "tanh"}, + {}, + ], + failing_reason=FailingReasons.DATA_MISMATCH, + ), + TestCollection( + operators=["leaky_relu"], + input_sources=[InputSource.CONST_EVAL_PASS], + input_shapes=[(1, 1)], + kwargs=[{"negative_slope": 0.01, "inplace": True}], + failing_reason=FailingReasons.DATA_MISMATCH, + ), + ], +) + + TestParamsData.test_plan_not_implemented = TestPlan( verify=lambda test_device, test_vector: TestVerification.verify( test_device, @@ -718,5 +797,6 @@ class TestCollectionData: def get_test_plans() -> List[TestPlan]: return [ TestParamsData.test_plan_implemented, + TestParamsData.test_plan_implemented_float, TestParamsData.test_plan_not_implemented, ] diff --git a/forge/test/operators/utils/__init__.py b/forge/test/operators/utils/__init__.py index a09136d0b..f08356dc1 100644 --- a/forge/test/operators/utils/__init__.py +++ b/forge/test/operators/utils/__init__.py @@ -14,6 +14,7 @@ from .utils import LoggerUtils from .utils import RateLimiter from .utils import FrameworkModelType +from .utils import PytorchUtils from .features import TestFeaturesConfiguration from .plan import InputSource from .plan import TestVector @@ -47,6 +48,7 @@ "RateLimiter", "TestFeaturesConfiguration", "FrameworkModelType", + "PytorchUtils", "InputSource", "TestVector", "TestCollection", diff --git a/forge/test/operators/utils/utils.py b/forge/test/operators/utils/utils.py index 5ab08aebd..704cde8ab 100644 --- a/forge/test/operators/utils/utils.py +++ b/forge/test/operators/utils/utils.py @@ -19,6 +19,7 @@ from forge import ForgeModule, Module, DepricatedVerifyConfig from forge.op_repo import TensorShape +from forge.op_repo.pytorch_operators import pytorch_operator_repository from forge.verify import TestKind # , verify_module from forge._C import MathFidelity @@ -315,3 +316,19 @@ def limit_info(self) -> str: return f"{self.current_value} <= {self.current_limit}" else: return f"{self.current_value} > {self.current_limit}" + + +class PytorchUtils: + """Utility functions for PyTorch operators""" + + @staticmethod + def get_pytorch_module(module_name: str): + """Retrieving the module that contains a given operator, based on its full name.\n + For example, for "torch.nn.functional.gelu", the function returns module torch.nn.functional.""" + repo_operator = pytorch_operator_repository.get_by_name(module_name).full_name + module_name = repo_operator.rsplit(".", 1)[0] + # module = importlib.import_module(module_name) # bad performance + module = torch + if module_name == "torch.nn.functional": + module = torch.nn.functional + return module