diff --git a/src/unified_graphics/diag.py b/src/unified_graphics/diag.py index 4133e14b..1701bba6 100644 --- a/src/unified_graphics/diag.py +++ b/src/unified_graphics/diag.py @@ -2,17 +2,11 @@ from collections import namedtuple from datetime import datetime, timedelta from enum import Enum -from typing import Union -from urllib.parse import urlparse import numpy as np import pandas as pd import sqlalchemy as sa -import xarray as xr -import zarr # type: ignore -from s3fs import S3FileSystem, S3Map # type: ignore from werkzeug.datastructures import MultiDict -from xarray.core.dataset import Dataset from .models import Analysis, WeatherModel @@ -88,242 +82,75 @@ def get_model_metadata(session) -> ModelMetadata: ) -def get_store(url: str) -> Union[str, S3Map]: - result = urlparse(url) - if result.scheme in ["", "file"]: - return result.path - - if result.scheme != "s3": - raise ValueError(f"Unsupported protocol '{result.scheme}' for URI: '{url}'") - - region = os.environ.get("AWS_REGION", "us-east-1") - s3 = S3FileSystem( - key=os.environ.get("AWS_ACCESS_KEY_ID"), - secret=os.environ.get("AWS_SECRET_ACCESS_KEY"), - token=os.environ.get("AWS_SESSION_TOKEN"), - client_kwargs={"region_name": region}, - ) - - return S3Map(root=f"{result.netloc}{result.path}", s3=s3, check=False) - - -def open_diagnostic( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - variable: Variable, - initialization_time: str, - loop: MinimLoop, -) -> xr.Dataset: - store = get_store(diag_zarr) - group = ( - f"/{model}/{system}/{domain}/{background}/{frequency}" - f"/{variable.value}/{initialization_time}/{loop.value}" - ) - return xr.open_zarr(store, group=group, consolidated=False) - - -def parse_filter_value(value): - if value == "true": - return 1 - - if value == "false": - return 0 - - try: - return float(value) - except ValueError: - return value - - -# TODO: Refactor to a class -# I think this might belong in a different module. It could be a class or set of classes -# that represent different filters that can be added together into a filtering pipeline -def get_bounds(filters: MultiDict): - for coord, value in filters.items(): - extent = np.array( - [ - [parse_filter_value(digit) for digit in pair.split(",")] - for pair in value.split("::") - ] - ) - yield coord, extent.min(axis=0), extent.max(axis=0) - - -def apply_filters(dataset: xr.Dataset, filters: MultiDict) -> Dataset: - for coord, lower, upper in get_bounds(filters): - data_array = dataset[coord] - dataset = dataset.where((data_array >= lower) & (data_array <= upper)).dropna( - dim="nobs" - ) - - # If the is_used filter is not passed, our default behavior is to include only used - # observations. - if "is_used" not in filters: - dataset = dataset.where(dataset["is_used"]).dropna(dim="nobs") - - return dataset - - -def scalar( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - variable: Variable, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame: - data = open_diagnostic( - diag_zarr, - model, - system, - domain, - background, - frequency, - variable, - initialization_time, - loop, +def magnitude(dataset: pd.DataFrame) -> pd.DataFrame: + return dataset.groupby(level=0).aggregate( + { + "obs_minus_forecast_adjusted": np.linalg.norm, + "obs_minus_forecast_unadjusted": np.linalg.norm, + "observation": np.linalg.norm, + "longitude": "first", + "latitude": "first", + } ) - data = apply_filters(data, filters) - - return data.to_dataframe() -def temperature( - diag_zarr: str, +def diag_observations( model: str, system: str, domain: str, background: str, frequency: str, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, + variable: str, + init_time: datetime, + loop: str, + uri: str, + filters: dict = {}, ) -> pd.DataFrame: - return scalar( - diag_zarr, - model, - system, - domain, - background, - frequency, - Variable.TEMPERATURE, - initialization_time, - loop, - filters, - ) + def matches(df: pd.DataFrame, filters: dict) -> bool: + result = True + for col_name, filter_value in filters.items(): + arr = np.array(filter_value) + result &= (df[col_name] >= arr.min(axis=0)).all() + result &= (df[col_name] <= arr.max(axis=0)).all() -def moisture( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame: - return scalar( - diag_zarr, - model, - system, - domain, - background, - frequency, - Variable.MOISTURE, - initialization_time, - loop, - filters, - ) + if "is_used" not in filters: + result &= df["is_used"].any() + return result -def pressure( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame: - return scalar( - diag_zarr, - model, - system, - domain, - background, - frequency, - Variable.PRESSURE, - initialization_time, - loop, - filters, - ) - + model_config = "_".join((model, background, system, domain, frequency)) -def wind( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - initialization_time: str, - loop: MinimLoop, - filters: MultiDict, -) -> pd.DataFrame | pd.Series: - data = open_diagnostic( - diag_zarr, - model, - system, - domain, - background, - frequency, - Variable.WIND, - initialization_time, - loop, + df = pd.read_parquet( + "/".join((uri, model_config, variable)), + columns=[ + "obs_minus_forecast_adjusted", + "obs_minus_forecast_unadjusted", + "observation", + "latitude", + "longitude", + "is_used", + ], + filters=( + ("loop", "=", loop), + ("initialization_time", "=", init_time), + ), ) - data = apply_filters(data, filters) - - return data.to_dataframe() - - -def magnitude(dataset: pd.DataFrame) -> pd.DataFrame: - return dataset.groupby(level=0).aggregate( - { - "obs_minus_forecast_adjusted": np.linalg.norm, - "obs_minus_forecast_unadjusted": np.linalg.norm, - "observation": np.linalg.norm, - "longitude": "first", - "latitude": "first", - } - ) + # Group the rows of the DataFrame by nobs (effectively the observation ID) and test + # each group against our filters. This is necessary because we use a MultiIndex for + # vectors where the second level of the index is the vector component. If we don't + # group the components like this, we run the risk that one component matches the + # filters and the other doesn't, leaving us with a partial observation. + matching_obs = [obs for _, obs in df.groupby("nobs") if matches(obs, filters)] + # If no observations match the filters, return an empty DataFrame by masking out all + # the values in the DataFrame using a list of repeated False values + if len(matching_obs) < 1: + return df[[False] * len(df)] -def get_model_run_list( - diag_zarr: str, - model: str, - system: str, - domain: str, - background: str, - frequency: str, - variable: Variable, -): - store = get_store(diag_zarr) - path = "/".join([model, system, domain, background, frequency, variable.value]) - with zarr.open_group(store, mode="r", path=path) as group: - return group.group_keys() + # Otherwise concatenate the matching DataFrames back into a single DataFrame + return pd.concat(matching_obs) def history( diff --git a/src/unified_graphics/routes.py b/src/unified_graphics/routes.py index c4033820..eccbe2fb 100644 --- a/src/unified_graphics/routes.py +++ b/src/unified_graphics/routes.py @@ -10,6 +10,7 @@ stream_template, url_for, ) +from werkzeug.datastructures import MultiDict from zarr.errors import FSPathExistNotDir, GroupNotFoundError # type: ignore from unified_graphics import diag @@ -18,6 +19,27 @@ bp = Blueprint("api", __name__) +def parse_filters(query: MultiDict) -> dict: + def parse_value(value): + if "::" in value: + return [parse_value(tok) for tok in value.split("::")] + + if value in ["true", "false"]: + return value == "true" + + try: + return float(value) + except ValueError: + return value + + filters = {} + for col, value_list in query.lists(): + val = [parse_value(val) for val in value_list] + filters[col] = val if len(val) > 1 else val[0] + + return filters + + @bp.errorhandler(GroupNotFoundError) def handle_diag_group_not_found(e): current_app.logger.exception("Unable to read diagnostic group") @@ -167,22 +189,19 @@ def history(model, system, domain, background, frequency, variable, loop): def diagnostics( model, system, domain, background, frequency, variable, initialization_time, loop ): - try: - v = diag.Variable(variable) - except ValueError: - return jsonify(msg=f"Variable not found: '{variable}'"), 404 - - variable_diagnostics = getattr(diag, v.name.lower()) - data = variable_diagnostics( - current_app.config["DIAG_ZARR"], + filters = parse_filters(request.args) + + data = diag.diag_observations( model, system, domain, background, frequency, - initialization_time, - diag.MinimLoop(loop), - request.args, + variable, + datetime.fromisoformat(initialization_time), + loop, + current_app.config["DIAG_PARQUET"], + filters, )[ [ "obs_minus_forecast_adjusted", @@ -207,22 +226,19 @@ def diagnostics( def magnitude( model, system, domain, background, frequency, variable, initialization_time, loop ): - try: - v = diag.Variable(variable) - except ValueError: - return jsonify(msg=f"Variable not found: '{variable}'"), 404 - - variable_diagnostics = getattr(diag, v.name.lower()) - data = variable_diagnostics( - current_app.config["DIAG_ZARR"], + filters = parse_filters(request.args) + + data = diag.diag_observations( model, system, domain, background, frequency, - initialization_time, - diag.MinimLoop(loop), - request.args, + variable, + datetime.fromisoformat(initialization_time), + loop, + current_app.config["DIAG_PARQUET"], + filters, )[ [ "obs_minus_forecast_adjusted", diff --git a/tests/test_diag.py b/tests/test_diag.py index 86c041cd..e1acd84f 100644 --- a/tests/test_diag.py +++ b/tests/test_diag.py @@ -1,22 +1,16 @@ -import uuid from datetime import datetime from functools import partial import numpy as np import pandas as pd import pytest -import xarray as xr from botocore.session import Session from moto.server import ThreadedMotoServer -from s3fs import S3FileSystem, S3Map from werkzeug.datastructures import MultiDict from unified_graphics import diag from unified_graphics.models import Analysis, WeatherModel -# Global resources for s3 -test_bucket_name = "osti-modeling-dev-rtma-vis" - @pytest.fixture def aws_credentials(monkeypatch): @@ -48,11 +42,6 @@ def s3_client(aws_credentials, moto_server): return session.create_client("s3", endpoint_url=moto_server) -@pytest.fixture -def test_key_prefix(): - return f"/test/{uuid.uuid4()}/" - - def test_get_model_metadata(session): model_run_list = [ ("RTMA", "WCOSS", "CONUS", "REALTIME", "HRRR", "2023-03-17T14:00"), @@ -85,181 +74,6 @@ def test_get_model_metadata(session): ) -@pytest.mark.parametrize( - "uri,expected", - [ - ("file:///tmp/diag.zarr", "/tmp/diag.zarr"), - ("/tmp/diag.zarr", "/tmp/diag.zarr"), - ], -) -def test_get_store_file(uri, expected): - result = diag.get_store(uri) - - assert result == expected - - -def test_get_store_s3(moto_server, s3_client, monkeypatch): - client = {"region_name": "us-east-1"} - uri = "s3://bucket/prefix/diag.zarr" - s3_client.create_bucket(Bucket="bucket") - s3_client.put_object(Bucket="bucket", Body=b"Test object", Key="prefix/diag.zarr") - - monkeypatch.setattr( - diag, - "S3FileSystem", - partial(diag.S3FileSystem, endpoint_url=moto_server), - ) - - result = diag.get_store(uri) - - assert result == S3Map( - root=uri, - s3=S3FileSystem( - client_kwargs=client, - endpoint_url=moto_server, - ), - check=False, - ) - - -def test_open_diagnostic(tmp_path, test_dataset): - diag_zarr_file = str(tmp_path / "test_diag.zarr") - expected = test_dataset() - group = "/".join( - ( - expected.model, - expected.system, - expected.domain, - expected.background, - expected.frequency, - expected.name, - expected.initialization_time, - expected.loop, - ) - ) - - expected.to_zarr(diag_zarr_file, group=group, consolidated=False) - - result = diag.open_diagnostic( - diag_zarr_file, - expected.model, - expected.system, - expected.domain, - expected.background, - expected.frequency, - diag.Variable(expected.name), - expected.initialization_time, - diag.MinimLoop(expected.loop), - ) - - xr.testing.assert_equal(result, expected) - - -@pytest.mark.parametrize( - "uri,expected", - [ - ( - "foo://an/unknown/uri.zarr", - "Unsupported protocol 'foo' for URI: 'foo://an/unknown/uri.zarr'", - ), - ( - "ftp://an/unsupported/uri.zarr", - "Unsupported protocol 'ftp' for URI: 'ftp://an/unsupported/uri.zarr'", - ), - ], -) -def test_open_diagnostic_unknown_uri(uri, expected): - model = "RTMA" - system = "WCOSS" - domain = "CONUS" - background = "HRRR" - frequency = "REALTIME" - init_time = "2022-05-16T04:00" - - with pytest.raises(ValueError, match=expected): - diag.open_diagnostic( - uri, - model, - system, - domain, - background, - frequency, - init_time, - diag.Variable.WIND, - diag.MinimLoop.GUESS, - ) - - -@pytest.mark.usefixtures("aws_credentials") -def test_open_diagnostic_s3(moto_server, test_dataset, monkeypatch): - store = "s3://test_open_diagnostic_s3/test_diag.zarr" - expected = test_dataset() - group = "/".join( - ( - expected.model, - expected.system, - expected.domain, - expected.background, - expected.frequency, - expected.name, - expected.initialization_time, - expected.loop, - ) - ) - - monkeypatch.setattr( - diag, - "S3FileSystem", - partial(diag.S3FileSystem, endpoint_url=moto_server), - ) - - expected.to_zarr( - store, - group=group, - consolidated=False, - storage_options={"endpoint_url": moto_server}, - ) - - result = diag.open_diagnostic( - store, - expected.model, - expected.system, - expected.domain, - expected.background, - expected.frequency, - diag.Variable(expected.name), - expected.initialization_time, - diag.MinimLoop(expected.loop), - ) - - xr.testing.assert_equal(result, expected) - - -@pytest.mark.parametrize( - "mapping,expected", - [ - ([("a", "1")], [("a", np.array([1.0]), np.array([1.0]))]), - ([("a", "1::2")], [("a", np.array([1.0]), np.array([2.0]))]), - ([("a", "2,4::3,1")], [("a", np.array([2.0, 1.0]), np.array([3.0, 4.0]))]), - ], - scope="class", -) -class TestGetBounds: - @pytest.fixture(scope="class") - def result(self, mapping): - filters = MultiDict(mapping) - return list(diag.get_bounds(filters)) - - def test_coord(self, result, expected): - assert result[0][0] == expected[0][0] - - def test_lower_bounds(self, result, expected): - assert (result[0][1] == expected[0][1]).all() - - def test_upper_bounds(self, result, expected): - assert (result[0][2] == expected[0][2]).all() - - def test_history(tmp_path, test_dataset, diag_parquet): run_list = [ {