Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc. typing fixes #84

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/generate_spherely_vfunc_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
import string
from pathlib import Path

from spherely import EARTH_RADIUS_METERS


VFUNC_TYPE_SPECS = {
"_VFunc_Nin1_Nout1": {"n_in": 1},
"_VFunc_Nin2_Nout1": {"n_in": 2},
"_VFunc_Nin2optradius_Nout1": {"n_in": 2, "radius": "float"},
"_VFunc_Nin1optradius_Nout1": {"n_in": 1, "radius": "float"},
"_VFunc_Nin2optradius_Nout1": {"n_in": 2, "radius": ("float", EARTH_RADIUS_METERS)},
"_VFunc_Nin1optradius_Nout1": {"n_in": 1, "radius": ("float", EARTH_RADIUS_METERS)},
"_VFunc_Nin1optprecision_Nout1": {"n_in": 1, "precision": ("int", 6)},
}

STUB_FILE_PATH = Path(__file__).parent / "spherely.pyi"
Expand Down Expand Up @@ -51,10 +55,11 @@ def _vfunctype_factory(class_name, n_in, **optargs):
"",
]
optarg_str = ", ".join(
f"{arg_name}: {arg_type} = ..." for arg_name, arg_type in optargs.items()
f"{arg_name}: {arg_type} = {arg_value}"
for arg_name, (arg_type, arg_value) in optargs.items()
)

geog_types = ["Geography", "npt.ArrayLike"]
geog_types = ["Geography", "Iterable[Geography]"]
for arg_types in itertools.product(geog_types, repeat=n_in):
arg_str = ", ".join(
f"{arg_name}: {arg_type}"
Expand Down
94 changes: 66 additions & 28 deletions src/spherely.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ from typing import (
Literal,
Protocol,
Sequence,
Tuple,
TypeVar,
overload,
)
Expand Down Expand Up @@ -66,7 +65,7 @@ class Projection:
@staticmethod
def lnglat() -> Projection: ...
@staticmethod
def speudo_mercator() -> Projection: ...
def pseudo_mercator() -> Projection: ...
@staticmethod
def orthographic(longitude: float, latitude: float) -> Projection: ...

Expand All @@ -76,6 +75,11 @@ _NameType = TypeVar("_NameType", bound=str)
_ScalarReturnType = TypeVar("_ScalarReturnType", bound=Any)
_ArrayReturnDType = TypeVar("_ArrayReturnDType", bound=Any)

# TODO: npt.NDArray[Geography] not supported yet
# (see https://github.com/numpy/numpy/issues/24738)
# (unless Geography is passed via Generic[...], see VFunc below)
T_NDArray_Geography = npt.NDArray[Any]

# The following types are auto-generated. Please don't edit them by hand.
# Instead, update the generate_spherely_vfunc_types.py script and run it
# to update the types.
Expand All @@ -87,7 +91,9 @@ class _VFunc_Nin1_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]
@overload
def __call__(self, geography: Geography) -> _ScalarReturnType: ...
@overload
def __call__(self, geography: npt.ArrayLike) -> npt.NDArray[_ArrayReturnDType]: ...
def __call__(
self, geography: Iterable[Geography]
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin2_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]):
@property
Expand All @@ -96,15 +102,15 @@ class _VFunc_Nin2_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]
def __call__(self, a: Geography, b: Geography) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: Geography, b: npt.ArrayLike
self, a: Geography, b: Iterable[Geography]
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: Geography
self, a: Iterable[Geography], b: Geography
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: npt.ArrayLike
self, a: Iterable[Geography], b: Iterable[Geography]
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin2optradius_Nout1(
Expand All @@ -114,19 +120,19 @@ class _VFunc_Nin2optradius_Nout1(
def __name__(self) -> _NameType: ...
@overload
def __call__(
self, a: Geography, b: Geography, radius: float = ...
self, a: Geography, b: Geography, radius: float = 6371010.0
) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: Geography, b: npt.ArrayLike, radius: float = ...
self, a: Geography, b: Iterable[Geography], radius: float = 6371010.0
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: Geography, radius: float = ...
self, a: Iterable[Geography], b: Geography, radius: float = 6371010.0
) -> npt.NDArray[_ArrayReturnDType]: ...
@overload
def __call__(
self, a: npt.ArrayLike, b: npt.ArrayLike, radius: float = ...
self, a: Iterable[Geography], b: Iterable[Geography], radius: float = 6371010.0
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin1optradius_Nout1(
Expand All @@ -135,10 +141,24 @@ class _VFunc_Nin1optradius_Nout1(
@property
def __name__(self) -> _NameType: ...
@overload
def __call__(self, a: Geography, radius: float = ...) -> _ScalarReturnType: ...
def __call__(
self, a: Geography, radius: float = 6371010.0
) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: npt.ArrayLike, radius: float = ...
self, a: Iterable[Geography], radius: float = 6371010.0
) -> npt.NDArray[_ArrayReturnDType]: ...

class _VFunc_Nin1optprecision_Nout1(
Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]
):
@property
def __name__(self) -> _NameType: ...
@overload
def __call__(self, a: Geography, precision: int = 6) -> _ScalarReturnType: ...
@overload
def __call__(
self, a: Iterable[Geography], precision: int = 6
) -> npt.NDArray[_ArrayReturnDType]: ...

# /// End types
Expand Down Expand Up @@ -188,12 +208,9 @@ def create_collection(geographies: Iterable[Geography]) -> GeometryCollection: .

# Geography creation (vectorized)

@overload
def points(
longitude: npt.ArrayLike, latitude: npt.ArrayLike
) -> npt.NDArray[np.object_]: ...
@overload
def points(longitude: float, latitude: float) -> PointGeography: ... # type: ignore[misc]
) -> PointGeography | T_NDArray_Geography: ...

# Geography utils

Expand Down Expand Up @@ -234,42 +251,63 @@ boundary: _VFunc_Nin1_Nout1[Literal["boundary"], Geography, Geography]
convex_hull: _VFunc_Nin1_Nout1[
Literal["convex_hull"], PolygonGeography, PolygonGeography
]
distance: _VFunc_Nin2optradius_Nout1[Literal["distance"], float, float]
area: _VFunc_Nin1optradius_Nout1[Literal["area"], float, float]
length: _VFunc_Nin1optradius_Nout1[Literal["length"], float, float]
perimeter: _VFunc_Nin1optradius_Nout1[Literal["perimeter"], float, float]
distance: _VFunc_Nin2optradius_Nout1[Literal["distance"], float, np.float64]
area: _VFunc_Nin1optradius_Nout1[Literal["area"], float, np.float64]
length: _VFunc_Nin1optradius_Nout1[Literal["length"], float, np.float64]
perimeter: _VFunc_Nin1optradius_Nout1[Literal["perimeter"], float, np.float64]

# io functions

to_wkt: _VFunc_Nin1_Nout1[Literal["to_wkt"], str, object]
to_wkt: _VFunc_Nin1optprecision_Nout1[Literal["to_wkt"], str, object]
to_wkb: _VFunc_Nin1_Nout1[Literal["to_wkb"], bytes, object]

@overload
def from_wkt(
a: Iterable[str],
a: str,
oriented: bool = False,
planar: bool = False,
tessellate_tolerance: float = 100.0,
) -> npt.NDArray[Any]: ...
) -> Geography: ...
@overload
def from_wkt(
a: list[str] | npt.NDArray[np.str_],
oriented: bool = False,
planar: bool = False,
tessellate_tolerance: float = 100.0,
) -> T_NDArray_Geography: ...
@overload
def from_wkb(
a: bytes,
oriented: bool = False,
planar: bool = False,
tessellate_tolerance: float = 100.0,
) -> Geography: ...
@overload
def from_wkb(
a: Iterable[bytes],
oriented: bool = False,
planar: bool = False,
tessellate_tolerance: float = 100.0,
) -> npt.NDArray[Any]: ...
) -> T_NDArray_Geography: ...

