Skip to content

Commit

Permalink
added some more Type Hints
Browse files Browse the repository at this point in the history
  • Loading branch information
blaylockbk committed Jul 27, 2024
1 parent 35faff1 commit 0e63523
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 78 deletions.
46 changes: 24 additions & 22 deletions herbie/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import re
import warnings
from pathlib import Path
from typing import Literal, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -23,7 +24,6 @@

import herbie


_level_units = dict(
adiabaticCondensation="adiabatic condensation",
atmosphere="atmosphere",
Expand Down Expand Up @@ -54,7 +54,7 @@
)


def add_proj_info(ds):
def add_proj_info(ds: xr.Dataset):
"""Add projection info to a Dataset."""
match = re.search(r'"source": "(.*?)"', ds.history)
FILE = Path(match.group(1))
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, xarray_obj):
self._center = None

@property
def center(self):
def center(self) -> tuple[float, float]:
"""Return the geographic center point of this dataset."""
if self._center is None:
# we can use a cache on our accessor objects, because accessors
Expand All @@ -107,19 +107,18 @@ def center(self):
self._center = (float(lon.mean()), float(lat.mean()))
return self._center

def to_180(self):
def to_180(self) -> xr.Dataset:
"""Wrap longitude coordinates as range [-180,180]."""
ds = self._obj
ds["longitude"] = (ds["longitude"] + 180) % 360 - 180
return ds

def to_360(self):
def to_360(self) -> xr.Dataset:
"""Wrap longitude coordinates as range [0,360]."""
ds = self._obj
ds["longitude"] = (ds["longitude"] - 360) % 360
return ds


@functools.cached_property
def crs(self):
"""
Expand Down Expand Up @@ -196,7 +195,9 @@ def polygon(self):

return domain_polygon, domain_polygon_latlon

def with_wind(self, which="both"):
def with_wind(
self, which: Literal["both", "speed", "direction"] = "both"
) -> xr.Dataset:
"""Return Dataset with calculated wind speed and/or direction.
Consistent with the eccodes GRIB parameter database, variables
Expand Down Expand Up @@ -228,7 +229,7 @@ def with_wind(self, which="both"):
ds["si10"].attrs["standard_name"] = "wind_speed"
ds["si10"].attrs["grid_mapping"] = ds.u10.attrs.get("grid_mapping")
n_computed += 1

if {"u100", "v100"}.issubset(ds):
ds["si100"] = np.sqrt(ds.u100**2 + ds.v100**2)
ds["si100"].attrs["GRIB_paramId"] = 228249
Expand All @@ -237,7 +238,7 @@ def with_wind(self, which="both"):
ds["si100"].attrs["standard_name"] = "wind_speed"
ds["si100"].attrs["grid_mapping"] = ds.u100.attrs.get("grid_mapping")
n_computed += 1

if {"u80", "v80"}.issubset(ds):
ds["si80"] = np.sqrt(ds.u80**2 + ds.v80**2)
ds["si80"].attrs["long_name"] = "80 metre wind speed"
Expand Down Expand Up @@ -266,7 +267,7 @@ def with_wind(self, which="both"):
ds["wdir10"].attrs["standard_name"] = "wind_from_direction"
ds["wdir10"].attrs["grid_mapping"] = ds.u10.attrs.get("grid_mapping")
n_computed += 1

