Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/more misc issues #91

Draft
wants to merge 6 commits into
base: staging
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ repos:
- id: black
language_version: python3.10
- repo: https://github.com/pycqa/pylint
rev: pylint-2.6.0
rev: v2.14.4
hooks:
- id: pylint
122 changes: 58 additions & 64 deletions geograph/geograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,48 +422,57 @@ def _load_from_dataframe(

# Reset index to ensure consistent indices
df = df.reset_index(drop=True)
# Using this list and iterating through it is slightly faster than
# iterating through df due to the dataframe overhead
geom: List[shapely.Polygon] = df["geometry"].tolist()
# this dict maps polygon row numbers in df to a list
# of neighbouring polygon row numbers
graph_dict = {}

if tolerance > 0:
# Expand the borders of the polygons by `tolerance```
new_polygons: List[shapely.Polygon] = (
df["geometry"].buffer(tolerance).tolist()
)
# pylint: disable=protected-access
if gpd._compat.USE_PYGEOS:
if tolerance > 0:
neighbour_arr = df.sindex.query_bulk(
df["geometry"].buffer(tolerance), predicate="intersects"
).transpose()
else:
neighbour_arr = df.sindex.query_bulk(
df["geometry"], predicate="intersects"
).transpose()
else:
# this dict maps polygon row numbers in df to a list
# of neighbouring polygon row numbers
graph_dict = {}
if tolerance > 0:
# Expand the borders of the polygons by `tolerance```
new_polygons: gpd.GeoSeries = df["geometry"].buffer(tolerance)
# Creating nodes (=vertices) and finding neighbors
for index, polygon in tqdm(
enumerate(geom),
enumerate(df["geometry"]),
desc="Step 1 of 2: Creating nodes and finding neighbours",
total=len(geom),
total=len(df),
):
if tolerance > 0:
# find the indexes of all polygons which intersect with this one
neighbours = df.sindex.query(
new_polygons[index], predicate="intersects"
)
else:
neighbours = df.sindex.query(polygon, predicate="intersects")

graph_dict[index] = neighbours
# add each polygon as a node to the graph with useful attributes
self.graph.add_node(
index,
rep_point=polygon.representative_point(),
area=polygon.area,
perimeter=polygon.length,
bounds=polygon.bounds,
)
# pylint: disable=protected-access
if not gpd._compat.USE_PYGEOS:
if tolerance > 0:
# find the indexes of all polygons which intersect with this one
neighbours = df.sindex.query(
new_polygons[index], predicate="intersects"
)
else:
neighbours = df.sindex.query(polygon, predicate="intersects")

graph_dict[index] = neighbours
# TODO: factor out for use_pygeos
self.graph.add_node(index)
# iterate through the dict and add edges between neighbouring polygons
for polygon_id, neighbours in tqdm(
graph_dict.items(), desc="Step 2 of 2: Adding edges"
):
for neighbour_id in neighbours:
if polygon_id != neighbour_id:
self.graph.add_edge(polygon_id, neighbour_id)
# pylint: disable=protected-access
if gpd._compat.USE_PYGEOS:
for index, neighbour in tqdm(
neighbour_arr, desc="Step 2 of 2: Adding edges"
):
if index != neighbour:
self.graph.add_edge(index, neighbour)
else:
for polygon_id, neighbours in tqdm(
graph_dict.items(), desc="Step 2 of 2: Adding edges"
):
for neighbour_id in neighbours:
if polygon_id != neighbour_id:
self.graph.add_edge(polygon_id, neighbour_id)

# add index name
df.index.name = "node_index"
Expand Down Expand Up @@ -802,7 +811,9 @@ def get_graph_components(
"""
components: List[set] = list(nx.connected_components(self.graph))
if calc_polygons:
geom = [self.df["geometry"].loc[comp].unary_union for comp in components]
geom = [
self.df["geometry"].loc[list(comp)].unary_union for comp in components
]
gdf = gpd.GeoDataFrame(
{"geometry": geom, "class_label": -1}, crs=self.df.crs
)
Expand Down Expand Up @@ -850,7 +861,7 @@ def get_metric(
result = metrics._get_metric(
name=name, geo_graph=self, class_value=class_value, **metric_kwargs
)
if name in self.class_metrics.keys():
if name in self.class_metrics:
self.class_metrics[name][class_value] = result
else:
self.class_metrics[name] = {class_value: result}
Expand Down Expand Up @@ -968,14 +979,7 @@ def _add_node(
node_data = dict(data.items())

# Add node to graph
self.graph.add_node(
node_id,
rep_point=node_data["geometry"].representative_point(),
area=node_data["geometry"].area,
perimeter=node_data["geometry"].length,
class_label=node_data["class_label"],
bounds=node_data["geometry"].bounds,
)
self.graph.add_node(node_id)

# Add node data to dataframe
missing_cols = {
Expand Down Expand Up @@ -1121,25 +1125,25 @@ class label in `valid_classes`, as long as they are less than
f"and {self.graph.number_of_edges()} edges.",
)

def _load_from_graph_path(self, load_path: pathlib.Path) -> None:
def _load_from_graph_path(self, graph_path: pathlib.Path) -> None:
"""
Load networkx graph and dataframe objects from a pickle file.

Args:
load_path (pathlib.Path): Path to a pickle file. Can be compressed
graph_path (pathlib.Path): Path to a pickle file. Can be compressed
with gzip or bz2.

Returns:
gpd.GeoDataFrame: The dataframe containing polygon objects.
"""
if load_path.suffix == ".bz2":
with bz2.BZ2File(load_path, "rb") as bz2_file:
if graph_path.suffix == ".bz2":
with bz2.BZ2File(graph_path, "rb") as bz2_file:
data = pickle.load(bz2_file)
elif load_path.suffix == ".gz":
with gzip.GzipFile(load_path, "rb") as gz_file:
elif graph_path.suffix == ".gz":
with gzip.GzipFile(graph_path, "rb") as gz_file:
data = pickle.loads(gz_file.read())
else:
with open(load_path, "rb") as file:
with open(graph_path, "rb") as file:
data = pickle.load(file)
self.df = data["dataframe"]
self.name = data["name"]
Expand Down Expand Up @@ -1263,16 +1267,6 @@ def _load_from_dataframe(
self.graph = nx.complete_graph(len(df))
else:
self.graph = nx.empty_graph(len(df))
# Add node attributes
for node in tqdm(
self.graph.nodes, desc="Constructing graph", total=len(self.graph)
):
polygon = geom[node]
self.graph.nodes[node]["rep_point"] = polygon.representative_point()
self.graph.nodes[node]["area"] = polygon.area
self.graph.nodes[node]["perimeter"] = polygon.length
self.graph.nodes[node]["bounds"] = polygon.bounds

# Add edge attributes if necessary
if self.has_distance_edges:
for u, v, attrs in tqdm(
Expand Down
51 changes: 48 additions & 3 deletions geograph/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union
from itertools import combinations
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import networkx as nx
import numpy as np

if TYPE_CHECKING:
import geograph


# TODO: refactor this file to return a tuple with the values to put inside the
# metric such that the metric is only created once by the calling GeoGraph
# since dataclass creation is slow
# define a metric dataclass with < <= => > == comparisons that work as you would
# expect intuitively
@dataclass()
Expand Down Expand Up @@ -52,6 +55,8 @@ def __ge__(self, o: object) -> bool:
########################################################################################
# 1. Landscape level metrics
########################################################################################


def _num_patches(geo_graph: geograph.GeoGraph) -> Metric:
"""
Calculate number of patches.
Expand Down Expand Up @@ -562,7 +567,7 @@ def _class_effective_mesh_size(
}

########################################################################################
# 3. Habitat componment level metrics
# 3. Habitat component level metrics
########################################################################################


Expand Down Expand Up @@ -655,17 +660,57 @@ def _avg_component_isolation(geo_graph: geograph.GeoGraph) -> Metric:
)


def _habitat_iic(
geo_graph: geograph.GeoGraph,
get_total_area: bool = False,
shortest_path_cutoff: Optional[int] = None,
) -> Metric:
if get_total_area:
# Most efficient way to get area of convex hull of the GeoGraph
total_area = geo_graph.components.df.dissolve().convex_hull.values[0].area
iic = 0.0
idx_dict = dict(zip(geo_graph.df.index.values, range(len(geo_graph.df))))
path_lengths: Dict = dict(
nx.all_pairs_shortest_path_length(geo_graph.graph, cutoff=shortest_path_cutoff)
)
for x in combinations(geo_graph.df.index.values, 2):
if x[1] not in path_lengths[x[0]]:
continue
iic += (
geo_graph.graph.nodes[x[0]]["area"] * geo_graph.graph.nodes[x[1]]["area"]
) / (1 + path_lengths[x[0]][x[1]])
if get_total_area:
# for node in geo_graph.graph.nodes:
# iic += geo_graph.graph.nodes[node]["area"] ** 2
return Metric(
value=iic / total_area,
name="habitat_iic",
description="The habitat IIC metric",
variant="component",
unit="dimensionless",
)
return Metric(
value=iic,
name="habitat_iic",
description="The habitat IIC metric",
variant="component",
unit="dimensionless",
)


COMPONENT_METRICS_DICT = {
"num_components": _num_components,
"avg_component_area": _avg_component_area,
"avg_component_isolation": _avg_component_isolation,
"habitat_iic": _habitat_iic,
}


########################################################################################
# 4. Define access interface for GeoGraph
########################################################################################


STANDARD_METRICS = ["num_components", "avg_patch_area", "total_area"]


Expand Down
12 changes: 0 additions & 12 deletions pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,6 @@ disable=abstract-method,
# mypackage.mymodule.MyReporterClass.
output-format=text

# Put messages in a separate file for each module / package specified on the
# command line instead of printing them on stdout. Reports (if any) will be
# written in a file name "pylint_global.[txt|html]". This option is deprecated
# and it will be removed in Pylint 2.0.
files-output=no

# Tells whether to display a full report or only the messages
reports=no

Expand Down Expand Up @@ -284,12 +278,6 @@ ignore-long-lines=(?x)(
# else.
single-line-if-stmt=yes

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=

# Maximum number of lines in a module
max-module-lines=99999

Expand Down