Skip to content

Commit

Permalink
Allow Metric to use dataclasses or attr
Browse files Browse the repository at this point in the history
* Update util.metric and related util.inspect modules to work with
  dataclasses or attr
* Update test_metric to test both dataclasses and attr classes
* Remove attr from non-test requirements

Closes Issue #45
  • Loading branch information
TedBrookings committed Nov 27, 2023
1 parent d841c37 commit 593c01b
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 206 deletions.
3 changes: 3 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Metric files

.. seealso::

https://docs.python.org/3/library/dataclasses.html
Documentation for the dataclasses standard module

https://www.attrs.org/en/stable/examples.html

The attrs website for bringing back the joy to writing classes.
Expand Down
17 changes: 8 additions & 9 deletions fgpyo/read_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,14 @@
- :class:`~fgpyo.read_structure.SubReadWithQuals` -- Contains the bases and qualities that
correspond to the given read segment
"""
import dataclasses
import enum
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple

import attr

# A character that can be put in place of a number in a read structure to mean "0 or more bases".
ANY_LENGTH_CHAR: str = "+"

Expand All @@ -87,7 +86,7 @@ def __str__(self) -> str:
return self.value


@attr.s(frozen=True, auto_attribs=True, kw_only=True)
@dataclasses.dataclass(frozen=True, kw_only=True)
class SubReadWithoutQuals:
"""Contains the bases that correspond to the given read segment"""

Expand All @@ -99,7 +98,7 @@ def kind(self) -> SegmentType:
return self.segment.kind


@attr.s(frozen=True, auto_attribs=True, kw_only=True)
@dataclasses.dataclass(frozen=True, kw_only=True)
class SubReadWithQuals:
"""Contains the bases and qualities that correspond to the given read segment"""

Expand All @@ -112,7 +111,7 @@ def kind(self) -> SegmentType:
return self.segment.kind


@attr.s(frozen=True, auto_attribs=True, kw_only=True)
@dataclasses.dataclass(frozen=True, kw_only=True)
class ReadSegment:
"""Encapsulates all the information about a segment within a read structure. A segment can
either have a definite length, in which case length must be Some(Int), or an indefinite length
Expand Down Expand Up @@ -178,7 +177,7 @@ def _resized(self, end: int) -> "ReadSegment":
if self.has_fixed_length and self.fixed_length == new_length:
return self
else:
return attr.evolve(self, length=new_length)
return dataclasses.replace(self, length=new_length)

def __str__(self) -> str:
if self.has_fixed_length:
Expand All @@ -187,7 +186,7 @@ def __str__(self) -> str:
return f"{ANY_LENGTH_CHAR}{self.kind.value}"


@attr.s(frozen=True, auto_attribs=True, kw_only=True)
@dataclasses.dataclass(frozen=True, kw_only=True)
class ReadStructure(Iterable[ReadSegment]):
"""Describes the structure of a give read. A read contains one or more read segments. A read
segment describes a contiguous stretch of bases of the same type (ex. template bases) of some
Expand Down Expand Up @@ -231,7 +230,7 @@ def with_variable_last_segment(self) -> "ReadStructure":
if not last_segment.has_fixed_length:
return self
else:
last_segment = attr.evolve(last_segment, length=None)
last_segment = dataclasses.replace(last_segment, length=None)
return ReadStructure(segments=self.segments[:-1] + (last_segment,))

def extract(self, bases: str) -> Tuple[SubReadWithoutQuals, ...]:
Expand Down Expand Up @@ -285,7 +284,7 @@ def from_segments(
off = 0
segs = []
for seg in segments:
seg = attr.evolve(seg, offset=off)
seg = dataclasses.replace(seg, offset=off)
off += seg.length if seg.has_fixed_length else 0
segs.append(seg)
segments = tuple(segs)
Expand Down
37 changes: 18 additions & 19 deletions fgpyo/sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
- :func:`~fgpyo.sam.calc_edit_info` -- calculates how a read differs from the reference
"""

import dataclasses
import enum
import io
from pathlib import Path
Expand All @@ -164,7 +165,6 @@
from typing import Tuple
from typing import Union

import attr
import pysam
from pysam import AlignedSegment
from pysam import AlignmentFile as SamFile
Expand Down Expand Up @@ -385,7 +385,7 @@ def is_clipping(self) -> bool:
return self == CigarOp.S or self == CigarOp.H


