Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add support for spatial aggregation arguments to exactextract #159

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cropclassification/preprocess/_timeseries_calc_openeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
id_column=conf.columns["id"],
rasters_bands=images_bands,
output_dir=timeseries_periodic_dir,
stats=["count", "mean", "median", "std", "min", "max"], # type: ignore[arg-type]
stats=["count", "mean", "median", "std", "min", "max"],

Check warning on line 82 in cropclassification/preprocess/_timeseries_calc_openeo.py

View check run for this annotation

Codecov / codecov/patch

cropclassification/preprocess/_timeseries_calc_openeo.py#L82

Added line #L82 was not covered by tests
engine="pyqgis",
nb_parallel=nb_parallel,
)
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,15 @@ def zonal_stats_band_tofile(
for band in bands:
index = raster_info.bands[band].band_index
band_columns = include_cols.copy()
band_columns.extend([f"band_{index}_{stat}" for stat in stats])
band_columns.extend(
[f"band_{index}_{stat}" for stat in [stat.split("(")[0] for stat in stats]]
)
band_stats_df = stats_df[band_columns].copy()
band_stats_df.rename(
columns={f"band_{index}_{stat}": stat for stat in stats},
columns={
f"band_{index}_{stat}": stat
for stat in [stat.split("(")[0] for stat in stats]
},
inplace=True,
)
# Add fid column to the beginning of the dataframe
Expand All @@ -289,6 +294,19 @@ def zonal_stats_band_tofile(
f"Write data for {len(band_stats_df.index)} parcels found to {output_paths[band]}" # noqa: E501
)
if not output_paths[band].exists():
pdh.to_file(band_stats_df, output_paths[band], index=False)
# Write the info table to the output file
pdh.to_file(df=band_stats_df, path=output_paths[band], index=False)

# Write the parameters table to the output file
spatial_aggregation_args_df = pd.DataFrame(
data=stats, columns=["stats"]
)
pdh.to_file(
df=spatial_aggregation_args_df,
path=output_paths[band],
table_name="params",
index=False,
append=True,
)

return output_paths
53 changes: 50 additions & 3 deletions tests/test_zonal_stats_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,28 @@
import geofileops as gfo
import pytest

from cropclassification.helpers import config_helper as conf
from cropclassification.helpers import pandas_helper as pdh
from cropclassification.util import zonal_stats_bulk
from cropclassification.util.zonal_stats_bulk._zonal_stats_bulk_pyqgis import HAS_QGIS
from tests.test_helper import SampleData


@pytest.mark.parametrize("engine", ["pyqgis", "rasterstats", "exactextract"])
def test_zonal_stats_bulk(tmp_path, engine):
@pytest.mark.parametrize(
"engine, stats",
[
("pyqgis", ["mean", "count"]),
("rasterstats", ["mean", "count"]),
(
"exactextract",
[
"mean(min_coverage_frac=0.5,coverage_weight=none)",
"count(min_coverage_frac=0.5,coverage_weight=none)",
],
),
],
)
def test_zonal_stats_bulk(tmp_path, engine, stats):
if engine == "pyqgis" and not HAS_QGIS:
pytest.skip("QGIS is not available on this system.")

Expand All @@ -19,6 +33,16 @@ def test_zonal_stats_bulk(tmp_path, engine):
test_dir = tmp_path / sample_dir.name
shutil.copytree(sample_dir, test_dir)

# Read the configuration
config_paths = [
SampleData.config_dir / "cropgroup.ini",
SampleData.tasks_dir / "local_overrule.ini",
]
conf.read_config(
config_paths=config_paths,
default_basedir=SampleData.marker_basedir,
)

# Make sure the s2-agri input file was copied
test_image_roi_dir = test_dir / SampleData.image_dir.name / SampleData.roi_name
test_s1_asc_dir = test_image_roi_dir / SampleData.image_s1_asc_path.parent.name
Expand All @@ -34,7 +58,7 @@ def test_zonal_stats_bulk(tmp_path, engine):
id_column="UID",
rasters_bands=images_bands,
output_dir=tmp_path,
stats=["mean", "count"],
stats=stats,
engine=engine,
)

Expand All @@ -46,3 +70,26 @@ def test_zonal_stats_bulk(tmp_path, engine):
assert len(result_df) == vector_info.featurecount
# The calculates stats should not be nan for any row.
assert not any(result_df["mean"].isna())


# def test_spatial_aggregations():
# # Read the configuration
# config_paths = [
# SampleData.config_dir / "cropgroup.ini",
# SampleData.tasks_dir / "local_overrule.ini",
# ]
# conf.read_config(
# config_paths=config_paths,
# default_basedir=SampleData.marker_basedir,
# )
# spatial_aggregations = conf.timeseries.getlist("spatial_aggregations")
# spatial_aggregation_args = conf.timeseries.getdict("spatial_aggregation_args")

# ops = get_ops(
# {
# "spatial_aggregations": spatial_aggregations,
# "spatial_aggregation_args": spatial_aggregation_args,
# }
# )

# assert len(ops) == 6