Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Commit

Permalink
Replace use of Zarr with Parquet
Browse files Browse the repository at this point in the history
Eliminate all of the code that read from Zarr files for diagnostic data
and switch over to reading the data from the Parquet files we already
keep for the time series data.
esheehan-gsl committed Nov 17, 2023
1 parent 31d4d08 commit 36775e9
Showing 3 changed files with 89 additions and 432 deletions.
275 changes: 51 additions & 224 deletions src/unified_graphics/diag.py
Original file line number Diff line number Diff line change
@@ -2,17 +2,11 @@
from collections import namedtuple
from datetime import datetime, timedelta
from enum import Enum
from typing import Union
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import sqlalchemy as sa
import xarray as xr
import zarr # type: ignore
from s3fs import S3FileSystem, S3Map # type: ignore
from werkzeug.datastructures import MultiDict
from xarray.core.dataset import Dataset

from .models import Analysis, WeatherModel

@@ -88,242 +82,75 @@ def get_model_metadata(session) -> ModelMetadata:
)


def get_store(url: str) -> Union[str, S3Map]:
result = urlparse(url)
if result.scheme in ["", "file"]:
return result.path

if result.scheme != "s3":
raise ValueError(f"Unsupported protocol '{result.scheme}' for URI: '{url}'")

region = os.environ.get("AWS_REGION", "us-east-1")
s3 = S3FileSystem(
key=os.environ.get("AWS_ACCESS_KEY_ID"),
secret=os.environ.get("AWS_SECRET_ACCESS_KEY"),
token=os.environ.get("AWS_SESSION_TOKEN"),
client_kwargs={"region_name": region},
)

return S3Map(root=f"{result.netloc}{result.path}", s3=s3, check=False)


def open_diagnostic(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
variable: Variable,
initialization_time: str,
loop: MinimLoop,
) -> xr.Dataset:
store = get_store(diag_zarr)
group = (
f"/{model}/{system}/{domain}/{background}/{frequency}"
f"/{variable.value}/{initialization_time}/{loop.value}"
)
return xr.open_zarr(store, group=group, consolidated=False)


def parse_filter_value(value):
if value == "true":
return 1

if value == "false":
return 0

try:
return float(value)
except ValueError:
return value


# TODO: Refactor to a class
# I think this might belong in a different module. It could be a class or set of classes
# that represent different filters that can be added together into a filtering pipeline
def get_bounds(filters: MultiDict):
for coord, value in filters.items():
extent = np.array(
[
[parse_filter_value(digit) for digit in pair.split(",")]
for pair in value.split("::")
]
)
yield coord, extent.min(axis=0), extent.max(axis=0)


def apply_filters(dataset: xr.Dataset, filters: MultiDict) -> Dataset:
for coord, lower, upper in get_bounds(filters):
data_array = dataset[coord]
dataset = dataset.where((data_array >= lower) & (data_array <= upper)).dropna(
dim="nobs"
)

# If the is_used filter is not passed, our default behavior is to include only used
# observations.
if "is_used" not in filters:
dataset = dataset.where(dataset["is_used"]).dropna(dim="nobs")

return dataset