@attr.s(frozen=True, slots=True)
@dataclasses.dataclass(frozen=True, slots=True)
class CigarElement:
"""Represents an element in a Cigar
Expand All @@ -394,14 +394,13 @@ class CigarElement:
- operator (CigarOp): the operator of the element
"""

length: int = attr.ib()
operator: CigarOp = attr.ib()
length: int
operator: CigarOp

@length.validator
def _validate_length(self, attribute: Any, value: int) -> None:
def __post_init__(self) -> None:
"""Validates the length attribute is greater than zero."""
if value <= 0:
raise ValueError(f"Cigar element must have a length > 0, found {value}")
if self.length <= 0:
raise ValueError(f"Cigar element must have a length > 0, found {self.length}")

@property
def length_on_query(self) -> int:
Expand All @@ -423,15 +422,15 @@ class CigarParsingException(Exception):
pass


@attr.s(frozen=True, slots=True)
@dataclasses.dataclass(frozen=True, slots=True)
class Cigar:
"""Class representing a cigar string.
Attributes:
- elements (Tuple[CigarElement, ...]): zero or more cigar elements
"""

elements: Tuple[CigarElement, ...] = attr.ib(default=())
elements: Tuple[CigarElement, ...] = ()

@classmethod
def from_cigartuples(cls, cigartuples: Optional[List[Tuple[int, int]]]) -> "Cigar":
Expand Down Expand Up @@ -518,7 +517,7 @@ def length_on_target(self) -> int:
return sum([elem.length_on_target for elem in self.elements])


@attr.s(auto_attribs=True, frozen=True)
@dataclasses.dataclass(frozen=True)
class SupplementaryAlignment:
"""Stores a supplementary alignment record produced by BWA and stored in the SA SAM tag.
Expand All @@ -531,12 +530,12 @@ class SupplementaryAlignment:
nm: the number of edits
"""

reference_name: str = attr.ib()
start: int = attr.ib()
is_forward: bool = attr.ib()
cigar: Cigar = attr.ib()
mapq: int = attr.ib()
nm: int = attr.ib()
reference_name: str
start: int
is_forward: bool
cigar: Cigar
mapq: int
nm: int

