From ba8255a87eae48adcd3e1d6060aabae35b3521b8 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Thu, 27 Jun 2024 11:55:07 -0600 Subject: [PATCH] Auto-convert from `_load`/`_save` to `load`/`save` Signed-off-by: Deepyaman Datta --- kedro/io/cached_dataset.py | 8 +-- kedro/io/core.py | 109 ++++++++++++------------------ kedro/io/lambda_dataset.py | 8 +-- kedro/io/memory_dataset.py | 10 +-- kedro/io/shared_memory_dataset.py | 8 +-- tests/io/test_core.py | 20 ------ 6 files changed, 52 insertions(+), 111 deletions(-) diff --git a/kedro/io/cached_dataset.py b/kedro/io/cached_dataset.py index 1ef7450c28..2dccba40fe 100644 --- a/kedro/io/cached_dataset.py +++ b/kedro/io/cached_dataset.py @@ -98,7 +98,7 @@ def _describe(self) -> dict[str, Any]: "cache": self._cache._describe(), } - def _load(self) -> Any: + def load(self) -> Any: data = self._cache.load() if self._cache.exists() else self._dataset.load() if not self._cache.exists(): @@ -106,14 +106,10 @@ def _load(self) -> Any: return data - load = _load - - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: self._dataset.save(data) self._cache.save(data) - save = _save - def _exists(self) -> bool: return self._cache.exists() or self._dataset.exists() diff --git a/kedro/io/core.py b/kedro/io/core.py index 7d11cece19..c83edd569a 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -178,6 +178,33 @@ def from_config( def _logger(self) -> logging.Logger: return logging.getLogger(__name__) + def __str__(self) -> str: + def _to_str(obj: Any, is_root: bool = False) -> str: + """Returns a string representation where + 1. The root level (i.e. the Dataset.__init__ arguments) are + formatted like Dataset(key=value). + 2. Dictionaries have the keys alphabetically sorted recursively. + 3. None values are not shown. + """ + + fmt = "{}={}" if is_root else "'{}': {}" # 1 + + if isinstance(obj, dict): + sorted_dict = sorted(obj.items(), key=lambda pair: str(pair[0])) # 2 + + text = ", ".join( + fmt.format(key, _to_str(value)) # 2 + for key, value in sorted_dict + if value is not None # 3 + ) + + return text if is_root else "{" + text + "}" # 1 + + # not a dictionary + return str(obj) + + return f"{type(self).__name__}({_to_str(self._describe(), True)})" + @classmethod def _load_wrapper(cls, load_func: Callable[[Self], _DO]) -> Callable[[Self], _DO]: @wraps(load_func) @@ -228,6 +255,12 @@ def save(self: Self, data: _DI) -> None: def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) + if hasattr(cls, "_load") and not cls._load.__qualname__.startswith("Abstract"): + cls.load = cls._load # type: ignore[method-assign] + + if hasattr(cls, "_save") and not cls._save.__qualname__.startswith("Abstract"): + cls.save = cls._save # type: ignore[method-assign] + if hasattr(cls, "load") and not cls.load.__qualname__.startswith("Abstract"): cls.load = cls._load_wrapper( # type: ignore[assignment] cls.load @@ -242,6 +275,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: else cls.save.__wrapped__ # type: ignore[attr-defined] ) + @abc.abstractmethod def load(self) -> _DO: """Loads data by delegation to the provided load method. @@ -252,21 +286,12 @@ def load(self) -> _DO: DatasetError: When underlying load method raises error. """ + raise NotImplementedError( + f"'{self.__class__.__name__}' is a subclass of AbstractDataset and " + f"it must implement the 'load' method" + ) - self._logger.debug("Loading %s", str(self)) - - try: - return self._load() - except DatasetError: - raise - except Exception as exc: - # This exception handling is by design as the composed data sets - # can throw any type of exception. - message = ( - f"Failed while loading data from data set {str(self)}.\n{str(exc)}" - ) - raise DatasetError(message) from exc - + @abc.abstractmethod def save(self, data: _DI) -> None: """Saves data by delegation to the provided save method. @@ -277,59 +302,11 @@ def save(self, data: _DI) -> None: DatasetError: when underlying save method raises error. FileNotFoundError: when save method got file instead of dir, on Windows. NotADirectoryError: when save method got file instead of dir, on Unix. - """ - - if data is None: - raise DatasetError("Saving 'None' to a 'Dataset' is not allowed") - - try: - self._logger.debug("Saving %s", str(self)) - self._save(data) - except (DatasetError, FileNotFoundError, NotADirectoryError): - raise - except Exception as exc: - message = f"Failed while saving data to data set {str(self)}.\n{str(exc)}" - raise DatasetError(message) from exc - - def __str__(self) -> str: - def _to_str(obj: Any, is_root: bool = False) -> str: - """Returns a string representation where - 1. The root level (i.e. the Dataset.__init__ arguments) are - formatted like Dataset(key=value). - 2. Dictionaries have the keys alphabetically sorted recursively. - 3. None values are not shown. - """ - - fmt = "{}={}" if is_root else "'{}': {}" # 1 - - if isinstance(obj, dict): - sorted_dict = sorted(obj.items(), key=lambda pair: str(pair[0])) # 2 - text = ", ".join( - fmt.format(key, _to_str(value)) # 2 - for key, value in sorted_dict - if value is not None # 3 - ) - - return text if is_root else "{" + text + "}" # 1 - - # not a dictionary - return str(obj) - - return f"{type(self).__name__}({_to_str(self._describe(), True)})" - - @abc.abstractmethod - def _load(self) -> _DO: - raise NotImplementedError( - f"'{self.__class__.__name__}' is a subclass of AbstractDataset and " - f"it must implement the '_load' method" - ) - - @abc.abstractmethod - def _save(self, data: _DI) -> None: + """ raise NotImplementedError( f"'{self.__class__.__name__}' is a subclass of AbstractDataset and " - f"it must implement the '_save' method" + f"it must implement the 'save' method" ) @abc.abstractmethod @@ -682,7 +659,7 @@ def _get_versioned_path(self, version: str) -> PurePosixPath: return self._filepath / version / self._filepath.name def load(self) -> _DO: - return super().load() + return super().load() # type: ignore[safe-super] @classmethod def _save_wrapper( @@ -724,7 +701,7 @@ def save(self, data: _DI) -> None: self._version_cache.clear() save_version = self.resolve_save_version() # Make sure last save version is set try: - super().save(data) + super().save(data) # type: ignore[safe-super] except (FileNotFoundError, NotADirectoryError) as err: # FileNotFoundError raised in Win, NotADirectoryError raised in Unix _default_version = "YYYY-MM-DDThh.mm.ss.sssZ" diff --git a/kedro/io/lambda_dataset.py b/kedro/io/lambda_dataset.py index a1e79ad327..126bef6632 100644 --- a/kedro/io/lambda_dataset.py +++ b/kedro/io/lambda_dataset.py @@ -49,7 +49,7 @@ def _to_str(func: Any) -> str | None: return descr - def _load(self) -> Any: + def load(self) -> Any: if not self.__load: raise DatasetError( "Cannot load data set. No 'load' function " @@ -57,9 +57,7 @@ def _load(self) -> Any: ) return self.__load() - load = _load - - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: if not self.__save: raise DatasetError( "Cannot save to data set. No 'save' function " @@ -67,8 +65,6 @@ def _save(self, data: Any) -> None: ) self.__save(data) - save = _save - def _exists(self) -> bool: if not self.__exists: return super()._exists() diff --git a/kedro/io/memory_dataset.py b/kedro/io/memory_dataset.py index b25b8de348..dff84d5670 100644 --- a/kedro/io/memory_dataset.py +++ b/kedro/io/memory_dataset.py @@ -57,9 +57,9 @@ def __init__( self.metadata = metadata self._EPHEMERAL = True if data is not _EMPTY: - self._save(data) + self.save.__wrapped__(self, data) # type: ignore[attr-defined] - def _load(self) -> Any: + def load(self) -> Any: if self._data is _EMPTY: raise DatasetError("Data for MemoryDataset has not been saved yet.") @@ -67,14 +67,10 @@ def _load(self) -> Any: data = _copy_with_mode(self._data, copy_mode=copy_mode) return data - load = _load - - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: copy_mode = self._copy_mode or _infer_copy_mode(data) self._data = _copy_with_mode(data, copy_mode=copy_mode) - save = _save - def _exists(self) -> bool: return self._data is not _EMPTY diff --git a/kedro/io/shared_memory_dataset.py b/kedro/io/shared_memory_dataset.py index f08f56fe24..413c9c07a7 100644 --- a/kedro/io/shared_memory_dataset.py +++ b/kedro/io/shared_memory_dataset.py @@ -34,12 +34,10 @@ def __getattr__(self, name: str) -> Any: raise AttributeError() return getattr(self.shared_memory_dataset, name) # pragma: no cover - def _load(self) -> Any: + def load(self) -> Any: return self.shared_memory_dataset.load() - load = _load - - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: """Calls save method of a shared MemoryDataset in SyncManager.""" try: self.shared_memory_dataset.save(data) @@ -54,8 +52,6 @@ def _save(self, data: Any) -> None: ) from serialisation_exc raise exc # pragma: no cover - save = _save - def _describe(self) -> dict[str, Any]: """SharedMemoryDataset doesn't have any constructor argument to return.""" return {} diff --git a/tests/io/test_core.py b/tests/io/test_core.py index 1a16468bd9..d0bf8062c9 100644 --- a/tests/io/test_core.py +++ b/tests/io/test_core.py @@ -58,14 +58,10 @@ def _exists(self) -> bool: def _load(self): return pd.read_csv(self._filepath) - load = _load - def _save(self, data: str) -> None: with open(self._filepath, mode="w") as file: file.write(data) - save = _save - class MyVersionedDataset(AbstractVersionedDataset[str, str]): def __init__( # noqa: PLR0913 @@ -96,16 +92,12 @@ def _load(self) -> str: with self._fs.open(load_path, mode="r") as fs_file: return fs_file.read() - load = _load - def _save(self, data: str) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) with self._fs.open(save_path, mode="w") as fs_file: fs_file.write(data) - save = _save - def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol) @@ -143,16 +135,12 @@ def _load(self) -> str: with self._fs.open(load_path, mode="r") as fs_file: return fs_file.read() - load = _load - def _save(self, data: str) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) with self._fs.open(save_path, mode="w") as fs_file: fs_file.write(data) - save = _save - def _exists(self) -> bool: load_path = get_filepath_str(self._get_load_path(), self._protocol) # no try catch - will return a VersionNotFoundError to be caught be AbstractVersionedDataset.exists() @@ -458,14 +446,10 @@ def _exists(self) -> bool: def _load(self): return pd.read_csv(self._filepath) - # load = _load - def _save(self, data: str) -> None: with open(self._filepath, mode="w") as file: file.write(data) - # save = _save - class MyLegacyVersionedDataset(AbstractVersionedDataset[str, str]): def __init__( # noqa: PLR0913 @@ -496,16 +480,12 @@ def _load(self) -> str: with self._fs.open(load_path, mode="r") as fs_file: return fs_file.read() - # load = _load - def _save(self, data: str) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) with self._fs.open(save_path, mode="w") as fs_file: fs_file.write(data) - # save = _save - def _exists(self) -> bool: try: load_path = get_filepath_str(self._get_load_path(), self._protocol)