Skip to content

Commit

Permalink
Black
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Apr 23, 2024
1 parent dd4efaf commit 5033de5
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tests/test_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
from sleap_roots import Series
from sleap_roots.lengths import get_max_length_pts
from sleap_roots.points import extract_points_from_geometry, filter_plants_with_unexpected_ct, get_count, join_pts
from sleap_roots.points import (
extract_points_from_geometry,
filter_plants_with_unexpected_ct,
get_count,
join_pts,
)
from sleap_roots.points import (
get_all_pts_array,
get_nodes,
Expand Down Expand Up @@ -745,39 +750,48 @@ def test_extract_from_point():
expected = [np.array([1, 2])]
assert np.array_equal(extract_points_from_geometry(point), expected)


def test_extract_from_multipoint():
multipoint = MultiPoint([(1, 2), (3, 4)])
expected = [np.array([1, 2]), np.array([3, 4])]
results = extract_points_from_geometry(multipoint)
assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))


def test_extract_from_linestring():
linestring = LineString([(0, 0), (1, 1), (2, 2)])
expected = [np.array([0, 0]), np.array([1, 1]), np.array([2, 2])]
results = extract_points_from_geometry(linestring)
assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))


def test_extract_from_geometrycollection():
geom_collection = GeometryCollection([Point(1, 2), LineString([(0, 0), (1, 1)])])
expected = [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])]
results = extract_points_from_geometry(geom_collection)
assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))


def test_extract_from_empty_multipoint():
empty_multipoint = MultiPoint()
expected = []
assert extract_points_from_geometry(empty_multipoint) == expected


def test_extract_from_empty_linestring():
empty_linestring = LineString()
expected = []
assert extract_points_from_geometry(empty_linestring) == expected


def test_extract_from_unsupported_type():
with pytest.raises(NameError):
extract_points_from_geometry(Polygon([(0, 0), (1, 1), (1, 0)])) # Polygon is unsupported
extract_points_from_geometry(
Polygon([(0, 0), (1, 1), (1, 0)])
) # Polygon is unsupported


def test_extract_from_empty_geometrycollection():
empty_geom_collection = GeometryCollection()
expected = []
assert extract_points_from_geometry(empty_geom_collection) == expected
assert extract_points_from_geometry(empty_geom_collection) == expected

0 comments on commit 5033de5

Please sign in to comment.