if {"u100", "v100"}.issubset(ds):
ds["wdir100"] = (
(270 - np.rad2deg(np.arctan2(ds.v100, ds.u100))) % 360
Expand All @@ -276,7 +277,7 @@ def with_wind(self, which="both"):
ds["wdir100"].attrs["standard_name"] = "wind_from_direction"
ds["wdir100"].attrs["grid_mapping"] = ds.u100.attrs.get("grid_mapping")
n_computed += 1

if {"u80", "v80"}.issubset(ds):
ds["wdir80"] = (
(270 - np.rad2deg(np.arctan2(ds.v80, ds.u80))) % 360
Expand Down Expand Up @@ -305,15 +306,15 @@ def with_wind(self, which="both"):

def pick_points(
self,
points,
method="nearest",
points: pd.DataFrame,
method: Literal["nearest", "weighted"] = "nearest",
*,
k=None,
max_distance=500,
use_cached_tree=True,
tree_name=None,
verbose=False,
):
k: Optional[int] = None,
max_distance: Union[int, float] = 500,
use_cached_tree: Union[bool, Literal["replant"]] = True,
tree_name: Optional[str] = None,
verbose: bool = False,
) -> xr.Dataset:
"""Pick nearest neighbor grid values at selected points.
Parameters
Expand Down Expand Up @@ -384,7 +385,7 @@ def pick_points(
"`pip install 'herbie-data[extras]'` for the full functionality."
)

def plant_tree(save_pickle=None):
def plant_tree(save_pickle: Optional[Union[Path, str]] = None):
"""Grow a new BallTree object from seedling."""
timer = pd.Timestamp("now")
print("INFO: 🌱 Growing new BallTree...", end="")
Expand Down Expand Up @@ -719,9 +720,10 @@ def plot(self, ax=None, common_features_kw={}, vars=None, **kwargs):
raise NotImplementedError("Plotting functionality is not working right now.")

try:
from herbie.toolbox import EasyMap, pc
from herbie import paint
import matplotlib.pyplot as plt

from herbie import paint
from herbie.toolbox import EasyMap, pc
except ModuleNotFoundError:
raise ModuleNotFoundError(
"cartopy is an 'extra' requirement. Please use "
Expand Down
24 changes: 14 additions & 10 deletions herbie/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from herbie.help import _search_help
from herbie.misc import ANSI

Datetime = Union[datetime, pd.Timestamp, str]

# NOTE: The config dict values are retrieved from __init__ and read
# from the file ${HOME}/.config/herbie/config.toml
# Path is imported from __init__ because it has my custom methods.
Expand Down Expand Up @@ -82,7 +84,7 @@ def wgrib2_idx(grib2filepath: Union[Path, str]) -> str:
raise RuntimeError("wgrib2 command was not found.")


def create_index_files(path: Union[Path, str], overwrite: bool = False):
def create_index_files(path: Union[Path, str], overwrite: bool = False) -> None:
"""Create an index file for all GRIB2 files in a directory.
Parameters
Expand Down Expand Up @@ -169,9 +171,9 @@ class Herbie:

def __init__(
self,
date: Optional[Union[datetime, pd.Timestamp, str]] = None,
date: Optional[Datetime] = None,
*,
valid_date: Optional[Union[datetime, pd.Timestamp, str]] = None,
valid_date: Optional[Datetime] = None,
model: str = config["default"].get("model"),
fxx: int = config["default"].get("fxx"),
product: str = config["default"].get("product"),
Expand Down Expand Up @@ -303,7 +305,7 @@ def __bool__(self) -> bool:
"""Herbie evaluated True if the GRIB file exists."""
return bool(self.grib)

def help(self):
def help(self) -> None:
"""Print help message if available."""
if hasattr(self, "HELP"):
HELP = self.HELP.strip().replace("\n", "\n│ ")
Expand All @@ -320,7 +322,7 @@ def help(self):
print("│")
print("╰─────────────────────────────────────────")

def tell_me_everything(self):
def tell_me_everything(self) -> None:
"""Print all the attributes of the Herbie object."""
msg = []
for i in dir(self):
Expand All @@ -330,11 +332,11 @@ def tell_me_everything(self):
msg = "\n".join(msg)
print(msg)

def __logo__(self):
def __logo__(self) -> None:
"""For Fun, show the Herbie Logo."""
print(ANSI.ascii)

def _validate(self):
def _validate(self) -> None:
"""Validate the Herbie class input arguments."""
# Accept model alias
if self.model.lower() == "alaska":
Expand Down Expand Up @@ -368,7 +370,7 @@ def _validate(self):
if self.date < expired:
self.priority.remove("nomads")

def _ping_pando(self):
def _ping_pando(self) -> None:
"""Pinging the Pando server before downloading can prevent a bad handshake."""
try:
requests.head("https://pando-rgw01.chpc.utah.edu/")
Expand Down Expand Up @@ -523,9 +525,11 @@ def get_localFileName(self) -> str:
"""Predict the local file name."""
return self.LOCALFILE

def get_localFilePath(self, search: Optional[str] = None, *, searchString=None):
def get_localFilePath(
self, search: Optional[str] = None, *, searchString=None
) -> Path:
"""Get full path to the local file."""
# TODO: Remove this eventually
# TODO: Remove this check for searString eventually
if searchString is not None:
warnings.warn(
"The argument `searchString` was renamed `search`. Please update your scripts.",
Expand Down
48 changes: 31 additions & 17 deletions herbie/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
"""

import logging

# Multithreading :)
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Union, Optional
from pathlib import Path

import pandas as pd
import xarray as xr
Expand All @@ -19,6 +20,7 @@

log = logging.getLogger(__name__)

Datetime = Union[datetime, pd.Timestamp, str]

"""
🧵🤹🏻‍♂️ Notice! Multithreading and Multiprocessing is use
Expand All @@ -30,8 +32,8 @@
"""


def _validate_fxx(fxx):
"""Fast Herbie requires fxx as a list-like"""
def _validate_fxx(fxx: Union[int, Union[list[int], range]]) -> Union[list[int], range]:
"""Fast Herbie requires fxx as a list-like."""
if isinstance(fxx, int):
fxx = [fxx]

Expand All @@ -41,8 +43,8 @@ def _validate_fxx(fxx):
return fxx


def _validate_DATES(DATES):
"""Fast Herbie requires DATES as a list-like"""
def _validate_DATES(DATES: Union[Datetime, list[Datetime]]) -> list[Datetime]:
"""Fast Herbie requires DATES as a list-like."""
if isinstance(DATES, str):
DATES = [pd.to_datetime(DATES)]
elif not hasattr(DATES, "__len__"):
Expand All @@ -56,7 +58,7 @@ def _validate_DATES(DATES):
return DATES


def Herbie_latest(n=6, freq="1h", **kwargs):
def Herbie_latest(n: int = 6, freq: str = "1h", **kwargs) -> Herbie:
"""Search for the most recent GRIB2 file (using multithreading).
Parameters
Expand Down Expand Up @@ -85,7 +87,16 @@ def Herbie_latest(n=6, freq="1h", **kwargs):


class FastHerbie:
def __init__(self, DATES, fxx=[0], *, max_threads=50, **kwargs):
"""Create many Herbie objects quickly."""

def __init__(
self,
DATES: Union[Datetime, list[Datetime]],
fxx: Union[int, list[int]] = [0],
*,
max_threads: int = 50,
**kwargs,
):
"""Create many Herbie objects with methods to download or read with xarray.
Uses multithreading.
Expand Down Expand Up @@ -156,10 +167,11 @@ def __init__(self, DATES, fxx=[0], *, max_threads=50, **kwargs):
f"Could not find {len(self.file_not_exists)}/{len(self.file_exists)} GRIB files."
)

def __len__(self):
def __len__(self) -> int:
"""Return the number of Herbie objects."""
return len(self.objects)

def df(self):
def df(self) -> pd.DataFrame:
"""Organize Herbie objects into a DataFrame.
#? Why is this inefficient? Takes several seconds to display because the __str__ does a lot.
Expand All @@ -172,7 +184,7 @@ def df(self):
ds_list, index=self.DATES, columns=[f"F{i:02d}" for i in self.fxx]
)

def inventory(self, search=None):
def inventory(self, search: Optional[str] = None):
"""Get combined inventory DataFrame.
Useful for data discovery and checking your search before
Expand All @@ -186,8 +198,10 @@ def inventory(self, search=None):
dfs.append(df)
return pd.concat(dfs, ignore_index=True)

def download(self, search=None, *, max_threads=20, **download_kwargs):
r"""Download many Herbie objects
def download(
self, search: Optional[str] = None, *, max_threads: int = 20, **download_kwargs
) -> list[Path]:
r"""Download many Herbie objects.
Uses multithreading.
Expand Down Expand Up @@ -231,11 +245,11 @@ def download(self, search=None, *, max_threads=20, **download_kwargs):

def xarray(
self,
search,
search: Optional[str],
*,
max_threads=None,
max_threads: Optional[int] = None,
**xarray_kwargs,
):
) -> xr.Dataset:
"""Read many Herbie objects into an xarray Dataset.
# TODO: Sometimes the Jupyter Cell always crashes when I run this.
Expand Down Expand Up @@ -302,7 +316,7 @@ def xarray(
concat_dim=["time", "step"],
combine_attrs="drop_conflicts",
)
except:
except Exception:
# TODO: I'm not sure why some cases doesn't like the combine_attrs argument
ds = xr.combine_nested(
ds_list,
Expand Down
Loading

0 comments on commit 0e63523

Please sign in to comment.