-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathfast.py
328 lines (263 loc) · 10.3 KB
/
fast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
## Brian Blaylock
## May 3, 2021
"""
============
Herbie Tools
============
"""
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Union, Optional
from pathlib import Path
import pandas as pd
import xarray as xr
from herbie.core import Herbie
log = logging.getLogger(__name__)
Datetime = Union[datetime, pd.Timestamp, str]
"""
🧵🤹🏻♂️ Notice! Multithreading and Multiprocessing is use
This is my first implementation of multithreading to create, download,
and read many Herbie objects. This drastically reduces the time it takes
to create a Herbie object (which is just looking for if and where a
GRIB2 file exists on the internet) and to download a file.
"""
def _validate_fxx(fxx: Union[int, Union[list[int], range]]) -> Union[list[int], range]:
"""Fast Herbie requires fxx as a list-like."""
if isinstance(fxx, int):
fxx = [fxx]
if not isinstance(fxx, (list, range)):
raise ValueError(f"fxx must be an int, list, or range. Gave {fxx}")
return fxx
def _validate_DATES(DATES: Union[Datetime, list[Datetime]]) -> list[Datetime]:
"""Fast Herbie requires DATES as a list-like."""
if isinstance(DATES, str):
DATES = [pd.to_datetime(DATES)]
elif not hasattr(DATES, "__len__"):
DATES = [pd.to_datetime(DATES)]
if not isinstance(DATES, (list, pd.DatetimeIndex)):
raise ValueError(
f"DATES must be a pandas-parsable datetime string or a list. Gave {DATES}"
)
return DATES
def Herbie_latest(n: int = 6, freq: str = "1h", **kwargs) -> Herbie:
"""Search for the most recent GRIB2 file (using multithreading).
Parameters
----------
n : int
Number of attempts to try.
freq : pandas-parsable timedelta string
Time interval between each attempt.
Examples
--------
When ``n=6``, and ``freq='1H'``, Herbie will look for the latest
file within the last 6 hours (suitable for the HRRR model).
When ``n=3``, and ``freq='6H'``, Herbie will look for the latest
file within the last 18 hours (suitable for the GFS model).
"""
current = pd.Timestamp.now("utc").tz_localize(None).floor(freq)
DATES = pd.date_range(
start=current - (pd.Timedelta(freq) * n),
end=current,
freq=freq,
)
FH = FastHerbie(DATES, **kwargs)
return FH.file_exists[-1]
class FastHerbie:
"""Create many Herbie objects quickly."""
def __init__(
self,
DATES: Union[Datetime, list[Datetime]],
fxx: Union[int, list[int]] = [0],
*,
max_threads: int = 50,
**kwargs,
):
"""Create many Herbie objects with methods to download or read with xarray.
Uses multithreading.
.. note::
Currently, Herbie objects looped by run datetime (date)
and forecast lead time (fxx).
Parameters
----------
DATES : pandas-parsable datetime string or list of datetimes
fxx : int or list of forecast lead times
max_threads : int
Maximum number of threads to use.
kwargs :
Remaining keywords for Herbie object
(e.g., model, product, priority, verbose, etc.)
Benchmark
---------
Creating 48 Herbie objects
- 1 thread took 16 s
- 2 threads took 8 s
- 5 threads took 3.3 s
- 10 threads took 1.7 s
- 50 threads took 0.5 s
"""
self.DATES = _validate_DATES(DATES)
self.fxx = _validate_fxx(fxx)
kwargs.setdefault("verbose", False)
################
# Multithreading
self.tasks = len(DATES) * len(fxx)
threads = min(self.tasks, max_threads)
log.info(f"🧵 Working on {self.tasks} tasks with {threads} threads.")
self.objects = []
with ThreadPoolExecutor(threads) as exe:
futures = [
exe.submit(Herbie, date=DATE, fxx=f, **kwargs)
for DATE in DATES
for f in fxx
]
# Return list of Herbie objects in order completed
for future in as_completed(futures):
if future.exception() is None:
self.objects.append(future.result())
else:
log.error(f"Exception has occured : {future.exception()}")
log.info(f"Number of Herbie objects: {len(self.objects)}")
# Sort the list of Herbie objects by lead time then by date
self.objects.sort(key=lambda H: H.fxx)
self.objects.sort(key=lambda H: H.date)
self.objects = self.objects
# Which files exist?
self.file_exists = [H for H in self.objects if H.grib is not None]
self.file_not_exists = [H for H in self.objects if H.grib is None]
if len(self.file_not_exists) > 0:
log.warning(
f"Could not find {len(self.file_not_exists)}/{len(self.file_exists)} GRIB files."
)
def __len__(self) -> int:
"""Return the number of Herbie objects."""
return len(self.objects)
def df(self) -> pd.DataFrame:
"""Organize Herbie objects into a DataFrame.
#? Why is this inefficient? Takes several seconds to display because the __str__ does a lot.
"""
ds_list = [
self.objects[x : x + len(self.fxx)]
for x in range(0, len(self.objects), len(self.fxx))
]
return pd.DataFrame(
ds_list, index=self.DATES, columns=[f"F{i:02d}" for i in self.fxx]
)
def inventory(self, search: Optional[str] = None):
"""Get combined inventory DataFrame.
Useful for data discovery and checking your search before
doing a download.
"""
# NOTE: In my quick test, you don't gain much speed using multithreading here.
dfs = []
for i in self.file_exists:
df = i.inventory(search)
df = df.assign(FILE=i.grib)
dfs.append(df)
return pd.concat(dfs, ignore_index=True)
def download(
self, search: Optional[str] = None, *, max_threads: int = 20, **download_kwargs
) -> list[Path]:
r"""Download many Herbie objects.
Uses multithreading.
Parameters
----------
search : string
Regular expression string to specify which GRIB messages to
download.
**download_kwargs :
Any kwarg for Herbie's download method.
Benchmark
---------
Downloading 48 files with 1 variable (TMP:2 m)
- 1 thread took 1 min 17 s
- 2 threads took 36 s
- 5 threads took 28 s
- 10 threads took 25 s
- 50 threads took 23 s
"""
###########################
# Multithread the downloads
threads = min(self.tasks, max_threads)
log.info(f"🧵 Working on {self.tasks} tasks with {threads} threads.")
outFiles = []
with ThreadPoolExecutor(threads) as exe:
futures = [
exe.submit(H.download, search, **download_kwargs)
for H in self.file_exists
]
# Return list of Herbie objects in order completed
for future in as_completed(futures):
if future.exception() is None:
outFiles.append(future.result())
else:
log.error(f"Exception has occured : {future.exception()}")
return outFiles
def xarray(
self,
search: Optional[str],
*,
max_threads: Optional[int] = None,
**xarray_kwargs,
) -> xr.Dataset:
"""Read many Herbie objects into an xarray Dataset.
# TODO: Sometimes the Jupyter Cell always crashes when I run this.
# TODO: "fatal flex scanner internal error--end of buffer missed"
Uses multithreading (or multiprocessing).
This would likely benefit from multiprocessing instead.
Parameters
----------
max_threads : int
Control the maximum number of threads to use.
If you use too many threads, you may run into memory limits.
Benchmark
---------
Opening 48 files with 1 variable (TMP:2 m)
- 1 thread took 1 min 45 s
- 2 threads took 55 s
- 5 threads took 39 s
- 10 threads took 39 s
- 50 threads took 37 s
"""
xarray_kwargs = dict(search=search, **xarray_kwargs)
# NOTE: Multiprocessing does not seem to work because it looks
# NOTE: like xarray objects are not pickleable.
# NOTE: ``Reason: 'TypeError("cannot pickle '_thread.lock' object"``
if max_threads:
###########################
# Multithread the downloads
# ! Only works sometimes
# ! I get this error: "'EntryPoint' object has no attribute '_key'""
threads = min(self.tasks, max_threads)
log.info(f"🧵 Working on {self.tasks} tasks with {threads} threads.")
with ThreadPoolExecutor(max_threads) as exe:
futures = [
exe.submit(H.xarray, **xarray_kwargs) for H in self.file_exists
]
# Return list of Herbie objects in order completed
ds_list = [future.result() for future in as_completed(futures)]
else:
ds_list = [H.xarray(**xarray_kwargs) for H in self.file_exists]
# Sort the DataSets, first by lead time (step), then by run time (time)
ds_list.sort(key=lambda x: x.step.data.max())
ds_list.sort(key=lambda x: x.time.data.max())
# Reshape list with dimensions (len(DATES), len(fxx))
ds_list = [
ds_list[x : x + len(self.fxx)]
for x in range(0, len(ds_list), len(self.fxx))
]
# Concat DataSets
try:
ds = xr.combine_nested(
ds_list,
concat_dim=["time", "step"],
combine_attrs="drop_conflicts",
)
except Exception:
# TODO: I'm not sure why some cases doesn't like the combine_attrs argument
ds = xr.combine_nested(
ds_list,
concat_dim=["time", "step"],
)
ds = ds.squeeze()
return ds