From a1df5047e68309079840535eb961355d6ea07bd4 Mon Sep 17 00:00:00 2001 From: Ted Brookings Date: Tue, 21 Nov 2023 12:03:51 -0500 Subject: [PATCH] Allow Metric to use dataclasses or attr * Update util.metric and related util.inspect modules to work with dataclasses or attr * Update test_metric to test both dataclasses and attr classes Closes #45 --- docs/api.rst | 3 + fgpyo/read_structure.py | 8 +- fgpyo/sam/__init__.py | 35 ++- fgpyo/util/inspect.py | 111 +++++++- fgpyo/util/metric.py | 30 ++- fgpyo/util/tests/test_metric.py | 436 +++++++++++++++++++++----------- fgpyo/util/types.py | 1 + 7 files changed, 426 insertions(+), 198 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index 63da3ba7..641fa627 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -34,6 +34,9 @@ Metric files .. seealso:: + https://docs.python.org/3/library/dataclasses.html + Documentation for the dataclasses standard module + https://www.attrs.org/en/stable/examples.html The attrs website for bringing back the joy to writing classes. diff --git a/fgpyo/read_structure.py b/fgpyo/read_structure.py index 91d7622f..cd3f0360 100644 --- a/fgpyo/read_structure.py +++ b/fgpyo/read_structure.py @@ -87,7 +87,7 @@ def __str__(self) -> str: return self.value -@attr.s(frozen=True, auto_attribs=True, kw_only=True) +@attr.s(frozen=True, kw_only=True, auto_attribs=True) class SubReadWithoutQuals: """Contains the bases that correspond to the given read segment""" @@ -99,7 +99,7 @@ def kind(self) -> SegmentType: return self.segment.kind -@attr.s(frozen=True, auto_attribs=True, kw_only=True) +@attr.s(frozen=True, kw_only=True, auto_attribs=True) class SubReadWithQuals: """Contains the bases and qualities that correspond to the given read segment""" @@ -112,7 +112,7 @@ def kind(self) -> SegmentType: return self.segment.kind -@attr.s(frozen=True, auto_attribs=True, kw_only=True) +@attr.s(frozen=True, kw_only=True, auto_attribs=True) class ReadSegment: """Encapsulates all the information about a segment within a read structure. A segment can either have a definite length, in which case length must be Some(Int), or an indefinite length @@ -187,7 +187,7 @@ def __str__(self) -> str: return f"{ANY_LENGTH_CHAR}{self.kind.value}" -@attr.s(frozen=True, auto_attribs=True, kw_only=True) +@attr.s(frozen=True, kw_only=True, auto_attribs=True) class ReadStructure(Iterable[ReadSegment]): """Describes the structure of a give read. A read contains one or more read segments. A read segment describes a contiguous stretch of bases of the same type (ex. template bases) of some diff --git a/fgpyo/sam/__init__.py b/fgpyo/sam/__init__.py index caceadac..8de26f2e 100644 --- a/fgpyo/sam/__init__.py +++ b/fgpyo/sam/__init__.py @@ -385,7 +385,7 @@ def is_clipping(self) -> bool: return self == CigarOp.S or self == CigarOp.H -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class CigarElement: """Represents an element in a Cigar @@ -394,14 +394,13 @@ class CigarElement: - operator (CigarOp): the operator of the element """ - length: int = attr.ib() - operator: CigarOp = attr.ib() + length: int + operator: CigarOp - @length.validator - def _validate_length(self, attribute: Any, value: int) -> None: + def __post_init__(self) -> None: """Validates the length attribute is greater than zero.""" - if value <= 0: - raise ValueError(f"Cigar element must have a length > 0, found {value}") + if self.length <= 0: + raise ValueError(f"Cigar element must have a length > 0, found {self.length}") @property def length_on_query(self) -> int: @@ -423,7 +422,7 @@ class CigarParsingException(Exception): pass -@attr.s(frozen=True, slots=True) +@attr.s(frozen=True, slots=True, auto_attribs=True) class Cigar: """Class representing a cigar string. @@ -431,7 +430,7 @@ class Cigar: - elements (Tuple[CigarElement, ...]): zero or more cigar elements """ - elements: Tuple[CigarElement, ...] = attr.ib(default=()) + elements: Tuple[CigarElement, ...] = () @classmethod def from_cigartuples(cls, cigartuples: Optional[List[Tuple[int, int]]]) -> "Cigar": @@ -518,7 +517,7 @@ def length_on_target(self) -> int: return sum([elem.length_on_target for elem in self.elements]) -@attr.s(auto_attribs=True, frozen=True) +@attr.s(frozen=True, auto_attribs=True) class SupplementaryAlignment: """Stores a supplementary alignment record produced by BWA and stored in the SA SAM tag. @@ -531,12 +530,12 @@ class SupplementaryAlignment: nm: the number of edits """ - reference_name: str = attr.ib() - start: int = attr.ib() - is_forward: bool = attr.ib() - cigar: Cigar = attr.ib() - mapq: int = attr.ib() - nm: int = attr.ib() + reference_name: str + start: int + is_forward: bool + cigar: Cigar + mapq: int + nm: int def __str__(self) -> str: return ",".join( @@ -620,7 +619,7 @@ def set_pair_info(r1: AlignedSegment, r2: AlignedSegment, proper_pair: bool = Tr r2.template_length = -insert_size -@attr.s(auto_attribs=True, frozen=True) +@attr.s(frozen=True, auto_attribs=True) class ReadEditInfo: """ Counts various stats about how a read compares to a reference sequence. @@ -709,7 +708,7 @@ def calculate_edit_info( ) -@attr.s(auto_attribs=True, frozen=True) +@attr.s(frozen=True, auto_attribs=True) class Template: """A container for alignment records corresponding to a single sequenced template or insert. diff --git a/fgpyo/util/inspect.py b/fgpyo/util/inspect.py index 5ab8a4e8..f1d3f440 100644 --- a/fgpyo/util/inspect.py +++ b/fgpyo/util/inspect.py @@ -9,24 +9,85 @@ if sys.version_info >= (3, 8): from typing import Literal + from typing import Protocol else: from typing_extensions import Literal + from typing_extensions import Protocol if sys.version_info >= (3, 12): from typing import TypeAlias else: from typing_extensions import TypeAlias +import dataclasses import functools +from dataclasses import MISSING as DATACLASSES_MISSING +from dataclasses import fields as get_dataclasses_fields +from dataclasses import is_dataclass as is_dataclasses_class from enum import Enum from functools import partial from pathlib import PurePath +from typing import TYPE_CHECKING from typing import Callable from typing import Optional - -import attr +from typing import TypeVar import fgpyo.util.types as types +try: + import attr + + _use_attr = True + from attr import NOTHING as ATTR_NOTHING + from attr import fields as get_attr_fields + from attr import fields_dict as get_attr_fields_dict + + Attribute = attr.Attribute + + MISSING = {DATACLASSES_MISSING, ATTR_NOTHING} +except ImportError: + _use_attr = False + attr = None + ATTR_NOTHING = None + Attribute = TypeVar("Attribute", bound=object) # type: ignore + + def get_attr_fields(cls: type) -> Tuple[dataclasses.Field, ...]: # type: ignore + return () + + def get_attr_fields_dict(cls: type) -> Dict[str, dataclasses.Field]: # type: ignore + return {} + + MISSING = {DATACLASSES_MISSING} + +if TYPE_CHECKING: + from _typeshed import DataclassInstance as DataclassesProtocol +else: + + class DataclassesProtocol(Protocol): + __dataclasses_fields__: Dict[str, dataclasses.Field] + + +if TYPE_CHECKING and _use_attr: + from attr import AttrsInstance +else: + + class AttrsInstance(Protocol): # type: ignore + __attrs_attrs__: Dict[str, Any] + + +def is_attr_class(cls: type) -> bool: # type: ignore + return hasattr(cls, "__attrs_attrs__") + + +MISSING_OR_NONE = {*MISSING, None} +DataclassesOrAttrClass = Union[DataclassesProtocol, AttrsInstance] +FieldType: TypeAlias = Union[dataclasses.Field, attr.Attribute] + + +def get_dataclasses_fields_dict( + class_or_instance: Union[DataclassesProtocol, Type[DataclassesProtocol]], +) -> Dict[str, dataclasses.Field]: + return {field.name: field for field in get_dataclasses_fields(class_or_instance)} + class ParserNotFoundException(Exception): pass @@ -254,9 +315,36 @@ def dict_parse(dict_string: str) -> Dict[Any, Any]: return parser +def get_fields_dict(cls: Type[DataclassesOrAttrClass]) -> Dict[str, FieldType]: + """ + Get the fields dict from either a dataclasses or attr dataclass. + + Combine results in case someone chooses to mix them through inheritance. + """ + if not (is_dataclasses_class(cls) or is_attr_class(cls)): + raise ValueError("cls must a dataclasses or attr class") + return { + **(get_dataclasses_fields_dict(cls) if is_dataclasses_class(cls) else {}), + **(get_attr_fields_dict(cls) if is_attr_class(cls) else {}), # type: ignore + } + + +def get_fields(cls: Type[DataclassesOrAttrClass]) -> Tuple[FieldType, ...]: + if not (is_dataclasses_class(cls) or is_attr_class(cls)): + raise ValueError("cls must a dataclasses or attr class") + return (get_dataclasses_fields(cls) if is_dataclasses_class(cls) else ()) + ( + get_attr_fields(cls) if is_attr_class(cls) else () # type: ignore + ) + + +AttrFromType = TypeVar("AttrFromType") + + def attr_from( - cls: Type, kwargs: Dict[str, str], parsers: Optional[Dict[type, Callable[[str], Any]]] = None -) -> Any: + cls: Type[AttrFromType], + kwargs: Dict[str, str], + parsers: Optional[Dict[type, Callable[[str], Any]]] = None, +) -> AttrFromType: """Builds an attr class from key-word arguments Args: @@ -265,15 +353,16 @@ def attr_from( parsers: a dictionary of parser functions to apply to specific types """ return_values: Dict[str, Any] = {} - for attribute in attr.fields(cls): + for attribute in get_fields(cls): # type: ignore return_value: Any if attribute.name in kwargs: str_value: str = kwargs[attribute.name] set_value: bool = False # Use the converter if provided - if attribute.converter is not None: - return_value = attribute.converter(str_value) + converter = getattr(attribute, "converter", None) + if converter is not None: + return_value = converter(str_value) set_value = True # try getting a known parser @@ -305,7 +394,7 @@ def attr_from( ), f"No value given and no default for attribute `{attribute.name}`" return_value = attribute.default # when the default is attr.NOTHING, just use None - if return_value is attr.NOTHING: + if return_value in MISSING: return_value = None return_values[attribute.name] = return_value @@ -313,13 +402,13 @@ def attr_from( return cls(**return_values) -def attribute_is_optional(attribute: attr.Attribute) -> bool: +def attribute_is_optional(attribute: FieldType) -> bool: """Returns True if the attribute is optional, False otherwise""" return types.get_origin_type(attribute.type) is Union and isinstance( None, types.get_arg_types(attribute.type) ) -def attribute_has_default(attribute: attr.Attribute) -> bool: +def attribute_has_default(attribute: FieldType) -> bool: """Returns True if the attribute has a default value, False otherwise""" - return attribute.default != attr.NOTHING or attribute_is_optional(attribute) + return attribute.default not in MISSING_OR_NONE or attribute_is_optional(attribute) diff --git a/fgpyo/util/metric.py b/fgpyo/util/metric.py index 9cda23ee..5668607e 100644 --- a/fgpyo/util/metric.py +++ b/fgpyo/util/metric.py @@ -10,14 +10,25 @@ The :class:`~fgpyo.util.metric.Metric` class makes it easy to read, write, and store one or metrics of the same type, all the while preserving types for each value in a metric. It is an abstract -base class decorated by `attr `_, with attributes -storing one or more typed values. +base class decorated by `dataclassees `_, or +`attr `_, with attributes storing one or more typed +values. Examples ~~~~~~~~ Defining a new metric class: +.. code-block:: python + + >>> from fgpyo.util.metric import Metric + >>> import dataclasses + >>> @dataclasses.dataclass(frozen=True) + ... class Person(Metric["Person"]): + ... name: str + ... age: int + +or using attr: .. code-block:: python >>> from fgpyo.util.metric import Metric @@ -75,7 +86,7 @@ .. code-block:: python - >>> @attr.s(auto_attribs=True, frozen=True) + >>> @dataclasses.dataclass(frozen=True) ... class Name: ... first: str ... last: str @@ -83,7 +94,7 @@ ... def parse(cls, value: str) -> "Name": ... fields = value.split(" ") ... return Name(first=fields[0], last=fields[1]) - >>> @attr.s(auto_attribs=True, frozen=True) + >>> @dataclasses.dataclass(frozen=True) ... class Person(Metric["Person"]): ... name: Name ... age: int @@ -101,7 +112,6 @@ ["first last", "42"] """ - from abc import ABC from enum import Enum from pathlib import Path @@ -113,15 +123,12 @@ from typing import List from typing import TypeVar -import attr - from fgpyo import io from fgpyo.util import inspect MetricType = TypeVar("MetricType", bound="Metric") -@attr.s class Metric(ABC, Generic[MetricType]): """Abstract base class for all metric-like tab-delimited files @@ -135,7 +142,8 @@ class Metric(ABC, Generic[MetricType]): def values(self) -> Iterator[Any]: """An iterator over attribute values in the same order as the header.""" - return iter(attr.astuple(self, recurse=False)) + for field in inspect.get_fields(self.__class__): # type: ignore + yield getattr(self, field.name) def formatted_values(self) -> List[str]: """An iterator over formatted attribute values in the same order as the header.""" @@ -170,7 +178,7 @@ def read(cls, path: Path, ignore_extra_fields: bool = True) -> Iterator[Any]: missing_from_class = file_fields.difference(class_fields) missing_from_file = class_fields.difference(file_fields) - field_name_to_attribute = attr.fields_dict(cls) + field_name_to_attribute = inspect.get_fields_dict(cls) # type: ignore # ignore class fields that are missing from the file (via header) if they're optional # or have a default @@ -247,7 +255,7 @@ def write(cls, path: Path, *values: MetricType) -> None: @classmethod def header(cls) -> List[str]: """The list of header values for the metric.""" - return [a.name for a in attr.fields(cls)] + return [a.name for a in inspect.get_fields(cls)] # type: ignore @classmethod def format_value(cls, value: Any) -> str: diff --git a/fgpyo/util/tests/test_metric.py b/fgpyo/util/tests/test_metric.py index d84ed4b1..f4bdc696 100644 --- a/fgpyo/util/tests/test_metric.py +++ b/fgpyo/util/tests/test_metric.py @@ -1,5 +1,9 @@ +# attr and dataclasses are both nightmares for type-checking, and trying to combine them both in +# an if-statement is a level of Hell that Dante never conceived of. Turning off mypy for this file: +# mypy: ignore-errors import enum import gzip +import sys from pathlib import Path from typing import Any from typing import Callable @@ -8,10 +12,22 @@ from typing import Optional from typing import Set from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union + +if sys.version_info >= (3, 12): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + +import dataclasses import attr import pytest +from fgpyo.util.inspect import is_attr_class +from fgpyo.util.inspect import is_dataclasses_class from fgpyo.util.metric import Metric @@ -21,128 +37,208 @@ class EnumTest(enum.Enum): EnumVal3 = "val3" -@attr.s(auto_attribs=True, frozen=True) -class DummyMetric(Metric["DummyMetric"]): - int_value: int - str_value: str - bool_val: bool - enum_val: EnumTest = attr.ib() - optional_str_value: Optional[str] = attr.ib() - optional_int_value: Optional[int] = attr.ib() - optional_bool_value: Optional[bool] = attr.ib() - optional_enum_value: Optional[EnumTest] = attr.ib() - dict_value: Dict[int, str] = attr.ib() - tuple_value: Tuple[int, str] = attr.ib() - list_value: List[str] = attr.ib() - complex_value: Dict[ - int, - Dict[ - Tuple[int, int], - Set[str], - ], - ] = attr.ib() - - -DUMMY_METRICS: List[DummyMetric] = [ - DummyMetric( - int_value=1, - str_value="2", - bool_val=True, - enum_val=EnumTest.EnumVal1, - optional_str_value="test4", - optional_int_value=-5, - optional_bool_value=True, - optional_enum_value=EnumTest.EnumVal3, - dict_value={ - 1: "test1", - }, - tuple_value=(0, "test1"), - list_value=[], - complex_value={1: {(5, 1): set({"mapped_test_val1", "setval2"})}}, - ), - DummyMetric( - int_value=1, - str_value="2", - bool_val=False, - enum_val=EnumTest.EnumVal2, - optional_str_value="test", - optional_int_value=1, - optional_bool_value=False, - optional_enum_value=EnumTest.EnumVal1, - dict_value={2: "test2", 7: "test4"}, - tuple_value=(1, "test2"), - list_value=["1"], - complex_value={2: {(-5, 1): set({"mapped_test_val2", "setval2"})}}, - ), - DummyMetric( - int_value=1, - str_value="2", - bool_val=False, - enum_val=EnumTest.EnumVal3, - optional_str_value=None, - optional_int_value=None, - optional_bool_value=None, - optional_enum_value=None, - dict_value={}, - tuple_value=(2, "test3"), - list_value=["1", "2", "3"], - complex_value={3: {(8, 1): set({"mapped_test_val3", "setval2"})}}, - ), -] - - -@attr.s(auto_attribs=True, frozen=True) -class Person(Metric["Person"]): - name: Optional[str] - age: Optional[int] - - -@attr.s(auto_attribs=True, frozen=True) -class Name: - first: str - last: str - - @classmethod - def parse(cls, value: str) -> "Name": - fields = value.split(" ") - return Name(first=fields[0], last=fields[1]) - - -@attr.s(auto_attribs=True, frozen=True) -class NamedPerson(Metric["NamedPerson"]): - name: Name - age: int - - @classmethod - def _parsers(cls) -> Dict[type, Callable[[str], Any]]: - return {Name: lambda value: Name.parse(value=value)} - - @classmethod - def format_value(cls, value: Any) -> str: - if isinstance(value, (Name)): - return f"{value.first} {value.last}" - else: - return super().format_value(value=value) - - -@attr.s(auto_attribs=True, frozen=True) -class PersonMaybeAge(Metric["PersonMaybeAge"]): - name: str - age: Optional[int] - - -@attr.s(auto_attribs=True, frozen=True) -class PersonDefault(Metric["PersonDefault"]): - name: str - age: int = 0 - - -@pytest.mark.parametrize("metric", DUMMY_METRICS) +T = TypeVar("T", bound=Type) + + +def make_dataclass(use_attr: bool = False) -> Callable[[T], T]: + """Decorator to make a attr- or dataclasses-style dataclass""" + sys.stderr.write(f"use_attr = {use_attr}\n") + if use_attr: + + def make_attr(cls: T) -> T: + return attr.s(auto_attribs=True, frozen=True)(cls) + + return make_attr + else: + + def make_dataclasses(cls: T) -> T: + return dataclasses.dataclass(frozen=True)(cls) + + return make_dataclasses + + +class DataBuilder: + """Holds classes and data for testing, either usting attr- or dataclasses-style dataclass""" + + def __init__(self, use_attr: bool) -> None: + self.use_attr = use_attr + + @make_dataclass(use_attr=use_attr) + class DummyMetric(Metric["DummyMetric"]): + int_value: int + str_value: str + bool_val: bool + enum_val: EnumTest + optional_str_value: Optional[str] + optional_int_value: Optional[int] + optional_bool_value: Optional[bool] + optional_enum_value: Optional[EnumTest] + dict_value: Dict[int, str] + tuple_value: Tuple[int, str] + list_value: List[str] + complex_value: Dict[ + int, + Dict[ + Tuple[int, int], + Set[str], + ], + ] + + @make_dataclass(use_attr=use_attr) + class Person(Metric["Person"]): + name: Optional[str] + age: Optional[int] + + @make_dataclass(use_attr=use_attr) + class Name: + first: str + last: str + + @classmethod + def parse(cls, value: str) -> "Name": + fields = value.split(" ") + return Name(first=fields[0], last=fields[1]) + + @make_dataclass(use_attr=use_attr) + class NamedPerson(Metric["NamedPerson"]): + name: Name + age: int + + @classmethod + def _parsers(cls) -> Dict[type, Callable[[str], Any]]: + return {Name: lambda value: Name.parse(value=value)} + + @classmethod + def format_value(cls, value: Any) -> str: + if isinstance(value, (Name)): + return f"{value.first} {value.last}" + else: + return super().format_value(value=value) + + @make_dataclass(use_attr=use_attr) + class PersonMaybeAge(Metric["PersonMaybeAge"]): + name: str + age: Optional[int] + + @make_dataclass(use_attr=use_attr) + class PersonDefault(Metric["PersonDefault"]): + name: str + age: int = 0 + + @make_dataclass(use_attr=use_attr) + class ListPerson(Metric["ListPerson"]): + name: List[Optional[str]] + age: List[Optional[int]] + + self.DummyMetric = DummyMetric + self.Person = Person + self.Name = Name + self.NamedPerson = NamedPerson + self.PersonMaybeAge = PersonMaybeAge + self.PersonDefault = PersonDefault + self.ListPerson = ListPerson + + self.DUMMY_METRICS: List[DummyMetric] = [ + DummyMetric( + int_value=1, + str_value="2", + bool_val=True, + enum_val=EnumTest.EnumVal1, + optional_str_value="test4", + optional_int_value=-5, + optional_bool_value=True, + optional_enum_value=EnumTest.EnumVal3, + dict_value={ + 1: "test1", + }, + tuple_value=(0, "test1"), + list_value=[], + complex_value={1: {(5, 1): {"mapped_test_val1", "setval2"}}}, + ), + DummyMetric( + int_value=1, + str_value="2", + bool_val=False, + enum_val=EnumTest.EnumVal2, + optional_str_value="test", + optional_int_value=1, + optional_bool_value=False, + optional_enum_value=EnumTest.EnumVal1, + dict_value={2: "test2", 7: "test4"}, + tuple_value=(1, "test2"), + list_value=["1"], + complex_value={2: {(-5, 1): {"mapped_test_val2", "setval2"}}}, + ), + DummyMetric( + int_value=1, + str_value="2", + bool_val=False, + enum_val=EnumTest.EnumVal3, + optional_str_value=None, + optional_int_value=None, + optional_bool_value=None, + optional_enum_value=None, + dict_value={}, + tuple_value=(2, "test3"), + list_value=["1", "2", "3"], + complex_value={3: {(8, 1): {"mapped_test_val3", "setval2"}}}, + ), + ] + + +attr_data_and_classes = DataBuilder(use_attr=True) +dataclasses_data_and_classes = DataBuilder(use_attr=False) + +AnyDummyMetric = Union[attr_data_and_classes.DummyMetric, dataclasses_data_and_classes.DummyMetric] +num_metrics = len(attr_data_and_classes.DUMMY_METRICS) + + +@pytest.mark.parametrize("use_attr", [False, True]) +def test_is_correct_dataclass_type(use_attr: bool) -> None: + data_and_classes = DataBuilder(use_attr=use_attr) + assert use_attr == data_and_classes.use_attr + assert is_attr_class(data_and_classes.DummyMetric) is use_attr + assert is_dataclasses_class(data_and_classes.DummyMetric) is not use_attr + assert is_attr_class(data_and_classes.Person) is use_attr + assert is_dataclasses_class(data_and_classes.Person) is not use_attr + assert is_attr_class(data_and_classes.Name) is use_attr + assert is_dataclasses_class(data_and_classes.Name) is not use_attr + assert is_attr_class(data_and_classes.NamedPerson) is use_attr + assert is_dataclasses_class(data_and_classes.NamedPerson) is not use_attr + assert is_attr_class(data_and_classes.PersonMaybeAge) is use_attr + assert is_dataclasses_class(data_and_classes.PersonMaybeAge) is not use_attr + assert is_attr_class(data_and_classes.PersonDefault) is use_attr + assert is_dataclasses_class(data_and_classes.PersonDefault) is not use_attr + assert len(data_and_classes.DUMMY_METRICS) == num_metrics + + +def pytest_generate_tests(metafunc: Any) -> None: + if "DummyMetric" in metafunc.fixturenames: + metafunc.parametrize( + "DummyMetric", + [attr_data_and_classes.DummyMetric, dataclasses_data_and_classes.DummyMetric], + ) + if "DUMMY_METRICS" in metafunc.fixturenames: + metafunc.parametrize( + "DUMMY_METRICS", + [attr_data_and_classes.DUMMY_METRICS, dataclasses_data_and_classes.DUMMY_METRICS], + ) + + +@pytest.fixture(scope="function") +def metric(request, data_and_classes: DataBuilder) -> AnyDummyMetric: + yield data_and_classes.DUMMY_METRICS[request.param] + + +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +@pytest.mark.parametrize("metric", range(num_metrics), indirect=True) def test_metric_roundtrip( tmp_path: Path, - metric: DummyMetric, + data_and_classes: DataBuilder, + metric: AnyDummyMetric, ) -> None: path: Path = tmp_path / "metrics.txt" + DummyMetric: TypeAlias = data_and_classes.DummyMetric DummyMetric.write(path, metric) metrics: List[DummyMetric] = list(DummyMetric.read(path=path)) @@ -151,31 +247,37 @@ def test_metric_roundtrip( assert metrics[0] == metric -def test_metrics_roundtrip(tmp_path: Path) -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metrics_roundtrip(tmp_path: Path, data_and_classes: DataBuilder) -> None: path: Path = tmp_path / "metrics.txt" + DummyMetric: TypeAlias = data_and_classes.DummyMetric - DummyMetric.write(path, *DUMMY_METRICS) + DummyMetric.write(path, *data_and_classes.DUMMY_METRICS) metrics: List[DummyMetric] = list(DummyMetric.read(path=path)) - assert len(metrics) == len(DUMMY_METRICS) - assert metrics == DUMMY_METRICS + assert len(metrics) == len(data_and_classes.DUMMY_METRICS) + assert metrics == data_and_classes.DUMMY_METRICS -def test_metrics_roundtrip_gzip(tmp_path: Path) -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metrics_roundtrip_gzip(tmp_path: Path, data_and_classes: DataBuilder) -> None: path: Path = Path(tmp_path) / "metrics.txt.gz" + DummyMetric: Type[Metric] = data_and_classes.DummyMetric - DummyMetric.write(path, *DUMMY_METRICS) + DummyMetric.write(path, *data_and_classes.DUMMY_METRICS) with gzip.open(path, "r") as handle: handle.read(1) # Will raise an exception if not a GZIP file. metrics: List[DummyMetric] = list(DummyMetric.read(path=path)) - assert len(metrics) == len(DUMMY_METRICS) - assert metrics == DUMMY_METRICS + assert len(metrics) == len(data_and_classes.DUMMY_METRICS) + assert metrics == data_and_classes.DUMMY_METRICS -def test_metrics_read_extra_columns(tmp_path: Path) -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metrics_read_extra_columns(tmp_path: Path, data_and_classes: DataBuilder) -> None: + Person: TypeAlias = data_and_classes.Person person = Person(name="Max", age=42) path = tmp_path / "metrics.txt" with path.open("w") as writer: @@ -190,7 +292,11 @@ def test_metrics_read_extra_columns(tmp_path: Path) -> None: list(Person.read(path=path, ignore_extra_fields=False)) -def test_metrics_read_missing_optional_columns(tmp_path: Path) -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metrics_read_missing_optional_columns( + tmp_path: Path, data_and_classes: DataBuilder +) -> None: + PersonMaybeAge: TypeAlias = data_and_classes.PersonMaybeAge person = PersonMaybeAge(name="Max", age=None) path = tmp_path / "metrics.txt" @@ -206,7 +312,11 @@ def test_metrics_read_missing_optional_columns(tmp_path: Path) -> None: list(PersonMaybeAge.read(path=path)) -def test_metric_read_missing_column_with_default(tmp_path: Path) -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_read_missing_column_with_default( + tmp_path: Path, data_and_classes: DataBuilder +) -> None: + PersonDefault: TypeAlias = data_and_classes.PersonDefault person = PersonDefault(name="Max") path = tmp_path / "metrics.txt" @@ -227,8 +337,9 @@ def test_metric_read_missing_column_with_default(tmp_path: Path) -> None: list(PersonDefault.read(path=path)) -def test_metric_header() -> None: - assert DummyMetric.header() == [ +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_header(data_and_classes: DataBuilder) -> None: + assert data_and_classes.DummyMetric.header() == [ "int_value", "str_value", "bool_val", @@ -244,60 +355,72 @@ def test_metric_header() -> None: ] -def test_metric_values() -> None: - assert list(Person(name="name", age=42).values()) == ["name", 42] +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_values(data_and_classes: DataBuilder) -> None: + assert list(data_and_classes.Person(name="name", age=42).values()) == ["name", 42] -def test_metric_parse() -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_parse(data_and_classes: DataBuilder) -> None: + Person: TypeAlias = data_and_classes.Person assert Person.parse(fields=["name", "42"]) == Person(name="name", age=42) -def test_metric_formatted_values() -> None: - assert Person(name="name", age=42).formatted_values() == (["name", "42"]) +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_formatted_values(data_and_classes: DataBuilder) -> None: + assert data_and_classes.Person(name="name", age=42).formatted_values() == (["name", "42"]) -def test_metric_custom_parser() -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_custom_parser(data_and_classes: DataBuilder) -> None: + NamedPerson: TypeAlias = data_and_classes.NamedPerson assert NamedPerson.parse(fields=["john doe", "42"]) == ( - NamedPerson(name=Name(first="john", last="doe"), age=42) + NamedPerson(name=data_and_classes.Name(first="john", last="doe"), age=42) ) -def test_metric_custom_formatter() -> None: - person = NamedPerson(name=Name(first="john", last="doe"), age=42) +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_custom_formatter(data_and_classes: DataBuilder) -> None: + person = data_and_classes.NamedPerson( + name=data_and_classes.Name(first="john", last="doe"), age=42 + ) assert list(person.formatted_values()) == ["john doe", "42"] -def test_metric_parse_with_none() -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_parse_with_none(data_and_classes: DataBuilder) -> None: + Person: TypeAlias = data_and_classes.Person assert Person.parse(fields=["", "40"]) == Person(name=None, age=40) assert Person.parse(fields=["Sally", ""]) == Person(name="Sally", age=None) assert Person.parse(fields=["", ""]) == Person(name=None, age=None) -def test_metric_formatted_values_with_empty_string() -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_formatted_values_with_empty_string(data_and_classes: DataBuilder) -> None: + Person: TypeAlias = data_and_classes.Person assert Person(name=None, age=42).formatted_values() == (["", "42"]) assert Person(name="Sally", age=None).formatted_values() == (["Sally", ""]) assert Person(name=None, age=None).formatted_values() == (["", ""]) -@attr.s(auto_attribs=True, frozen=True) -class ListPerson(Metric["ListPerson"]): - name: List[Optional[str]] - age: List[Optional[int]] - - -def test_metric_list_format() -> None: - assert ListPerson(name=["Max", "Sally"], age=[43, 55]).formatted_values() == ( +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_list_format(data_and_classes: DataBuilder) -> None: + assert data_and_classes.ListPerson(name=["Max", "Sally"], age=[43, 55]).formatted_values() == ( ["Max,Sally", "43,55"] ) -def test_metric_list_parse() -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_list_parse(data_and_classes: DataBuilder) -> None: + ListPerson: TypeAlias = data_and_classes.ListPerson assert ListPerson.parse(fields=["Max,Sally", "43, 55"]) == ListPerson( name=["Max", "Sally"], age=[43, 55] ) -def test_metric_list_format_with_empty_string() -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_list_format_with_empty_string(data_and_classes: DataBuilder) -> None: + ListPerson: TypeAlias = data_and_classes.ListPerson assert ListPerson(name=[None, "Sally"], age=[43, 55]).formatted_values() == ( [",Sally", "43,55"] ) @@ -309,7 +432,9 @@ def test_metric_list_format_with_empty_string() -> None: ) -def test_metric_list_parse_with_none() -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metric_list_parse_with_none(data_and_classes: DataBuilder) -> None: + ListPerson: TypeAlias = data_and_classes.ListPerson assert ListPerson.parse(fields=[",Sally", "40, 30"]) == ListPerson( name=[None, "Sally"], age=[40, 30] ) @@ -321,13 +446,16 @@ def test_metric_list_parse_with_none() -> None: ) -def test_metrics_fast_concat(tmp_path: Path) -> None: +@pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) +def test_metrics_fast_concat(tmp_path: Path, data_and_classes: DataBuilder) -> None: path_input = [ tmp_path / "metrics_1.txt", tmp_path / "metrics_2.txt", tmp_path / "metrics_3.txt", ] path_output: Path = tmp_path / "metrics_concat.txt" + DummyMetric: TypeAlias = data_and_classes.DummyMetric + DUMMY_METRICS: list[DummyMetric] = data_and_classes.DUMMY_METRICS DummyMetric.write(path_input[0], DUMMY_METRICS[0]) DummyMetric.write(path_input[1], DUMMY_METRICS[1]) diff --git a/fgpyo/util/types.py b/fgpyo/util/types.py index 822c4dc2..f8891756 100644 --- a/fgpyo/util/types.py +++ b/fgpyo/util/types.py @@ -129,6 +129,7 @@ def _make_union_parser_worker( return None except (ValueError, InspectException): pass + for p in parsers: try: return p(value)