Skip to content

Commit

Permalink
Fix type problems found by latest mypy (#75)
Browse files Browse the repository at this point in the history
Closes #73
* Move to current latest version of pytest to avoid deprecation warning
* Move to current latest version of mypy
* Correct multiple problems with type and import flagged by newer mypy
  • Loading branch information
TedBrookings authored Nov 22, 2023
1 parent ca4a572 commit b8ac915
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 88 deletions.
10 changes: 5 additions & 5 deletions fgpyo/io/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def test_assert_path_is_writeable_pass() -> None:
"""Should return the correct writeable path"""
with NamedTemp(suffix=".txt", mode="w", delete=True) as read_file:
path = Path(read_file.name)
assert fio.assert_path_is_writeable(path=path) is None
fio.assert_path_is_writeable(path=path)


@pytest.mark.parametrize(
"suffix, expected",
[
(".gz", io._io.TextIOWrapper),
(".fa", io._io.TextIOWrapper),
(".gz", io.TextIOWrapper),
(".fa", io.TextIOWrapper),
],
)
def test_reader(
Expand All @@ -103,8 +103,8 @@ def test_reader(
@pytest.mark.parametrize(
"suffix, expected",
[
(".gz", io._io.TextIOWrapper),
(".fa", io._io.TextIOWrapper),
(".gz", io.TextIOWrapper),
(".fa", io.TextIOWrapper),
],
)
def test_writer(
Expand Down
29 changes: 16 additions & 13 deletions fgpyo/util/inspect.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import sys
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Tuple
from typing import Type
from typing import Union

try: # py>=38
if sys.version_info >= (3, 8):
from typing import Literal
except ImportError: # py<38
else:
from typing_extensions import Literal
if sys.version_info >= (3, 12):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

import functools
from enum import Enum
Expand All @@ -29,8 +35,8 @@ class ParserNotFoundException(Exception):
def split_at_given_level(
field: str,
split_delim: str = ",",
increase_depth_chars: List[str] = ["{", "(", "["],
decrease_depth_chars: List[str] = ["}", ")", "]"],
increase_depth_chars: Iterable[str] = ("{", "(", "["),
decrease_depth_chars: Iterable[str] = ("}", ")", "]"),
) -> List[str]:
"""
Splits a nested field by its outer-most level
Expand Down Expand Up @@ -65,7 +71,7 @@ def split_at_given_level(


def _get_parser(
cls: Type, type_: Type, parsers: Optional[Dict[type, Callable[[str], Any]]] = None
cls: Type, type_: TypeAlias, parsers: Optional[Dict[type, Callable[[str], Any]]] = None
) -> partial:
"""Attempts to find a parser for a provided type.
Expand Down Expand Up @@ -114,14 +120,13 @@ def get_parser() -> partial:
subtypes[0],
parsers,
)
origin_type = types.get_origin_type(type_)
return functools.partial(
lambda s: origin_type(
lambda s: list(
[]
if s == ""
else [
subtype_parser(item)
for item in origin_type(split_at_given_level(s, split_delim=","))
for item in list(split_at_given_level(s, split_delim=","))
]
)
)
Expand All @@ -135,14 +140,13 @@ def get_parser() -> partial:
subtypes[0],
parsers,
)
origin_type = types.get_origin_type(type_)
return functools.partial(
lambda s: origin_type(
lambda s: set(
set({})
if s == "{}"
else [
subtype_parser(item)
for item in origin_type(split_at_given_level(s[1:-1], split_delim=","))
for item in set(split_at_given_level(s[1:-1], split_delim=","))
]
)
)
Expand All @@ -155,7 +159,6 @@ def get_parser() -> partial:
)
for subtype in types.get_arg_types(type_)
]
origin_type = types.get_origin_type(type_)

def tuple_parse(tuple_string: str) -> Tuple[Any, ...]:
"""
Expand Down Expand Up @@ -247,7 +250,7 @@ def dict_parse(dict_string: str) -> Dict[Any, Any]:
# Set the name that the user expects to see in error messages (we always
# return a temporary partial object so it's safe to set its __name__).
# Unions and Literals don't have a __name__, but their str is fine.
parser.__name__ = getattr(type_, "__name__", str(type_))
setattr(parser, "__name__", getattr(type_, "__name__", str(type_)))
return parser


Expand Down
6 changes: 4 additions & 2 deletions fgpyo/util/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
"""

import logging
import sys

try: # py>=38
if sys.version_info >= (3, 8):
from typing import Literal
except ImportError: # py<38
else:
from typing_extensions import Literal

import socket
from contextlib import AbstractContextManager
from logging import Logger
Expand Down
4 changes: 2 additions & 2 deletions fgpyo/util/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
from fgpyo import io
from fgpyo.util import inspect

MetricType = TypeVar("MetricType")
MetricType = TypeVar("MetricType", bound="Metric")


@attr.s
Expand All @@ -137,7 +137,7 @@ def values(self) -> Iterator[Any]:
"""An iterator over attribute values in the same order as the header."""
return iter(attr.astuple(self, recurse=False))

def formatted_values(self) -> Iterator[str]:
def formatted_values(self) -> List[str]:
"""An iterator over formatted attribute values in the same order as the header."""
return [self.format_value(value) for value in self.values()]

Expand Down
44 changes: 20 additions & 24 deletions fgpyo/util/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import inspect
import sys
import typing
from enum import Enum
from functools import partial
Expand All @@ -9,20 +10,19 @@
from typing import TypeVar
from typing import Union

try:
# `get_origin_type` is a method that gets the outer type (ex list in a List[str])
# `get_arg_types` is a method that gets the inner type (ex str in a List[str])
if sys.version_info >= (3, 8):
from typing import Literal
except ImportError:
from typing_extensions import Literal


# `get_origin_type` is a method that gets the outer type (ex list in a List[str])
if hasattr(typing, "get_origin"): # py>=38
get_origin_type = typing.get_origin
else: # py<38
get_arg_types = typing.get_args
else:
import typing_inspect
from typing_extensions import Literal

def get_origin_type(tp: Type) -> Type:
"""Returns the outer type of a Typing object (ex list in a List[T])"""
import typing_inspect

if type(tp) is type(Literal): # Py<=3.6.
return Literal
Expand All @@ -37,25 +37,19 @@ def get_origin_type(tp: Type) -> Type:
typing.Dict: dict,
}.get(origin, origin)


# `get_origin_type` is a method that gets the inner type (ex str in a List[str])
if hasattr(typing, "get_args"): # py>=38
get_arg_types = typing.get_args
else: # py<38

def get_arg_types(tp: Type) -> Type:
"""Gets the inner types of a Typing object (ex T in a List[T])"""
import typing_inspect

if type(tp) is type(Literal): # Py<=3.6.
return tp.__values__
return typing_inspect.get_args(tp, evaluate=True) # evaluate=True default on Py>=3.7.


T = TypeVar("T")
UnionType = TypeVar("UnionType", bound="Union")
EnumType = TypeVar("EnumType", bound="Enum")
LiteralType = TypeVar("LiteralType", bound="Literal")
# conceptually bound to "Literal" but that's not valid in the spec
# see: https://peps.python.org/pep-0586/#illegal-parameters-for-literal-at-type-check-time
LiteralType = TypeVar("LiteralType")


class InspectException(Exception):
Expand Down Expand Up @@ -92,7 +86,7 @@ def make_enum_parser(enum: Type[EnumType]) -> partial:
return partial(_make_enum_parser_worker, enum)


def is_constructible_from_str(type_: T) -> bool:
def is_constructible_from_str(type_: type) -> bool:
"""Returns true if the provided type can be constructed from a string"""
try:
sig = inspect.signature(type_)
Expand All @@ -113,7 +107,7 @@ def is_constructible_from_str(type_: T) -> bool:
return False


def _is_optional(type_: T) -> bool:
def _is_optional(type_: type) -> bool:
"""Returns true if type_ is optional"""
return get_origin_type(type_) is Union and type(None) in get_arg_types(type_)

Expand All @@ -122,15 +116,17 @@ def _make_union_parser_worker(
union: Type[UnionType],
parsers: Iterable[Callable[[str], UnionType]],
value: str,
) -> T:
) -> UnionType:
"""Worker function behind union parsing. Iterates through possible parsers for the union and
returns the value produced by the first parser that works. Otherwise raises an error if none
work"""
# Need to do this in the case of type Optional[str], because otherwise it'll return the string
# 'None' instead of the object None
if _is_optional(union):
try:
return none_parser(value)
# mypy doesn't like functions that return None always, so return separately
none_parser(value)
return None
except (ValueError, InspectException):
pass
for p in parsers:
Expand All @@ -141,7 +137,7 @@ def _make_union_parser_worker(
raise ValueError(f"{value} could not be parsed as any of {union}")


def make_union_parser(union: Type[UnionType], parsers: Iterable[Callable[[str], T]]) -> partial:
def make_union_parser(union: Type[UnionType], parsers: Iterable[Callable[[str], type]]) -> partial:
"""Generates a parser function for a union type object and set of parsers for the possible
parsers to that union type object
"""
Expand Down Expand Up @@ -176,12 +172,12 @@ def make_literal_parser(
return partial(_make_literal_parser_worker, literal, parsers)


def is_list_like(type_: T) -> bool:
def is_list_like(type_: type) -> bool:
"""Returns true if the value is a list or list like object"""
return get_origin_type(type_) in [list, collections.abc.Iterable, collections.abc.Sequence]


def none_parser(value: str) -> None:
def none_parser(value: str) -> Literal[None]:
"""Returns None if the value is 'None', else raises an error"""
if value == "":
return None
Expand Down
13 changes: 8 additions & 5 deletions fgpyo/vcf/tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Tuple

Expand Down Expand Up @@ -33,7 +34,7 @@ def sequence_dict() -> Dict[str, Dict[str, Any]]:

def _get_random_contig(
random_generator: random.Random, sequence_dict: Dict[str, Dict[str, Any]]
) -> (str, int):
) -> Tuple[str, int]:
"""Randomly select a contig from the sequence dictionary and return its name and length."""
contig = random_generator.choice(list(sequence_dict.values()))
return contig["ID"], contig["length"]
Expand Down Expand Up @@ -102,7 +103,7 @@ def _get_random_variant_inputs(
@pytest.fixture(scope="function")
def zero_sample_record_inputs(
random_generator: random.Random, sequence_dict: Dict[str, Dict[str, Any]]
) -> Tuple[Mapping[str, Any]]:
) -> Tuple[Mapping[str, Any], ...]:
"""
Fixture with inputs to create test Variant records for zero-sample VCFs (no genotypes).
Make them MappingProxyType so that they are immutable.
Expand Down Expand Up @@ -174,7 +175,7 @@ def test_minimal_inputs() -> None:

def test_sort_order(random_generator: random.Random) -> None:
"""Test if the VariantBuilder sorts the Variant records in the correct order."""
sorted_inputs = [
sorted_inputs: List[Dict[str, Any]] = [
{"contig": "chr1", "pos": 100},
{"contig": "chr1", "pos": 500},
{"contig": "chr2", "pos": 1000},
Expand All @@ -183,7 +184,9 @@ def test_sort_order(random_generator: random.Random) -> None:
{"contig": "chr10", "pos": 20},
{"contig": "chr11", "pos": 5},
]
scrambled_inputs = random_generator.sample(sorted_inputs, k=len(sorted_inputs))
scrambled_inputs: List[Dict[str, Any]] = random_generator.sample(
sorted_inputs, k=len(sorted_inputs)
)
assert scrambled_inputs != sorted_inputs # there should be something to actually sort
variant_builder = VariantBuilder()
for record_input in scrambled_inputs:
Expand Down Expand Up @@ -223,7 +226,7 @@ def _get_is_compressed(input_file: Path) -> bool:
@pytest.mark.parametrize("compress", (True, False))
def test_zero_sample_vcf_round_trip(
temp_path: Path,
zero_sample_record_inputs,
zero_sample_record_inputs: Tuple[Mapping[str, Any], ...],
compress: bool,
) -> None:
"""
Expand Down
Loading

0 comments on commit b8ac915

Please sign in to comment.