def scalar(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
variable: Variable,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame:
data = open_diagnostic(
diag_zarr,
model,
system,
domain,
background,
frequency,
variable,
initialization_time,
loop,
def magnitude(dataset: pd.DataFrame) -> pd.DataFrame:
return dataset.groupby(level=0).aggregate(
{
"obs_minus_forecast_adjusted": np.linalg.norm,
"obs_minus_forecast_unadjusted": np.linalg.norm,
"observation": np.linalg.norm,
"longitude": "first",
"latitude": "first",
}
)
data = apply_filters(data, filters)

return data.to_dataframe()


def temperature(
diag_zarr: str,
def diag_observations(
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
variable: str,
init_time: datetime,
loop: str,
uri: str,
filters: dict = {},
) -> pd.DataFrame:
return scalar(
diag_zarr,
model,
system,
domain,
background,
frequency,
Variable.TEMPERATURE,
initialization_time,
loop,
filters,
)
def matches(df: pd.DataFrame, filters: dict) -> bool:
result = True

for col_name, filter_value in filters.items():
arr = np.array(filter_value)
result &= (df[col_name] >= arr.min(axis=0)).all()
result &= (df[col_name] <= arr.max(axis=0)).all()

def moisture(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame:
return scalar(
diag_zarr,
model,
system,
domain,
background,
frequency,
Variable.MOISTURE,
initialization_time,
loop,
filters,
)
if "is_used" not in filters:
result &= df["is_used"].any()

return result

def pressure(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame:
return scalar(
diag_zarr,
model,
system,
domain,
background,
frequency,
Variable.PRESSURE,
initialization_time,
loop,
filters,
)

model_config = "_".join((model, background, system, domain, frequency))

def wind(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame | pd.Series:
data = open_diagnostic(
diag_zarr,
model,
system,
domain,
background,
frequency,
Variable.WIND,
initialization_time,
loop,
df = pd.read_parquet(
"/".join((uri, model_config, variable)),
columns=[
"obs_minus_forecast_adjusted",
"obs_minus_forecast_unadjusted",
"observation",
"latitude",
"longitude",
"is_used",
],
filters=(
("loop", "=", loop),
("initialization_time", "=", init_time),
),
)

data = apply_filters(data, filters)

return data.to_dataframe()


def magnitude(dataset: pd.DataFrame) -> pd.DataFrame:
return dataset.groupby(level=0).aggregate(
{
"obs_minus_forecast_adjusted": np.linalg.norm,
"obs_minus_forecast_unadjusted": np.linalg.norm,
"observation": np.linalg.norm,
"longitude": "first",
"latitude": "first",
}
)
# Group the rows of the DataFrame by nobs (effectively the observation ID) and test
# each group against our filters. This is necessary because we use a MultiIndex for
# vectors where the second level of the index is the vector component. If we don't
# group the components like this, we run the risk that one component matches the
# filters and the other doesn't, leaving us with a partial observation.
matching_obs = [obs for _, obs in df.groupby("nobs") if matches(obs, filters)]

# If no observations match the filters, return an empty DataFrame by masking out all
# the values in the DataFrame using a list of repeated False values
if len(matching_obs) < 1:
return df[[False] * len(df)]

def get_model_run_list(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
variable: Variable,
):
store = get_store(diag_zarr)
path = "/".join([model, system, domain, background, frequency, variable.value])
with zarr.open_group(store, mode="r", path=path) as group:
return group.group_keys()
# Otherwise concatenate the matching DataFrames back into a single DataFrame
return pd.concat(matching_obs)


def history(
60 changes: 38 additions & 22 deletions src/unified_graphics/routes.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
stream_template,
url_for,
)
from werkzeug.datastructures import MultiDict
from zarr.errors import FSPathExistNotDir, GroupNotFoundError # type: ignore

from unified_graphics import diag
@@ -18,6 +19,27 @@
bp = Blueprint("api", __name__)


def parse_filters(query: MultiDict) -> dict:
def parse_value(value):
if "::" in value:
return [parse_value(tok) for tok in value.split("::")]

if value in ["true", "false"]:
return value == "true"

try:
return float(value)
except ValueError:
return value

filters = {}
for col, value_list in query.lists():
val = [parse_value(val) for val in value_list]
filters[col] = val if len(val) > 1 else val[0]

return filters


@bp.errorhandler(GroupNotFoundError)
def handle_diag_group_not_found(e):
current_app.logger.exception("Unable to read diagnostic group")
@@ -167,22 +189,19 @@ def history(model, system, domain, background, frequency, variable, loop):
def diagnostics(
model, system, domain, background, frequency, variable, initialization_time, loop
):
try:
v = diag.Variable(variable)
except ValueError:
return jsonify(msg=f"Variable not found: '{variable}'"), 404

variable_diagnostics = getattr(diag, v.name.lower())
data = variable_diagnostics(
current_app.config["DIAG_ZARR"],
filters = parse_filters(request.args)

data = diag.diag_observations(
model,
system,
domain,
background,
frequency,
initialization_time,
diag.MinimLoop(loop),
request.args,
variable,
datetime.fromisoformat(initialization_time),
loop,
current_app.config["DIAG_PARQUET"],
filters,
)[
[
"obs_minus_forecast_adjusted",
@@ -207,22 +226,19 @@ def diagnostics(
def magnitude(
model, system, domain, background, frequency, variable, initialization_time, loop
):
try:
v = diag.Variable(variable)
except ValueError:
return jsonify(msg=f"Variable not found: '{variable}'"), 404

variable_diagnostics = getattr(diag, v.name.lower())
data = variable_diagnostics(
current_app.config["DIAG_ZARR"],
filters = parse_filters(request.args)

data = diag.diag_observations(
model,
system,
domain,
background,
frequency,
initialization_time,
diag.MinimLoop(loop),
request.args,
variable,
datetime.fromisoformat(initialization_time),
loop,
current_app.config["DIAG_PARQUET"],
filters,
)[
[
"obs_minus_forecast_adjusted",
186 changes: 0 additions & 186 deletions tests/test_diag.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
import uuid
from datetime import datetime
from functools import partial

import numpy as np
import pandas as pd
import pytest
import xarray as xr
from botocore.session import Session
from moto.server import ThreadedMotoServer
from s3fs import S3FileSystem, S3Map
from werkzeug.datastructures import MultiDict

from unified_graphics import diag
from unified_graphics.models import Analysis, WeatherModel

# Global resources for s3
test_bucket_name = "osti-modeling-dev-rtma-vis"


@pytest.fixture
def aws_credentials(monkeypatch):
@@ -48,11 +42,6 @@ def s3_client(aws_credentials, moto_server):
return session.create_client("s3", endpoint_url=moto_server)


@pytest.fixture
def test_key_prefix():
return f"/test/{uuid.uuid4()}/"


def test_get_model_metadata(session):
model_run_list = [
("RTMA", "WCOSS", "CONUS", "REALTIME", "HRRR", "2023-03-17T14:00"),
@@ -85,181 +74,6 @@ def test_get_model_metadata(session):
)


@pytest.mark.parametrize(
"uri,expected",
[
("file:///tmp/diag.zarr", "/tmp/diag.zarr"),
("/tmp/diag.zarr", "/tmp/diag.zarr"),
],
)
def test_get_store_file(uri, expected):
result = diag.get_store(uri)

assert result == expected


def test_get_store_s3(moto_server, s3_client, monkeypatch):
client = {"region_name": "us-east-1"}
uri = "s3://bucket/prefix/diag.zarr"
s3_client.create_bucket(Bucket="bucket")
s3_client.put_object(Bucket="bucket", Body=b"Test object", Key="prefix/diag.zarr")

monkeypatch.setattr(
diag,
"S3FileSystem",
partial(diag.S3FileSystem, endpoint_url=moto_server),
)

result = diag.get_store(uri)

assert result == S3Map(
root=uri,
s3=S3FileSystem(
client_kwargs=client,
endpoint_url=moto_server,
),
check=False,
)


def test_open_diagnostic(tmp_path, test_dataset):
diag_zarr_file = str(tmp_path / "test_diag.zarr")
expected = test_dataset()
group = "/".join(
(
expected.model,
expected.system,
expected.domain,
expected.background,
expected.frequency,
expected.name,
expected.initialization_time,
expected.loop,
)
)

expected.to_zarr(diag_zarr_file, group=group, consolidated=False)

result = diag.open_diagnostic(
diag_zarr_file,
expected.model,
expected.system,
expected.domain,
expected.background,
expected.frequency,
diag.Variable(expected.name),
expected.initialization_time,
diag.MinimLoop(expected.loop),
)

xr.testing.assert_equal(result, expected)


@pytest.mark.parametrize(
"uri,expected",
[
(
"foo://an/unknown/uri.zarr",
"Unsupported protocol 'foo' for URI: 'foo://an/unknown/uri.zarr'",
),
(
"ftp://an/unsupported/uri.zarr",
"Unsupported protocol 'ftp' for URI: 'ftp://an/unsupported/uri.zarr'",
),
],
)
def test_open_diagnostic_unknown_uri(uri, expected):
model = "RTMA"
system = "WCOSS"
domain = "CONUS"
background = "HRRR"
frequency = "REALTIME"
init_time = "2022-05-16T04:00"

with pytest.raises(ValueError, match=expected):
diag.open_diagnostic(
uri,
model,
system,
domain,
background,
frequency,
init_time,
diag.Variable.WIND,
diag.MinimLoop.GUESS,
)


@pytest.mark.usefixtures("aws_credentials")
def test_open_diagnostic_s3(moto_server, test_dataset, monkeypatch):
store = "s3://test_open_diagnostic_s3/test_diag.zarr"
expected = test_dataset()
group = "/".join(
(
expected.model,
expected.system,
expected.domain,
expected.background,
expected.frequency,
expected.name,
expected.initialization_time,
expected.loop,
)
)

monkeypatch.setattr(
diag,
"S3FileSystem",
partial(diag.S3FileSystem, endpoint_url=moto_server),
)

expected.to_zarr(
store,
group=group,
consolidated=False,
storage_options={"endpoint_url": moto_server},
)

result = diag.open_diagnostic(
store,
expected.model,
expected.system,
expected.domain,
expected.background,
expected.frequency,
diag.Variable(expected.name),
expected.initialization_time,
diag.MinimLoop(expected.loop),
)

xr.testing.assert_equal(result, expected)


@pytest.mark.parametrize(
"mapping,expected",
[
([("a", "1")], [("a", np.array([1.0]), np.array([1.0]))]),
([("a", "1::2")], [("a", np.array([1.0]), np.array([2.0]))]),
([("a", "2,4::3,1")], [("a", np.array([2.0, 1.0]), np.array([3.0, 4.0]))]),
],
scope="class",
)
class TestGetBounds:
@pytest.fixture(scope="class")
def result(self, mapping):
filters = MultiDict(mapping)
return list(diag.get_bounds(filters))

def test_coord(self, result, expected):
assert result[0][0] == expected[0][0]

def test_lower_bounds(self, result, expected):
assert (result[0][1] == expected[0][1]).all()

def test_upper_bounds(self, result, expected):
assert (result[0][2] == expected[0][2]).all()


def test_history(tmp_path, test_dataset, diag_parquet):
run_list = [
{

0 comments on commit 36775e9

Please sign in to comment.