Skip to content

Commit

Permalink
Cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
lbesnard committed Jan 13, 2025
1 parent adcb70c commit f9f77ea
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 107 deletions.
191 changes: 84 additions & 107 deletions aodn_cloud_optimised/lib/GenericZarrHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,53 @@
)


def check_variable_values_dask(
file_path, reference_values, variable_name, dataset_config, uuid_log
):
"""
Check if the values of a specified variable in a single file are consistent with the reference.
Args:
file_path (str): File path to check.
reference_values (np.ndarray): Reference values for the variable.
variable_name (str): Name of the variable to check.
Returns:
tuple: (file_path, bool) where bool indicates if the file is problematic.
Comment:
this variable cannot be in the class below. Otherwise, the self cannot be serialized when calling future
"""
logger_name = dataset_config.get("logger_name", "generic")
logger = get_logger(logger_name)
try:
ds = xr.open_dataset(file_path)
variable_values = ds[variable_name].values
ds.close()

logger.debug(
f"{uuid_log}: {file_path} checking {variable_name} is consistent with the reference values."
)

logger.debug(f"{uuid_log}: reference values\n{reference_values}.")

res = not np.array_equal(variable_values, reference_values)

if res is True:
logger.error(
f"{uuid_log}: {file_path} - {variable_name} is NOT consistent with the reference values."
)
logger.error(f"{uuid_log}: variable values\n{variable_values}.")
else:
logger.debug(f"{uuid_log}: variable values\n{variable_values}.")

# Check if the values are identical
return file_path, res
except Exception as e:
logger.error(f"{uuid_log}: Failed to open {file_path}: {e}")
return file_path, True


def preprocess_xarray(ds, dataset_config):
"""
Perform preprocessing on the input dataset (`ds`) and return an xarray Dataset.
Expand Down Expand Up @@ -306,7 +353,7 @@ def publish_cloud_optimised_fileset_batch(self, s3_file_uri_list):
ds = self.try_open_dataset(
batch_files,
partial_preprocess,
drop_vars_list, # , engine="h5netcdf"
drop_vars_list,
)
partial_preprocess_already_run = True

Expand Down Expand Up @@ -337,62 +384,6 @@ def publish_cloud_optimised_fileset_batch(self, s3_file_uri_list):
if "ds" in locals():
self.postprocess(ds)

# def try_open_dataset(self, batch_files, partial_preprocess, drop_vars_list, engine="h5netcdf"):
# try:
# # First attempt: Try using the default engine (h5netcdf)
# return self._open_mfds(partial_preprocess, drop_vars_list, batch_files, engine=engine)
#
# except (ValueError, TypeError):
# # if issue, first we check that it's an issue with the data, and not the wrong engine.
# tb = traceback.format_exc()
# match_grid_not_consistent = re.search(
# r"Coordinate variable (\w+) is neither monotonically increasing nor monotonically decreasing on all datasets",
# tb,
# )
# match_not_netcdf4_signature = re.search(r"is not the signature of a valid netCDF4 file", tb)
# match_not_netcdf3_signature = re.search(r"is not a valid NetCDF 3 file", tb)
#
# if match_grid_not_consistent:
# variable_name = match_grid_not_consistent.group(1)
#
# # Handle coordinate variable issue and retry with a clean batch
# return self.handle_coordinate_variable_issue(batch_files, variable_name, partial_preprocess, drop_vars_list,
# engine)
# elif match_not_netcdf4_signature:
# # we recall the function but with scipy!
# try:
# self.try_open_dataset_scipy_fallback(batch_files, partial_preprocess, drop_vars_list)
#
# # if the above succeeds, it returns a ds even if there was an issue with grid inconsistency.
# # otherwise, we're doing a fallback
# except Exception:
# return self.handle_multi_engine_fallback(
# batch_files, partial_preprocess, drop_vars_list
# )
#
# def try_open_dataset_scipy_fallback(self, batch_files, partial_preprocess, drop_vars_list, engine="scipy"):
# try:
# return self._open_mfds(partial_preprocess, drop_vars_list, batch_files, engine=engine)
#
# except (ValueError, TypeError):
# # if issue, first we check that it's an issue with the data, and not the wrong engine.
# tb = traceback.format_exc()
# match_grid_not_consistent = re.search(
# r"Coordinate variable (\w+) is neither monotonically increasing nor monotonically decreasing on all datasets",
# tb,
# )
# match_not_netcdf4_signature = re.search(r"is not the signature of a valid netCDF4 file", tb)
# match_not_netcdf3_signature = re.search(r"is not a valid NetCDF 3 file", tb)
#
# if match_grid_not_consistent:
# variable_name = match_grid_not_consistent.group(1)
#
# # Handle coordinate variable issue and retry with a clean batch
# return self.handle_coordinate_variable_issue(batch_files, variable_name, partial_preprocess, drop_vars_list,
# engine)
# elif match_not_netcdf3_signature:
# raise ValueError

def try_open_dataset(
self,
batch_files,
Expand Down Expand Up @@ -473,6 +464,10 @@ def handle_coordinate_variable_issue(
self.logger.warning(
f"{self.uuid_log}: Processing batch without problematic files"
)
self.logger.info(
f"{self.uuid_log}: Processing the following files:\n{clean_batch_files}"
)

return self._open_mfds(
partial_preprocess, drop_vars_list, clean_batch_files, engine
)
Expand Down Expand Up @@ -511,29 +506,6 @@ def fallback_to_individual_processing(
f"{self.uuid_log}: An unexpected error occurred during fallback processing: {e}.\n {traceback.format_exc()}"
)

def check_variable_values_dask(self, file_path, reference_values, variable_name):
"""
Check if the values of a specified variable in a single file are consistent with the reference.
Args:
file_path (str): File path to check.
reference_values (np.ndarray): Reference values for the variable.
variable_name (str): Name of the variable to check.
Returns:
tuple: (file_path, bool) where bool indicates if the file is problematic.
"""
try:
ds = xr.open_dataset(file_path)
variable_values = ds[variable_name].values
ds.close()

# Check if the values are identical
return file_path, not np.array_equal(variable_values, reference_values)
except Exception as e:
self.logger.error(f"{self.uuid_log}: Failed to open {file_path}: {e}")
return file_path, True

def check_variable_values_parallel(self, file_paths, variable_name):
"""
Check the values of a specified variable in all files using a Coiled cluster.
Expand All @@ -556,27 +528,34 @@ def check_variable_values_parallel(self, file_paths, variable_name):
)
return file_paths # If the first file fails, consider all files problematic