def __str__(self) -> str:
return ",".join(
Expand Down Expand Up @@ -620,7 +619,7 @@ def set_pair_info(r1: AlignedSegment, r2: AlignedSegment, proper_pair: bool = Tr
r2.template_length = -insert_size


@attr.s(auto_attribs=True, frozen=True)
@dataclasses.dataclass(frozen=True)
class ReadEditInfo:
"""
Counts various stats about how a read compares to a reference sequence.
Expand Down Expand Up @@ -709,7 +708,7 @@ def calculate_edit_info(
)


@attr.s(auto_attribs=True, frozen=True)
@dataclasses.dataclass(frozen=True)
class Template:
"""A container for alignment records corresponding to a single sequenced template
or insert.
Expand Down
110 changes: 99 additions & 11 deletions fgpyo/util/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict
from typing import Iterable
from typing import List
from typing import Protocol
from typing import Tuple
from typing import Type
from typing import Union
Expand All @@ -16,17 +17,76 @@
else:
from typing_extensions import TypeAlias

import dataclasses
import functools
from dataclasses import MISSING as DATACLASSES_MISSING
from dataclasses import fields as get_dataclasses_fields
from dataclasses import is_dataclass as is_dataclasses_class
from enum import Enum
from functools import partial
from pathlib import PurePath
from typing import TYPE_CHECKING
from typing import Callable
from typing import Optional

import attr
from typing import TypeVar

import fgpyo.util.types as types

try:
import attr

_use_attr = True
from attr import NOTHING as ATTR_NOTHING
from attr import fields as get_attr_fields
from attr import fields_dict as get_attr_fields_dict

Attribute = attr.Attribute

MISSING = {DATACLASSES_MISSING, ATTR_NOTHING}
except ImportError:
_use_attr = False
attr = None
ATTR_NOTHING = None
Attribute = TypeVar("Attribute", bound=object) # type: ignore

def get_attr_fields(cls: type) -> Tuple[dataclasses.Field, ...]: # type: ignore
return ()

def get_attr_fields_dict(cls: type) -> Dict[str, dataclasses.Field]: # type: ignore
return {}

MISSING = {DATACLASSES_MISSING}

if TYPE_CHECKING:
from _typeshed import DataclassInstance as DataclassesProtocol
else:

class DataclassesProtocol(Protocol):
__dataclasses_fields__: Dict[str, dataclasses.Field]


if TYPE_CHECKING and _use_attr:
from attr import AttrsInstance
else:

class AttrsInstance(Protocol): # type: ignore
__attrs_attrs__: Dict[str, Any]


def is_attr_class(cls: type) -> bool: # type: ignore
return hasattr(cls, "__attrs_attrs__")


MISSING_OR_NONE = {*MISSING, None}
DataclassesOrAttrClass = Union[DataclassesProtocol, AttrsInstance]
FieldType: TypeAlias = Union[dataclasses.Field, attr.Attribute]


def get_dataclasses_fields_dict(
class_or_instance: Union[DataclassesProtocol, Type[DataclassesProtocol]],
) -> Dict[str, dataclasses.Field]:
return {field.name: field for field in get_dataclasses_fields(class_or_instance)}


class ParserNotFoundException(Exception):
pass
Expand Down Expand Up @@ -254,9 +314,36 @@ def dict_parse(dict_string: str) -> Dict[Any, Any]:
return parser


def get_fields_dict(cls: Type[DataclassesOrAttrClass]) -> Dict[str, FieldType]:
"""
Get the fields dict from either a dataclasses or attr dataclass.
Combine results in case someone chooses to mix them through inheritance.
"""
if not (is_dataclasses_class(cls) or is_attr_class(cls)):
raise ValueError("cls must a dataclasses or attr class")
return {
**(get_dataclasses_fields_dict(cls) if is_dataclasses_class(cls) else {}),
**(get_attr_fields_dict(cls) if is_attr_class(cls) else {}), # type: ignore
}


def get_fields(cls: Type[DataclassesOrAttrClass]) -> Tuple[FieldType, ...]:
if not (is_dataclasses_class(cls) or is_attr_class(cls)):
raise ValueError("cls must a dataclasses or attr class")
return (get_dataclasses_fields(cls) if is_dataclasses_class(cls) else ()) + (
get_attr_fields(cls) if is_attr_class(cls) else () # type: ignore
)


AttrFromType = TypeVar("AttrFromType")


def attr_from(
cls: Type, kwargs: Dict[str, str], parsers: Optional[Dict[type, Callable[[str], Any]]] = None
) -> Any:
cls: Type[AttrFromType],
kwargs: Dict[str, str],
parsers: Optional[Dict[type, Callable[[str], Any]]] = None,
) -> AttrFromType:
"""Builds an attr class from key-word arguments
Args:
Expand All @@ -265,15 +352,16 @@ def attr_from(
parsers: a dictionary of parser functions to apply to specific types
"""
return_values: Dict[str, Any] = {}
for attribute in attr.fields(cls):
for attribute in get_fields(cls): # type: ignore
return_value: Any
if attribute.name in kwargs:
str_value: str = kwargs[attribute.name]
set_value: bool = False

# Use the converter if provided
if attribute.converter is not None:
return_value = attribute.converter(str_value)
converter = getattr(attribute, "converter", None)
if converter is not None:
return_value = converter(str_value)
set_value = True

# try getting a known parser
Expand Down Expand Up @@ -305,21 +393,21 @@ def attr_from(
), f"No value given and no default for attribute `{attribute.name}`"
return_value = attribute.default
# when the default is attr.NOTHING, just use None
if return_value is attr.NOTHING:
if return_value in MISSING:
return_value = None

return_values[attribute.name] = return_value

return cls(**return_values)


def attribute_is_optional(attribute: attr.Attribute) -> bool:
def attribute_is_optional(attribute: FieldType) -> bool:
"""Returns True if the attribute is optional, False otherwise"""
return types.get_origin_type(attribute.type) is Union and isinstance(
None, types.get_arg_types(attribute.type)
)


def attribute_has_default(attribute: attr.Attribute) -> bool:
def attribute_has_default(attribute: FieldType) -> bool:
"""Returns True if the attribute has a default value, False otherwise"""
return attribute.default != attr.NOTHING or attribute_is_optional(attribute)
return attribute.default not in MISSING_OR_NONE or attribute_is_optional(attribute)
Loading

0 comments on commit 593c01b

Please sign in to comment.