Skip to content

Commit

Permalink
Fix type problems found by latest mypy
Browse files Browse the repository at this point in the history
Closes #73
* Move to current latest version of pytest to avoid deprecation warning
* Move to current latest version of mypy
* Correct multiple problems with type and import flagged by newer mypy
  • Loading branch information
TedBrookings committed Nov 22, 2023
1 parent ca4a572 commit 63c109e
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 84 deletions.
10 changes: 5 additions & 5 deletions fgpyo/io/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def test_assert_path_is_writeable_pass() -> None:
"""Should return the correct writeable path"""
with NamedTemp(suffix=".txt", mode="w", delete=True) as read_file:
path = Path(read_file.name)
assert fio.assert_path_is_writeable(path=path) is None
fio.assert_path_is_writeable(path=path)


@pytest.mark.parametrize(
"suffix, expected",
[
(".gz", io._io.TextIOWrapper),
(".fa", io._io.TextIOWrapper),
(".gz", io.TextIOWrapper),
(".fa", io.TextIOWrapper),
],
)
def test_reader(
Expand All @@ -103,8 +103,8 @@ def test_reader(
@pytest.mark.parametrize(
"suffix, expected",
[
(".gz", io._io.TextIOWrapper),
(".fa", io._io.TextIOWrapper),
(".gz", io.TextIOWrapper),
(".fa", io.TextIOWrapper),
],
)
def test_writer(
Expand Down
29 changes: 16 additions & 13 deletions fgpyo/util/inspect.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import sys
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Tuple
from typing import Type
from typing import Union

try: # py>=38
if sys.version_info >= (3, 8):
from typing import Literal
except ImportError: # py<38
else:
from typing_extensions import Literal
if sys.version_info >= (3, 12):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

import functools
from enum import Enum
Expand All @@ -29,8 +35,8 @@ class ParserNotFoundException(Exception):
def split_at_given_level(
field: str,
split_delim: str = ",",
increase_depth_chars: List[str] = ["{", "(", "["],
decrease_depth_chars: List[str] = ["}", ")", "]"],
increase_depth_chars: Iterable[str] = ("{", "(", "["),
decrease_depth_chars: Iterable[str] = ("}", ")", "]"),
) -> List[str]:
"""
Splits a nested field by its outer-most level
Expand Down Expand Up @@ -65,7 +71,7 @@ def split_at_given_level(


def _get_parser(
cls: Type, type_: Type, parsers: Optional[Dict[type, Callable[[str], Any]]] = None
cls: Type, type_: TypeAlias, parsers: Optional[Dict[type, Callable[[str], Any]]] = None
) -> partial:
"""Attempts to find a parser for a provided type.
Expand Down Expand Up @@ -114,14 +120,13 @@ def get_parser() -> partial:
subtypes[0],
parsers,
)
origin_type = types.get_origin_type(type_)
return functools.partial(
lambda s: origin_type(
lambda s: list(
[]
if s == ""
else [
subtype_parser(item)
for item in origin_type(split_at_given_level(s, split_delim=","))
for item in list(split_at_given_level(s, split_delim=","))
]
)
)
Expand All @@ -135,14 +140,13 @@ def get_parser() -> partial:
subtypes[0],
parsers,
)
origin_type = types.get_origin_type(type_)
return functools.partial(
lambda s: origin_type(
lambda s: set(
set({})
if s == "{}"
else [
subtype_parser(item)
for item in origin_type(split_at_given_level(s[1:-1], split_delim=","))
for item in set(split_at_given_level(s[1:-1], split_delim=","))
]
)
)
Expand All @@ -155,7 +159,6 @@ def get_parser() -> partial:
)
for subtype in types.get_arg_types(type_)
]
origin_type = types.get_origin_type(type_)

def tuple_parse(tuple_string: str) -> Tuple[Any, ...]:
"""
Expand Down Expand Up @@ -247,7 +250,7 @@ def dict_parse(dict_string: str) -> Dict[Any, Any]:
# Set the name that the user expects to see in error messages (we always
# return a temporary partial object so it's safe to set its __name__).
# Unions and Literals don't have a __name__, but their str is fine.
parser.__name__ = getattr(type_, "__name__", str(type_))
setattr(parser, "__name__", getattr(type_, "__name__", str(type_)))
return parser


Expand Down
6 changes: 4 additions & 2 deletions fgpyo/util/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
"""

import logging
import sys

try: # py>=38
if sys.version_info >= (3, 8):
from typing import Literal
except ImportError: # py<38
else:
from typing_extensions import Literal

import socket
from contextlib import AbstractContextManager
from logging import Logger
Expand Down
4 changes: 2 additions & 2 deletions fgpyo/util/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def values(self) -> Iterator[Any]:
"""An iterator over attribute values in the same order as the header."""
return iter(attr.astuple(self, recurse=False))

def formatted_values(self) -> Iterator[str]:
def formatted_values(self) -> List[str]:
"""An iterator over formatted attribute values in the same order as the header."""
return [self.format_value(value) for value in self.values()]

Expand Down Expand Up @@ -226,7 +226,7 @@ def parse(cls, fields: List[str]) -> Any:
return inspect.attr_from(cls=cls, kwargs=dict(zip(header, fields)), parsers=parsers)

@classmethod
def write(cls, path: Path, *values: MetricType) -> None:
def write(cls, path: Path, *values: "Metric") -> None:
"""Writes zero or more metrics to the given path.
The header will always be written.
Expand Down
36 changes: 16 additions & 20 deletions fgpyo/util/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import inspect
import sys
import typing
from enum import Enum
from functools import partial
Expand All @@ -9,20 +10,19 @@
from typing import TypeVar
from typing import Union

try:
# `get_origin_type` is a method that gets the outer type (ex list in a List[str])
# `get_arg_types` is a method that gets the inner type (ex str in a List[str])
if sys.version_info >= (3, 8):
from typing import Literal
except ImportError:
from typing_extensions import Literal


# `get_origin_type` is a method that gets the outer type (ex list in a List[str])
if hasattr(typing, "get_origin"): # py>=38
get_origin_type = typing.get_origin
else: # py<38
get_arg_types = typing.get_args
else:
import typing_inspect
from typing_extensions import Literal

def get_origin_type(tp: Type) -> Type:
"""Returns the outer type of a Typing object (ex list in a List[T])"""
import typing_inspect

if type(tp) is type(Literal): # Py<=3.6.
return Literal
Expand All @@ -37,15 +37,8 @@ def get_origin_type(tp: Type) -> Type:
typing.Dict: dict,
}.get(origin, origin)


# `get_origin_type` is a method that gets the inner type (ex str in a List[str])
if hasattr(typing, "get_args"): # py>=38
get_arg_types = typing.get_args
else: # py<38

def get_arg_types(tp: Type) -> Type:
"""Gets the inner types of a Typing object (ex T in a List[T])"""
import typing_inspect

if type(tp) is type(Literal): # Py<=3.6.
return tp.__values__
Expand All @@ -55,7 +48,9 @@ def get_arg_types(tp: Type) -> Type:
T = TypeVar("T")
UnionType = TypeVar("UnionType", bound="Union")
EnumType = TypeVar("EnumType", bound="Enum")
LiteralType = TypeVar("LiteralType", bound="Literal")
# conceptually bound to "Literal" but that's not valid in the spec
# see: https://peps.python.org/pep-0586/#illegal-parameters-for-literal-at-type-check-time
LiteralType = TypeVar("LiteralType")


class InspectException(Exception):
Expand Down Expand Up @@ -92,7 +87,7 @@ def make_enum_parser(enum: Type[EnumType]) -> partial:
return partial(_make_enum_parser_worker, enum)


def is_constructible_from_str(type_: T) -> bool:
def is_constructible_from_str(type_: type) -> bool:
"""Returns true if the provided type can be constructed from a string"""
try:
sig = inspect.signature(type_)
Expand Down Expand Up @@ -122,15 +117,16 @@ def _make_union_parser_worker(
union: Type[UnionType],
parsers: Iterable[Callable[[str], UnionType]],
value: str,
) -> T:
) -> UnionType:
"""Worker function behind union parsing. Iterates through possible parsers for the union and
returns the value produced by the first parser that works. Otherwise raises an error if none
work"""
# Need to do this in the case of type Optional[str], because otherwise it'll return the string
# 'None' instead of the object None
if _is_optional(union):
try:
return none_parser(value)
none_parser(value)
return None
except (ValueError, InspectException):
pass
for p in parsers:
Expand Down Expand Up @@ -181,7 +177,7 @@ def is_list_like(type_: T) -> bool:
return get_origin_type(type_) in [list, collections.abc.Iterable, collections.abc.Sequence]


def none_parser(value: str) -> None:
def none_parser(value: str) -> Literal[None]:
"""Returns None if the value is 'None', else raises an error"""
if value == "":
return None
Expand Down
13 changes: 8 additions & 5 deletions fgpyo/vcf/tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Tuple

Expand Down Expand Up @@ -33,7 +34,7 @@ def sequence_dict() -> Dict[str, Dict[str, Any]]:

def _get_random_contig(
random_generator: random.Random, sequence_dict: Dict[str, Dict[str, Any]]
) -> (str, int):
) -> Tuple[str, int]:
"""Randomly select a contig from the sequence dictionary and return its name and length."""
contig = random_generator.choice(list(sequence_dict.values()))
return contig["ID"], contig["length"]
Expand Down Expand Up @@ -102,7 +103,7 @@ def _get_random_variant_inputs(
@pytest.fixture(scope="function")
def zero_sample_record_inputs(
random_generator: random.Random, sequence_dict: Dict[str, Dict[str, Any]]
) -> Tuple[Mapping[str, Any]]:
) -> Tuple[Mapping[str, Any], ...]:
"""
Fixture with inputs to create test Variant records for zero-sample VCFs (no genotypes).
Make them MappingProxyType so that they are immutable.
Expand Down Expand Up @@ -174,7 +175,7 @@ def test_minimal_inputs() -> None:

def test_sort_order(random_generator: random.Random) -> None:
"""Test if the VariantBuilder sorts the Variant records in the correct order."""
sorted_inputs = [
sorted_inputs: List[Dict[str, Any]] = [
{"contig": "chr1", "pos": 100},
{"contig": "chr1", "pos": 500},
{"contig": "chr2", "pos": 1000},
Expand All @@ -183,7 +184,9 @@ def test_sort_order(random_generator: random.Random) -> None:
{"contig": "chr10", "pos": 20},
{"contig": "chr11", "pos": 5},
]
scrambled_inputs = random_generator.sample(sorted_inputs, k=len(sorted_inputs))
scrambled_inputs: List[Dict[str, Any]] = random_generator.sample(
sorted_inputs, k=len(sorted_inputs)
)
assert scrambled_inputs != sorted_inputs # there should be something to actually sort
variant_builder = VariantBuilder()
for record_input in scrambled_inputs:
Expand Down Expand Up @@ -223,7 +226,7 @@ def _get_is_compressed(input_file: Path) -> bool:
@pytest.mark.parametrize("compress", (True, False))
def test_zero_sample_vcf_round_trip(
temp_path: Path,
zero_sample_record_inputs,
zero_sample_record_inputs: Tuple[Mapping[str, Any], ...],
compress: bool,
) -> None:
"""
Expand Down
Loading

0 comments on commit 63c109e

Please sign in to comment.