# import ipdb; ipdb.set_trace()

# Use Dask to process files in parallel
# futures = self.client.map(
# self.check_variable_values_dask, file_paths[1:], reference_values=reference_values, variable_name=variable_name
# )
#
# results = self.client.gather(futures)
# future = self.client.submit(
# self.check_variable_values_dask, file_paths[1], reference_values=reference_values,
# variable_name=variable_name
# )
# results = future.result()
results = [
self.check_variable_values_dask(
file_path,
if self.cluster_mode:
# Use Dask to process files in parallel
futures = self.client.map(
check_variable_values_dask,
file_paths[1:],
reference_values=reference_values,
variable_name=variable_name,
dataset_config=self.dataset_config,
uuid_log=self.uuid_log,
)
for file_path in file_paths[1:]
]
results = self.client.gather(futures)
else:
# TODO: not running on a cluster. But if that's the case, most likely this will run a file one by one,
# so it will break anyway as no reference values. Hopefully this won't happen. alternative would be
# to have a ref value directly taken from the original zarr? but that would work only if the zarr
# already exist... to annoying/complicated to implement as it shouldn't happen. easier to debug if it
# does
results = [
check_variable_values_dask(
file_path,
reference_values,
variable_name,
self.dataset_config,
self.uuid_log,
)
for file_path in file_paths[1:]
]

# Collect problematic files
problematic_files = [
file_path for file_path, is_problematic in results if is_problematic
Expand Down Expand Up @@ -665,7 +644,6 @@ def _handle_duplicate_regions(
# TODO:
# compute() was added as unittests failed on github, but not locally. related to
# https://github.com/pydata/xarray/issues/5219
# import ipdb;ipdb.set_trace()
ds.isel(**{time_dimension_name: indexes}).drop_vars(
self.vars_to_drop_no_common_dimension, errors="ignore"
).pad(**{time_dimension_name: (0, amount_to_pad)}).to_zarr(
Expand Down Expand Up @@ -722,23 +700,22 @@ def _open_file_with_fallback(self, file, partial_preprocess, drop_vars_list):
Exception: Propagates any exceptions raised by the dataset opening operations.
"""
try:
engine = "scipy"
with self.s3_fs.open(file, "rb") as f: # Open the file-like object
ds = self._open_ds(
f, partial_preprocess, drop_vars_list, engine="scipy"
)
ds = self._open_ds(f, partial_preprocess, drop_vars_list, engine=engine)
self.logger.info(
f"{self.uuid_log}: Success opening {file} with scipy engine."
f"{self.uuid_log}: Success opening {file} with {engine} engine."
)
return ds
except (ValueError, TypeError) as e:
self.logger.info(
f"{self.uuid_log}: Error opening {file}: {e} with scipy engine. Defaulting to h5netcdf"
)
ds = self._open_ds(
file, partial_preprocess, drop_vars_list, engine="h5netcdf"
)
engine = "h5netcdf"

ds = self._open_ds(file, partial_preprocess, drop_vars_list, engine=engine)
self.logger.info(
f"{self.uuid_log}: Success opening {file} with h5netcdf engine."
f"{self.uuid_log}: Success opening {file} with {engine} engine."
)
return ds

Expand Down
18 changes: 18 additions & 0 deletions test_aodn_cloud_optimised/test_generic_zarr_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,24 @@ def test_zarr_nc_acorn_nwa_handler(self):
any("Writing data to a new Zarr dataset" in log for log in captured_logs)
)

# read zarr
dataset_config = load_dataset_config(DATASET_CONFIG_NC_ACORN_NWA_JSON)
dataset_name = dataset_config["dataset_name"]
dname = f"s3://{self.BUCKET_OPTIMISED_NAME}/{self.ROOT_PREFIX_CLOUD_OPTIMISED_PATH}/{dataset_name}.zarr/"

ds = xr.open_zarr(self.s3_fs.get_mapper(dname), consolidated=True)
self.assertEqual(ds.UCUR.standard_name, "eastward_sea_water_velocity")

expected = np.array(
["2022-03-12T00:29:59.999993088", "2022-03-12T01:30:00.000000000"],
dtype="datetime64[ns]",
)
assert np.allclose(
ds.TIME.values.astype("datetime64[ns]").astype(float),
expected.astype(float),
atol=1e-9,
), f"TIME values are not as expected: {expected}"


if __name__ == "__main__":
unittest.main()

0 comments on commit f9f77ea

Please sign in to comment.