class ArrowSchemaExportable(Protocol):
def __arrow_c_schema__(self) -> object: ...

class ArrowArrayExportable(Protocol):
def __arrow_c_array__(
self, requested_schema: object | None = None
) -> Tuple[object, object]: ...
) -> tuple[object, object]: ...

class ArrowArrayHolder:
benbovy marked this conversation as resolved.
Show resolved Hide resolved
def __arrow_c_array__(
self, requested_schema: object | None = None
) -> tuple[object, object]: ...

def to_geoarrow(
input: npt.ArrayLike,
input: Geography | T_NDArray_Geography,
/,
*,
output_schema: ArrowSchemaExportable | None = None,
output_schema: ArrowSchemaExportable | str | None = None,
projection: Projection = Projection.lnglat(),
planar: bool = False,
tessellate_tolerance: float = 100.0,
Expand All @@ -284,4 +322,4 @@ def from_geoarrow(
tessellate_tolerance: float = 100.0,
projection: Projection = Projection.lnglat(),
geometry_encoding: str | None = None,
) -> npt.NDArray[Any]: ...
) -> T_NDArray_Geography: ...
12 changes: 6 additions & 6 deletions tests/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_distance_with_custom_radius() -> None:
assert actual == pytest.approx(np.pi / 2)


def test_area():
def test_area() -> None:
# scalar
geog = spherely.create_polygon([(0, 0), (90, 0), (0, 90), (0, 0)])
result = spherely.area(geog, radius=1)
Expand Down Expand Up @@ -191,11 +191,11 @@ def test_area():
"POLYGON EMPTY",
],
)
def test_area_empty(geog):
def test_area_empty(geog) -> None:
assert spherely.area(spherely.from_wkt(geog)) == 0


