Skip to content

Commit

Permalink
support torch dtype as data formats in test plan
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrkicTT authored and kmilanovicTT committed Feb 3, 2025
1 parent abdcb6b commit 75d0c23
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 18 deletions.
4 changes: 1 addition & 3 deletions forge/test/operators/pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def log_test_vector_properties(
item.user_properties.append(
("input_source", test_vector.input_source.name if test_vector.input_source is not None else None)
)
item.user_properties.append(
("dev_data_format", test_vector.dev_data_format.name if test_vector.dev_data_format is not None else None)
)
item.user_properties.append(("dev_data_format", TestPlanUtils.dev_data_format_to_str(test_vector.dev_data_format)))
item.user_properties.append(
("math_fidelity", test_vector.math_fidelity.name if test_vector.math_fidelity is not None else None)
)
Expand Down
4 changes: 3 additions & 1 deletion forge/test/operators/pytorch/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def build_filtered_collection(cls) -> TestCollection:
dev_data_formats = os.getenv("DEV_DATA_FORMATS", None)
if dev_data_formats:
dev_data_formats = dev_data_formats.split(",")
dev_data_formats = [getattr(forge.DataFormat, dev_data_format) for dev_data_format in dev_data_formats]
dev_data_formats = [
TestPlanUtils.dev_data_format_from_str(dev_data_format) for dev_data_format in dev_data_formats
]

math_fidelities = os.getenv("MATH_FIDELITIES", None)
if math_fidelities:
Expand Down
4 changes: 4 additions & 0 deletions forge/test/operators/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .datatypes import OperatorParameterTypes
from .datatypes import ValueRange
from .datatypes import ValueRanges
from .datatypes import FrameworkDataFormat
from .utils import ShapeUtils
from .utils import InputSourceFlag, InputSourceFlags
from .utils import CompilerUtils
Expand All @@ -24,6 +25,7 @@
from .plan import FailingRulesConverter
from .plan import TestPlanScanner
from .test_data import TestCollectionCommon
from .test_data import TestCollectionTorch
from .failing_reasons import FailingReasons
from .failing_reasons import FailingReasonsValidation
from .pytest import PyTestUtils
Expand All @@ -33,6 +35,7 @@
"OperatorParameterTypes",
"ValueRange",
"ValueRanges",
"FrameworkDataFormat",
"ShapeUtils",
"InputSourceFlag",
"InputSourceFlags",
Expand All @@ -53,6 +56,7 @@
"FailingRulesConverter",
"TestPlanScanner",
"TestCollectionCommon",
"TestCollectionTorch",
"FailingReasons",
"FailingReasonsValidation",
"PyTestUtils",
Expand Down
39 changes: 31 additions & 8 deletions forge/test/operators/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from forge.verify.config import VerifyConfig

from .datatypes import OperatorParameterTypes, ValueRanges, ValueRange
from .datatypes import FrameworkDataFormat


# TODO - Remove this class once TestDevice is available in Forge
Expand Down Expand Up @@ -91,10 +92,12 @@ class TestTensorsUtils:
torch.bfloat16: (-10000, 10000),
torch.float16: (-10000, 10000),
torch.float32: (-10000, 10000),
torch.float64: (-10000, 10000),
torch.uint8: (0, 2**8 - 1),
torch.int8: (-(2**7), 2**7 - 1),
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1),
torch.int64: (-(2**63), 2**63 - 1),
}

