Skip to content

Commit

Permalink
Reshape op - failing rules
Browse files Browse the repository at this point in the history
  • Loading branch information
vobojevicTT committed Jan 28, 2025
1 parent e544f03 commit c918d7e
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 258 deletions.
42 changes: 8 additions & 34 deletions forge/test/operators/pytorch/tm/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from forge.verify.config import VerifyConfig

from forge.verify.value_checkers import AllCloseValueChecker
from forge.verify.value_checkers import AutomaticValueChecker
from forge.verify.verify import verify as forge_verify

from test.operators.utils import InputSourceFlags, VerifyUtils
Expand Down Expand Up @@ -129,6 +130,11 @@ def verify(

logger.trace(f"***input_shapes: {input_shapes}")

# We use AllCloseValueChecker in all cases except for integer data formats:
verify_config = VerifyConfig(value_checker=AllCloseValueChecker())
if test_vector.dev_data_format in TestCollectionCommon.int.dev_data_formats:
verify_config = VerifyConfig(value_checker=AutomaticValueChecker())

VerifyUtils.verify(
model=pytorch_model,
test_device=test_device,
Expand All @@ -140,7 +146,7 @@ def verify(
warm_reset=warm_reset,
value_range=ValueRanges.SMALL,
deprecated_verification=False,
verify_config=VerifyConfig(value_checker=AllCloseValueChecker()),
verify_config=verify_config,
)


Expand Down Expand Up @@ -287,39 +293,7 @@ def generate_specific_kwargs(cls, test_vector: TestVector):
failing_reason=FailingReasons.DATA_MISMATCH,
),
TestCollection(
input_shapes=[(1, 10000)],
failing_reason=FailingReasons.INFERENCE_FAILED,
),
TestCollection(
input_shapes=[
(100, 100),
(1000, 100),
(89, 3),
(1, 64, 1),
(1, 100, 100),
(11, 17, 41),
(1, 2, 3, 4),
(1, 11, 45, 17),
(1, 11, 17, 41),
(1, 13, 89, 3),
(8, 1, 10, 1000),
(11, 32, 32, 64),
(8, 8, 8),
],
failing_reason=FailingReasons.INFERENCE_FAILED,
),
TestCollection(
input_shapes=[(1, 2, 2, 2)],
criteria=lambda test_vector: test_vector.kwargs is not None
and "shape" in test_vector.kwargs
and test_vector.kwargs["shape"] == (8,),
failing_reason=FailingReasons.INFERENCE_FAILED,
),
TestCollection(
input_shapes=[(1, 49, 2304)],
criteria=lambda test_vector: test_vector.kwargs is not None
and "shape" in test_vector.kwargs
and test_vector.kwargs["shape"] == (-1,),
input_shapes=[(1, 10000), (7, 10, 1000, 100)],
failing_reason=FailingReasons.INFERENCE_FAILED,
),
TestCollection(
Expand Down
Loading

0 comments on commit c918d7e

Please sign in to comment.