Skip to content
This repository has been archived by the owner on Sep 30, 2024. It is now read-only.

Replace zarr with parquet #432

Merged
merged 9 commits into from
Nov 29, 2023
285 changes: 60 additions & 225 deletions src/unified_graphics/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@
from collections import namedtuple
from datetime import datetime, timedelta
from enum import Enum
from typing import Union
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import sqlalchemy as sa
import xarray as xr
import zarr # type: ignore
from s3fs import S3FileSystem, S3Map # type: ignore
from werkzeug.datastructures import MultiDict
from xarray.core.dataset import Dataset

from .models import Analysis, WeatherModel

Expand Down Expand Up @@ -88,217 +82,6 @@ def get_model_metadata(session) -> ModelMetadata:
)


def get_store(url: str) -> Union[str, S3Map]:
result = urlparse(url)
if result.scheme in ["", "file"]:
return result.path

if result.scheme != "s3":
raise ValueError(f"Unsupported protocol '{result.scheme}' for URI: '{url}'")

region = os.environ.get("AWS_REGION", "us-east-1")
s3 = S3FileSystem(
key=os.environ.get("AWS_ACCESS_KEY_ID"),
secret=os.environ.get("AWS_SECRET_ACCESS_KEY"),
token=os.environ.get("AWS_SESSION_TOKEN"),
client_kwargs={"region_name": region},
)

return S3Map(root=f"{result.netloc}{result.path}", s3=s3, check=False)


def open_diagnostic(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
variable: Variable,
initialization_time: str,
loop: MinimLoop,
) -> xr.Dataset:
store = get_store(diag_zarr)
group = (
f"/{model}/{system}/{domain}/{background}/{frequency}"
f"/{variable.value}/{initialization_time}/{loop.value}"
)
return xr.open_zarr(store, group=group, consolidated=False)


def parse_filter_value(value):
if value == "true":
return 1

if value == "false":
return 0

try:
return float(value)
except ValueError:
return value


# TODO: Refactor to a class
# I think this might belong in a different module. It could be a class or set of classes
# that represent different filters that can be added together into a filtering pipeline
def get_bounds(filters: MultiDict):
for coord, value in filters.items():
extent = np.array(
[
[parse_filter_value(digit) for digit in pair.split(",")]
for pair in value.split("::")
]
)
yield coord, extent.min(axis=0), extent.max(axis=0)


def apply_filters(dataset: xr.Dataset, filters: MultiDict) -> Dataset:
for coord, lower, upper in get_bounds(filters):
data_array = dataset[coord]
dataset = dataset.where((data_array >= lower) & (data_array <= upper)).dropna(
dim="nobs"
)

# If the is_used filter is not passed, our default behavior is to include only used
# observations.
if "is_used" not in filters:
dataset = dataset.where(dataset["is_used"]).dropna(dim="nobs")

return dataset


def scalar(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
variable: Variable,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame:
data = open_diagnostic(
diag_zarr,
model,
system,
domain,
background,
frequency,
variable,
initialization_time,
loop,
)
data = apply_filters(data, filters)

return data.to_dataframe()


def temperature(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame:
return scalar(
diag_zarr,
model,
system,
domain,
background,
frequency,
Variable.TEMPERATURE,
initialization_time,
loop,
filters,
)


def moisture(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame:
return scalar(
diag_zarr,
model,
system,
domain,
background,
frequency,
Variable.MOISTURE,
initialization_time,
loop,
filters,
)


def pressure(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame:
return scalar(
diag_zarr,
model,
system,
domain,
background,
frequency,
Variable.PRESSURE,
initialization_time,
loop,
filters,
)


def wind(
diag_zarr: str,
model: str,
system: str,
domain: str,
background: str,
frequency: str,
initialization_time: str,
loop: MinimLoop,
filters: MultiDict,
) -> pd.DataFrame | pd.Series:
data = open_diagnostic(
diag_zarr,
model,
system,
domain,
background,
frequency,
Variable.WIND,
initialization_time,
loop,
)

data = apply_filters(data, filters)

return data.to_dataframe()


def magnitude(dataset: pd.DataFrame) -> pd.DataFrame:
return dataset.groupby(level=0).aggregate(
{
Expand All @@ -311,19 +94,71 @@ def magnitude(dataset: pd.DataFrame) -> pd.DataFrame:
)


def get_model_run_list(
diag_zarr: str,
def diag_observations(
model: str,
system: str,
domain: str,
background: str,
frequency: str,
variable: Variable,
):
store = get_store(diag_zarr)
path = "/".join([model, system, domain, background, frequency, variable.value])
with zarr.open_group(store, mode="r", path=path) as group:
return group.group_keys()
variable: str,
init_time: datetime,
loop: str,
uri: str,
filters: dict = {},
) -> pd.DataFrame | pd.Series:
model_config = "_".join((model, background, system, domain, frequency))

is_used = filters.pop("is_used", True)
parquet_filters = [
("loop", "=", loop),
("initialization_time", "=", init_time),
]

if isinstance(is_used, bool):
parquet_filters.append(("is_used", "=", is_used))

df = pd.read_parquet(
"/".join((uri, model_config, variable)),
columns=[
"obs_minus_forecast_adjusted",
"obs_minus_forecast_unadjusted",
"observation",
"latitude",
"longitude",
"is_used",
],
filters=parquet_filters,
)

# To apply the filters, we need the vector components in the columns, not
# the rows.
# FIXME: We should consider changing how we store the vector data so we
# don't have to unstack it every time.
if "component" in df.index.names:
# FIXME: Specifically unstack the component level of the index because
# I'm seeing some data where the index is (component, nobs) instead of
# (nobs, component)
df = df.unstack("component") # type: ignore

# Iterate over each filter and apply it
for col_name, filter_value in filters.items():
arr = np.array(filter_value)

# Boolean mask for the rows in the data that are within the range
# specified by the filter
mask = (df[col_name] >= arr.min(axis=0)) & (df[col_name] <= arr.max(axis=0))

# In the event of a vector variable, we will have a DataFrame mask
# instead of a Series, which we need to flatten to a series which
# evaluates to True only when every column in the frame is True. If any
# column is False, this row should be excluded from the data
if len(mask.shape) > 1:
mask = mask.all(axis="columns")
esheehan-gsl marked this conversation as resolved.
Show resolved Hide resolved

# Apply the mask
df = df[mask]

return df


def history(
Expand Down
Loading
Loading