class DTypes:
Expand All @@ -104,12 +107,14 @@ class DTypes:
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
)
integers = (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
booleans = (torch.bool,)

Expand All @@ -128,8 +133,13 @@ def get_value_range(
if dtype is None:
dtype = torch.float32

if dev_data_format in cls.data_format_ranges:
data_format_ranges = cls.data_format_ranges[dev_data_format]
if isinstance(dev_data_format, forge.DataFormat):
forge_data_format_ranges = cls.data_format_ranges
elif isinstance(dev_data_format, torch.dtype):
forge_data_format_ranges = cls.dtype_ranges

if dev_data_format in forge_data_format_ranges:
data_format_ranges = forge_data_format_ranges[dev_data_format]
else:
raise ValueError(f"Unsupported range for dev data format: {dev_data_format}")
if dtype in cls.dtype_ranges:
Expand Down Expand Up @@ -178,6 +188,8 @@ def get_dtype_for_df(cls, dev_data_format: forge.DataFormat = None) -> torch.dty

if dev_data_format is None:
dtype = None
elif isinstance(dev_data_format, torch.dtype):
dtype = dev_data_format
else:
# dtype = torch.float32
if dev_data_format in cls.dev_data_format_to_dtype:
Expand Down Expand Up @@ -235,14 +247,16 @@ def extract_value_range(
return value_range


# TODO remove this method, used only in RGG
# Compatibility method for verifying models
def verify_module(
def verify_module_old(
model: Module,
input_shapes: List[TensorShape],
pcc: Optional[float] = None,
dev_data_format: forge.DataFormat = None,
dev_data_format: FrameworkDataFormat = None,
value_range: Optional[Union[ValueRanges, ValueRange, OperatorParameterTypes.RangeValue]] = None,
random_seed: int = 42,
convert_to_forge: bool = True, # explicit conversion to forge data format
):

logger.debug(
Expand All @@ -251,13 +265,13 @@ def verify_module(

inputs = create_torch_inputs(input_shapes, dev_data_format, value_range, random_seed)

verify_module_for_inputs(model, inputs, pcc, dev_data_format)
verify_module_for_inputs(model, inputs, pcc, dev_data_format, convert_to_forge)


# TODO move to class TestTensorsUtils
def create_torch_inputs(
input_shapes: List[TensorShape],
dev_data_format: forge.DataFormat = None,
dev_data_format: FrameworkDataFormat = None,
value_range: Optional[Union[ValueRanges, ValueRange, OperatorParameterTypes.RangeValue]] = None,
random_seed: Optional[int] = None,
) -> List[torch.Tensor]:
Expand Down Expand Up @@ -290,11 +304,15 @@ def verify_module_for_inputs_deprecated(
inputs: List[torch.Tensor],
pcc: Optional[float] = None,
dev_data_format: forge.DataFormat = None,
convert_to_forge: bool = True, # explicit conversion to forge data format
):

fw_out = model(*inputs)

forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
if convert_to_forge:
forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
else:
forge_inputs = inputs

compiled_model = forge.compile(model, sample_inputs=forge_inputs)
co_out = compiled_model(*forge_inputs)
Expand All @@ -321,8 +339,13 @@ def verify_module_for_inputs(
inputs: List[torch.Tensor],
verify_config: Optional[VerifyConfig] = VerifyConfig(),
dev_data_format: forge.DataFormat = None,
convert_to_forge: bool = True, # explicit conversion to forge data format
):

forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
if convert_to_forge:
forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
else:
forge_inputs = inputs

compiled_model = forge.compile(model, sample_inputs=forge_inputs)
verify(inputs, model, compiled_model, verify_config)
6 changes: 6 additions & 0 deletions forge/test/operators/utils/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from enum import Enum
from typing import Optional, Dict, Union, Tuple, TypeAlias

import torch
import forge


FrameworkDataFormat = Union[forge.DataFormat, torch.dtype]


class OperatorParameterTypes:
SingleValue: TypeAlias = Union[int, float]
Expand Down
39 changes: 33 additions & 6 deletions forge/test/operators/utils/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from forge.op_repo import TensorShape

from .datatypes import OperatorParameterTypes
from .datatypes import FrameworkDataFormat
from .pytest import PytestParamsUtils
from .compat import TestDevice
from .utils import RateLimiter
Expand Down Expand Up @@ -86,7 +87,7 @@ class TestVector:
input_source: InputSource
input_shape: TensorShape # TODO - Support multiple input shapes
number_of_operands: Optional[int] = None
dev_data_format: Optional[DataFormat] = None
dev_data_format: Optional[FrameworkDataFormat] = None
math_fidelity: Optional[MathFidelity] = None
kwargs: Optional[OperatorParameterTypes.Kwargs] = None
pcc: Optional[float] = None
Expand All @@ -96,7 +97,7 @@ class TestVector:
def get_id(self, fields: Optional[List[str]] = None) -> str:
"""Get test vector id"""
if fields is None:
return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}{'-' + str(self.number_of_operands) + '-' if self.number_of_operands else '-'}{self.dev_data_format.name if self.dev_data_format else None}-{self.math_fidelity.name if self.math_fidelity else None}"
return f"{self.operator}-{self.input_source.name}-{self.kwargs}-{self.input_shape}{'-' + str(self.number_of_operands) + '-' if self.number_of_operands else '-'}{TestPlanUtils.dev_data_format_to_str(self.dev_data_format)}-{self.math_fidelity.name if self.math_fidelity else None}"
else:
attr = [
(getattr(self, field).name if getattr(self, field) is not None else None)
Expand Down Expand Up @@ -144,7 +145,7 @@ class TestCollection:
input_sources: Optional[List[InputSource]] = None
input_shapes: Optional[List[TensorShape]] = None # TODO - Support multiple input shapes
numbers_of_operands: Optional[List[int]] = None
dev_data_formats: Optional[List[DataFormat]] = None
dev_data_formats: Optional[List[FrameworkDataFormat]] = None
math_fidelities: Optional[List[MathFidelity]] = None
kwargs: Optional[
Union[List[OperatorParameterTypes.Kwargs], Callable[["TestVector"], List[OperatorParameterTypes.Kwargs]]]
Expand Down Expand Up @@ -509,6 +510,32 @@ class TestPlanUtils:
Utility functions for test vectors
"""

@classmethod
def dev_data_format_to_str(cls, dev_data_format: FrameworkDataFormat) -> Optional[str]:
"""Convert data format to string"""
if dev_data_format is None:
return None
if isinstance(dev_data_format, DataFormat):
return dev_data_format.name
if isinstance(dev_data_format, torch.dtype):
# Remove torch. prefix
return str(dev_data_format).split(".")[-1]
else:
raise ValueError(f"Unsupported data format: {dev_data_format}")

@classmethod
def dev_data_format_from_str(cls, dev_data_format_str: str) -> FrameworkDataFormat:
"""Convert string to data format"""
if dev_data_format_str is None:
return None
if hasattr(forge.DataFormat, dev_data_format_str):
dev_data_format = getattr(forge.DataFormat, dev_data_format_str)
elif hasattr(torch, dev_data_format_str):
dev_data_format = getattr(torch, dev_data_format_str)
else:
raise ValueError(f"Unsupported data format: {dev_data_format_str} in Forge and PyTorch")
return dev_data_format

@classmethod
def _match(cls, rule_collection: Optional[List], vector_value):
"""
Expand Down Expand Up @@ -653,7 +680,7 @@ def test_id_to_test_vector(cls, test_id: str) -> TestVector:
dev_data_format_part = parts[dev_data_format_index]
if dev_data_format_part == "None":
dev_data_format_part = None
dev_data_format = eval(f"forge._C.{dev_data_format_part}") if dev_data_format_part is not None else None
dev_data_format = cls.dev_data_format_from_str(dev_data_format_part)

math_fidelity_part = parts[math_fidelity_index]
if math_fidelity_part == "None":
Expand Down Expand Up @@ -689,7 +716,7 @@ def build_rules(
Union[Optional[InputSource], List[InputSource]],
Union[Optional[TensorShape], List[TensorShape]],
Union[Optional[OperatorParameterTypes.Kwargs], List[OperatorParameterTypes.Kwargs]],
Union[Optional[forge.DataFormat], List[forge.DataFormat]],
Union[Optional[FrameworkDataFormat], List[FrameworkDataFormat]],
Union[Optional[forge.MathFidelity], List[forge.MathFidelity]],
Optional[TestResultFailing],
],
Expand Down Expand Up @@ -734,7 +761,7 @@ def build_rule(
input_source: Optional[Union[InputSource, List[InputSource]]],
input_shape: Optional[Union[TensorShape, List[TensorShape]]],
kwargs: Optional[Union[OperatorParameterTypes.Kwargs, List[OperatorParameterTypes.Kwargs]]],
dev_data_format: Optional[Union[forge.DataFormat, List[forge.DataFormat]]],
dev_data_format: Optional[Union[FrameworkDataFormat, List[FrameworkDataFormat]]],
math_fidelity: Optional[Union[forge.MathFidelity, List[forge.MathFidelity]]],
result_failing: Optional[TestResultFailing],
) -> TestCollection:
Expand Down
42 changes: 42 additions & 0 deletions forge/test/operators/utils/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import forge
import torch

from forge import MathFidelity, DataFormat

Expand Down Expand Up @@ -309,3 +310,44 @@ class TestCollectionCommon:
(14, 13, 89, 3), # 4.2 Prime numbers
]
)


class TestCollectionTorch:
"""
Shared test collection for torch data types.
"""

__test__ = False # Avoid collecting TestCollectionTorch as a pytest test

float = TestCollection(
dev_data_formats=[
torch.float16,
torch.float32,
# torch.float64,
torch.bfloat16,
],
)

int = TestCollection(
dev_data_formats=[
torch.int8,
# torch.int16,
torch.int32,
torch.int64,
# torch.uint8,
],
)

bool = TestCollection(
dev_data_formats=[
torch.bool,
],
)

all = TestCollection(dev_data_formats=float.dev_data_formats + int.dev_data_formats)

single = TestCollection(
dev_data_formats=[
torch.float16,
],
)
7 changes: 7 additions & 0 deletions forge/test/operators/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def verify(
pcc: Optional[float] = None,
input_source_flag: InputSourceFlags = None,
dev_data_format: forge.DataFormat = None,
convert_to_forge: bool = True, # explicit conversion to forge data format
math_fidelity: forge.MathFidelity = None,
value_range: Optional[ValueRanges] = None,
random_seed: Optional[int] = None,
Expand Down Expand Up @@ -167,13 +168,15 @@ def verify(
inputs=inputs,
pcc=pcc,
dev_data_format=dev_data_format,
convert_to_forge=convert_to_forge,
)
else:
cls.verify_module_for_inputs(
model=model,
inputs=inputs,
verify_config=verify_config,
dev_data_format=dev_data_format,
convert_to_forge=convert_to_forge,
)

@classmethod
Expand Down Expand Up @@ -220,13 +223,15 @@ def verify_module_for_inputs_deprecated(
inputs: List[torch.Tensor],
pcc: Optional[float] = None,
dev_data_format: forge.DataFormat = None,
convert_to_forge: bool = True, # explicit conversion to forge data format
):

verify_module_for_inputs_deprecated(
model=model,
inputs=inputs,
pcc=pcc,
dev_data_format=dev_data_format,
convert_to_forge=convert_to_forge,
)

@classmethod
Expand All @@ -236,13 +241,15 @@ def verify_module_for_inputs(
inputs: List[torch.Tensor],
verify_config: Optional[VerifyConfig] = VerifyConfig(),
dev_data_format: forge.DataFormat = None,
convert_to_forge: bool = True, # explicit conversion to forge data format
):

verify_module_for_inputs(
model=model,
inputs=inputs,
verify_config=verify_config,
dev_data_format=dev_data_format,
convert_to_forge=convert_to_forge,
)


Expand Down

0 comments on commit 75d0c23

Please sign in to comment.