Skip to content

Commit

Permalink
feat: Add assertions to support MetricWriter (#129)
Browse files Browse the repository at this point in the history
This PR introduces several of the assertion methods used when
constructing the `MetricWriter`. (See #107 for how they are used in
practice).
  • Loading branch information
msto authored Oct 17, 2024
1 parent 78f2246 commit d25ff12
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 0 deletions.
106 changes: 106 additions & 0 deletions fgpyo/util/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
from typing import Generic
from typing import Iterator
from typing import List
from typing import Type
from typing import TypeVar

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -406,6 +407,16 @@ def _read_header(
return MetricFileHeader(preamble=preamble, fieldnames=fieldnames)


def _is_metric_class(cls: Any) -> TypeGuard[Metric]:
"""True if the given class is a Metric."""

return (
isclass(cls)
and issubclass(cls, Metric)
and (dataclasses.is_dataclass(cls) or attr.has(cls))
)


def _is_dataclass_instance(metric: Metric) -> TypeGuard[DataclassInstance]:
"""
Test if the given metric is a dataclass instance.
Expand Down Expand Up @@ -466,3 +477,98 @@ def asdict(metric: Metric) -> Dict[str, Any]:
"The provided metric is not an instance of a `dataclass` or `attr.s`-decorated Metric "
f"class: {metric.__class__}"
)


def _get_fieldnames(metric_class: Type[Metric]) -> List[str]:
"""
Get the fieldnames of the specified metric class.
Args:
metric_class: A Metric class.
Returns:
A list of fieldnames.
"""
_assert_is_metric_class(metric_class)

if dataclasses.is_dataclass(metric_class):
return [f.name for f in dataclasses.fields(metric_class)]
elif attr.has(metric_class):
return [f.name for f in attr.fields(metric_class)]
else:
assert False, "Unreachable"


def _assert_file_header_matches_metric(
path: Path,
metric_class: Type[MetricType],
delimiter: str,
) -> None:
"""
Check that the specified file has a header and its fields match those of the provided Metric.
Args:
path: A path to a `Metric` file.
metric_class: The `Metric` class to validate against.
delimiter: The delimiter to use when reading the header.
Raises:
ValueError: If the provided file does not include a header.
ValueError: If the header of the provided file does not match the provided Metric.
"""
# NB: _get_fieldnames() will validate that `metric_class` is a valid Metric class.
fieldnames: List[str] = _get_fieldnames(metric_class)

header: MetricFileHeader
with path.open("r") as fin:
try:
header = metric_class._read_header(fin, delimiter=delimiter)
except ValueError:
raise ValueError(f"Could not find a header in the provided file: {path}")

if header.fieldnames != fieldnames:
raise ValueError(
"The provided file does not have the same field names as the provided Metric:\n"
f"\tMetric: {metric_class.__name__}\n"
f"\tFile: {path}\n"
f"\tExpected fields: {', '.join(fieldnames)}\n"
f"\tActual fields: {', '.join(header.fieldnames)}\n"
)


def _assert_fieldnames_are_metric_attributes(
specified_fieldnames: List[str],
metric_class: Type[MetricType],
) -> None:
"""
Check that all of the specified fields are attributes on the given Metric.
Raises:
ValueError: if any of the specified fieldnames are not an attribute on the given Metric.
"""
_assert_is_metric_class(metric_class)

invalid_fieldnames = {
f for f in specified_fieldnames if f not in _get_fieldnames(metric_class)
}

if len(invalid_fieldnames) > 0:
raise ValueError(
"One or more of the specified fields are not attributes on the Metric "
+ f"{metric_class.__name__}: "
+ ", ".join(invalid_fieldnames)
)


def _assert_is_metric_class(cls: Type[Metric]) -> None:
"""
Assert that the given class is a Metric.
Args:
cls: A class object.
Raises:
TypeError: If the given class is not a Metric.
"""
if not _is_metric_class(cls):
raise TypeError(f"Not a dataclass or attr decorated Metric: {cls}")
156 changes: 156 additions & 0 deletions fgpyo/util/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import enum
import gzip
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
Expand All @@ -29,6 +30,10 @@
from fgpyo.util.inspect import is_attr_class
from fgpyo.util.inspect import is_dataclasses_class
from fgpyo.util.metric import Metric
from fgpyo.util.metric import _assert_fieldnames_are_metric_attributes
from fgpyo.util.metric import _assert_file_header_matches_metric
from fgpyo.util.metric import _assert_is_metric_class
from fgpyo.util.metric import _get_fieldnames
from fgpyo.util.metric import _is_attrs_instance
from fgpyo.util.metric import _is_dataclass_instance
from fgpyo.util.metric import asdict
Expand Down Expand Up @@ -590,3 +595,154 @@ def test_read_header_can_read_picard(tmp_path: Path) -> None:
header = Metric._read_header(metrics_file, comment_prefix="#")

assert header.fieldnames == ["SAMPLE", "FOO", "BAR"]


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
def test_get_fieldnames(data_and_classes: DataBuilder) -> None:
"""Test we can get the fieldnames of a metric."""

assert _get_fieldnames(data_and_classes.Person) == ["name", "age"]


def test_fieldnames_raises_if_not_a_metric() -> None:
"""Test we raise if we get a non-metric."""

@dataclass
class BadMetric:
foo: str
bar: int

with pytest.raises(TypeError, match="Not a dataclass or attr decorated Metric"):
_get_fieldnames(BadMetric) # type: ignore[arg-type]


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
def test_assert_is_metric_class(data_and_classes: DataBuilder) -> None:
"""
Test that we can validate if a class is a Metric.
"""
try:
_assert_is_metric_class(data_and_classes.DummyMetric)
except TypeError:
raise AssertionError("Failed to validate a valid Metric") from None


def test_assert_is_metric_class_raises_if_not_decorated() -> None:
"""
Test that we raise an error if the provided type is a Metric subclass but not decorated as a
dataclass or attr.
"""

class BadMetric(Metric["BadMetric"]):
foo: str
bar: int

with pytest.raises(TypeError, match="Not a dataclass or attr decorated Metric"):
_assert_is_metric_class(BadMetric)


def test_assert_is_metric_class_raises_if_not_a_metric() -> None:
"""
Test that we raise an error if the provided type is decorated as a
dataclass or attr but does not subclass Metric.
"""

@dataclass
class BadMetric:
foo: str
bar: int

with pytest.raises(TypeError, match="Not a dataclass or attr decorated Metric"):
_assert_is_metric_class(BadMetric)

@attr.s
class BadMetric:
foo: str
bar: int

with pytest.raises(TypeError, match="Not a dataclass or attr decorated Metric"):
_assert_is_metric_class(BadMetric)


# fmt: off
@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
@pytest.mark.parametrize(
"fieldnames",
[
["name", "age"], # The fieldnames are all the attributes of the provided metric
["age", "name"], # The fieldnames are out of order
["name"], # The fieldnames are a subset of the attributes of the provided metric
],
)
# fmt: on
def test_assert_fieldnames_are_metric_attributes(
data_and_classes: DataBuilder,
fieldnames: List[str],
) -> None:
"""
Should not raise an error if the provided fieldnames are all attributes of the provided metric.
"""
try:
_assert_fieldnames_are_metric_attributes(fieldnames, data_and_classes.Person)
except Exception:
raise AssertionError("Fieldnames should be valid") from None


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
@pytest.mark.parametrize(
"fieldnames",
[
["name", "age", "foo"],
["name", "foo"],
["foo", "name", "age"],
["foo"],
],
)
def test_assert_fieldnames_are_metric_attributes_raises(
data_and_classes: DataBuilder,
fieldnames: List[str],
) -> None:
"""
Should raise an error if any of the provided fieldnames are not an attribute on the metric.
"""
with pytest.raises(ValueError, match="One or more of the specified fields are not "):
_assert_fieldnames_are_metric_attributes(fieldnames, data_and_classes.Person)


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
def test_assert_file_header_matches_metric(tmp_path: Path, data_and_classes: DataBuilder) -> None:
"""
Should not raise an error if the provided file header matches the provided metric.
"""
metric_path = tmp_path / "metrics.tsv"
with metric_path.open("w") as metrics_file:
metrics_file.write("name\tage\n")

try:
_assert_file_header_matches_metric(metric_path, data_and_classes.Person, delimiter="\t")
except Exception:
raise AssertionError("File header should be valid") from None


@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes))
@pytest.mark.parametrize(
"header",
[
["name"],
["age"],
["name", "age", "foo"],
["foo", "name", "age"],
],
)
def test_assert_file_header_matches_metric_raises(
tmp_path: Path, data_and_classes: DataBuilder, header: List[str]
) -> None:
"""
Should raise an error if the provided file header does not match the provided metric.
"""
metric_path = tmp_path / "metrics.tsv"
with metric_path.open("w") as metrics_file:
metrics_file.write("\t".join(header) + "\n")

with pytest.raises(ValueError, match="The provided file does not have the same field names"):
_assert_file_header_matches_metric(metric_path, data_and_classes.Person, delimiter="\t")

0 comments on commit d25ff12

Please sign in to comment.