def test_length():
def test_length() -> None:
geog = spherely.create_linestring([(0, 0), (1, 0)])
result = spherely.length(geog, radius=1)
assert isinstance(result, float)
Expand All @@ -218,11 +218,11 @@ def test_length():
"POLYGON ((0 0, 0 1, 1 0, 0 0))",
],
)
def test_length_invalid(geog):
def test_length_invalid(geog) -> None:
assert spherely.length(spherely.from_wkt(geog)) == 0.0


def test_perimeter():
def test_perimeter() -> None:
geog = spherely.create_polygon([(0, 0), (0, 90), (90, 90), (90, 0), (0, 0)])
result = spherely.perimeter(geog, radius=1)
assert isinstance(result, float)
Expand All @@ -239,5 +239,5 @@ def test_perimeter():
@pytest.mark.parametrize(
"geog", ["POINT (0 0)", "POINT EMPTY", "LINESTRING (0 0, 1 0)", "POLYGON EMPTY"]
)
def test_perimeter_invalid(geog):
def test_perimeter_invalid(geog) -> None:
assert spherely.perimeter(spherely.from_wkt(geog)) == 0.0
20 changes: 10 additions & 10 deletions tests/test_boolean_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
),
],
)
def test_union(geog1, geog2, expected):
def test_union(geog1, geog2, expected) -> None:
result = spherely.union(spherely.from_wkt(geog1), spherely.from_wkt(geog2))
assert str(result) == expected

Expand All @@ -47,12 +47,12 @@ def test_union_polygon():
),
],
)
def test_intersection(geog1, geog2, expected):
def test_intersection(geog1, geog2, expected) -> None:
result = spherely.intersection(spherely.from_wkt(geog1), spherely.from_wkt(geog2))
assert str(result) == expected


def test_intersection_empty():
def test_intersection_empty() -> None:
result = spherely.intersection(poly1, spherely.from_wkt("POLYGON EMPTY"))
# assert spherely.is_empty(result)
assert str(result) == "GEOMETRYCOLLECTION EMPTY"
Expand All @@ -66,7 +66,7 @@ def test_intersection_empty():
assert str(result) == "GEOMETRYCOLLECTION EMPTY"


def test_intersection_lines():
def test_intersection_lines() -> None:
result = spherely.intersection(
spherely.from_wkt("LINESTRING (-45 0, 45 0)"),
spherely.from_wkt("LINESTRING (0 -10, 0 10)"),
Expand All @@ -75,7 +75,7 @@ def test_intersection_lines():
assert spherely.distance(result, spherely.from_wkt("POINT (0 0)")) == 0


def test_intersection_polygons():
def test_intersection_polygons() -> None:
result = spherely.intersection(poly1, poly2)
# TODO precision could be higher with snap level
precision = 2 if Version(spherely.__s2geography_version__) < Version("0.2.0") else 1
Expand All @@ -85,7 +85,7 @@ def test_intersection_polygons():
)


def test_intersection_polygon_model():
def test_intersection_polygon_model() -> None:
poly = spherely.from_wkt("POLYGON ((0 0, 10 0, 10 10, 0 10, 0 0))")
point = spherely.from_wkt("POINT (0 0)")

Expand All @@ -107,12 +107,12 @@ def test_intersection_polygon_model():
),
],
)
def test_difference(geog1, geog2, expected):
def test_difference(geog1, geog2, expected) -> None:
result = spherely.difference(spherely.from_wkt(geog1), spherely.from_wkt(geog2))
assert spherely.equals(result, spherely.from_wkt(expected))


def test_difference_polygons():
def test_difference_polygons() -> None:
result = spherely.difference(poly1, poly2)
expected_near = spherely.area(poly1) - spherely.area(
spherely.from_wkt("POLYGON ((5 5, 10 5, 10 10, 5 10, 5 5))")
Expand All @@ -133,14 +133,14 @@ def test_difference_polygons():
),
],
)
def test_symmetric_difference(geog1, geog2, expected):
def test_symmetric_difference(geog1, geog2, expected) -> None:
result = spherely.symmetric_difference(
spherely.from_wkt(geog1), spherely.from_wkt(geog2)
)
assert spherely.equals(result, spherely.from_wkt(expected))


def test_symmetric_difference_polygons():
def test_symmetric_difference_polygons() -> None:
result = spherely.symmetric_difference(poly1, poly2)
expected_near = 2 * (
spherely.area(poly1)
Expand Down
Loading
Loading