Skip to content

Commit

Permalink
Auto-convert from _load/_save to load/save
Browse files Browse the repository at this point in the history
Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
deepyaman committed Jun 27, 2024
1 parent c25bab3 commit ba8255a
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 111 deletions.
8 changes: 2 additions & 6 deletions kedro/io/cached_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,18 @@ def _describe(self) -> dict[str, Any]:
"cache": self._cache._describe(),
}

def _load(self) -> Any:
def load(self) -> Any:
data = self._cache.load() if self._cache.exists() else self._dataset.load()

if not self._cache.exists():
self._cache.save(data)

return data

load = _load

def _save(self, data: Any) -> None:
def save(self, data: Any) -> None:
self._dataset.save(data)
self._cache.save(data)

save = _save

def _exists(self) -> bool:
return self._cache.exists() or self._dataset.exists()

Expand Down
109 changes: 43 additions & 66 deletions kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,33 @@ def from_config(
def _logger(self) -> logging.Logger:
return logging.getLogger(__name__)

def __str__(self) -> str:
def _to_str(obj: Any, is_root: bool = False) -> str:
"""Returns a string representation where
1. The root level (i.e. the Dataset.__init__ arguments) are
formatted like Dataset(key=value).
2. Dictionaries have the keys alphabetically sorted recursively.
3. None values are not shown.
"""

fmt = "{}={}" if is_root else "'{}': {}" # 1

if isinstance(obj, dict):
sorted_dict = sorted(obj.items(), key=lambda pair: str(pair[0])) # 2

text = ", ".join(
fmt.format(key, _to_str(value)) # 2
for key, value in sorted_dict
if value is not None # 3
)

return text if is_root else "{" + text + "}" # 1

# not a dictionary
return str(obj)

return f"{type(self).__name__}({_to_str(self._describe(), True)})"

@classmethod
def _load_wrapper(cls, load_func: Callable[[Self], _DO]) -> Callable[[Self], _DO]:
@wraps(load_func)
Expand Down Expand Up @@ -228,6 +255,12 @@ def save(self: Self, data: _DI) -> None:
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)

if hasattr(cls, "_load") and not cls._load.__qualname__.startswith("Abstract"):
cls.load = cls._load # type: ignore[method-assign]

if hasattr(cls, "_save") and not cls._save.__qualname__.startswith("Abstract"):
cls.save = cls._save # type: ignore[method-assign]

if hasattr(cls, "load") and not cls.load.__qualname__.startswith("Abstract"):
cls.load = cls._load_wrapper( # type: ignore[assignment]
cls.load
Expand All @@ -242,6 +275,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
else cls.save.__wrapped__ # type: ignore[attr-defined]
)

@abc.abstractmethod
def load(self) -> _DO:
"""Loads data by delegation to the provided load method.
Expand All @@ -252,21 +286,12 @@ def load(self) -> _DO:
DatasetError: When underlying load method raises error.
"""
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the 'load' method"
)

self._logger.debug("Loading %s", str(self))

try:
return self._load()
except DatasetError:
raise
except Exception as exc:
# This exception handling is by design as the composed data sets
# can throw any type of exception.
message = (
f"Failed while loading data from data set {str(self)}.\n{str(exc)}"
)
raise DatasetError(message) from exc

@abc.abstractmethod
def save(self, data: _DI) -> None:
"""Saves data by delegation to the provided save method.
Expand All @@ -277,59 +302,11 @@ def save(self, data: _DI) -> None:
DatasetError: when underlying save method raises error.
FileNotFoundError: when save method got file instead of dir, on Windows.
NotADirectoryError: when save method got file instead of dir, on Unix.
"""

if data is None:
raise DatasetError("Saving 'None' to a 'Dataset' is not allowed")

try:
self._logger.debug("Saving %s", str(self))
self._save(data)
except (DatasetError, FileNotFoundError, NotADirectoryError):
raise
except Exception as exc:
message = f"Failed while saving data to data set {str(self)}.\n{str(exc)}"
raise DatasetError(message) from exc

def __str__(self) -> str:
def _to_str(obj: Any, is_root: bool = False) -> str:
"""Returns a string representation where
1. The root level (i.e. the Dataset.__init__ arguments) are
formatted like Dataset(key=value).
2. Dictionaries have the keys alphabetically sorted recursively.
3. None values are not shown.
"""

fmt = "{}={}" if is_root else "'{}': {}" # 1

if isinstance(obj, dict):
sorted_dict = sorted(obj.items(), key=lambda pair: str(pair[0])) # 2
text = ", ".join(
fmt.format(key, _to_str(value)) # 2
for key, value in sorted_dict
if value is not None # 3
)

return text if is_root else "{" + text + "}" # 1

# not a dictionary
return str(obj)

return f"{type(self).__name__}({_to_str(self._describe(), True)})"

@abc.abstractmethod
def _load(self) -> _DO:
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_load' method"
)

@abc.abstractmethod
def _save(self, data: _DI) -> None:
"""
raise NotImplementedError(
f"'{self.__class__.__name__}' is a subclass of AbstractDataset and "
f"it must implement the '_save' method"
f"it must implement the 'save' method"
)

@abc.abstractmethod
Expand Down Expand Up @@ -682,7 +659,7 @@ def _get_versioned_path(self, version: str) -> PurePosixPath:
return self._filepath / version / self._filepath.name

def load(self) -> _DO:
return super().load()
return super().load() # type: ignore[safe-super]

@classmethod
def _save_wrapper(
Expand Down Expand Up @@ -724,7 +701,7 @@ def save(self, data: _DI) -> None:
self._version_cache.clear()
save_version = self.resolve_save_version() # Make sure last save version is set
try:
super().save(data)
super().save(data) # type: ignore[safe-super]
except (FileNotFoundError, NotADirectoryError) as err:
# FileNotFoundError raised in Win, NotADirectoryError raised in Unix
_default_version = "YYYY-MM-DDThh.mm.ss.sssZ"
Expand Down
8 changes: 2 additions & 6 deletions kedro/io/lambda_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,22 @@ def _to_str(func: Any) -> str | None:

return descr

def _load(self) -> Any:
def load(self) -> Any:
if not self.__load:
raise DatasetError(
"Cannot load data set. No 'load' function "
"provided when LambdaDataset was created."
)
return self.__load()

load = _load

def _save(self, data: Any) -> None:
def save(self, data: Any) -> None:
if not self.__save:
raise DatasetError(
"Cannot save to data set. No 'save' function "
"provided when LambdaDataset was created."
)
self.__save(data)

save = _save

def _exists(self) -> bool:
if not self.__exists:
return super()._exists()
Expand Down
10 changes: 3 additions & 7 deletions kedro/io/memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,20 @@ def __init__(
self.metadata = metadata
self._EPHEMERAL = True
if data is not _EMPTY:
self._save(data)
self.save.__wrapped__(self, data) # type: ignore[attr-defined]

def _load(self) -> Any:
def load(self) -> Any:
if self._data is _EMPTY:
raise DatasetError("Data for MemoryDataset has not been saved yet.")

copy_mode = self._copy_mode or _infer_copy_mode(self._data)
data = _copy_with_mode(self._data, copy_mode=copy_mode)
return data

load = _load

def _save(self, data: Any) -> None:
def save(self, data: Any) -> None:
copy_mode = self._copy_mode or _infer_copy_mode(data)
self._data = _copy_with_mode(data, copy_mode=copy_mode)

save = _save

def _exists(self) -> bool:
return self._data is not _EMPTY

Expand Down
8 changes: 2 additions & 6 deletions kedro/io/shared_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ def __getattr__(self, name: str) -> Any:
raise AttributeError()
return getattr(self.shared_memory_dataset, name) # pragma: no cover

def _load(self) -> Any:
def load(self) -> Any:
return self.shared_memory_dataset.load()

load = _load

def _save(self, data: Any) -> None:
def save(self, data: Any) -> None:
"""Calls save method of a shared MemoryDataset in SyncManager."""
try:
self.shared_memory_dataset.save(data)
Expand All @@ -54,8 +52,6 @@ def _save(self, data: Any) -> None:
) from serialisation_exc
raise exc # pragma: no cover

save = _save

def _describe(self) -> dict[str, Any]:
"""SharedMemoryDataset doesn't have any constructor argument to return."""
return {}
20 changes: 0 additions & 20 deletions tests/io/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,10 @@ def _exists(self) -> bool:
def _load(self):
return pd.read_csv(self._filepath)

load = _load

def _save(self, data: str) -> None:
with open(self._filepath, mode="w") as file:
file.write(data)

save = _save


class MyVersionedDataset(AbstractVersionedDataset[str, str]):
def __init__( # noqa: PLR0913
Expand Down Expand Up @@ -96,16 +92,12 @@ def _load(self) -> str:
with self._fs.open(load_path, mode="r") as fs_file:
return fs_file.read()

load = _load

def _save(self, data: str) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, mode="w") as fs_file:
fs_file.write(data)

save = _save

def _exists(self) -> bool:
try:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand Down Expand Up @@ -143,16 +135,12 @@ def _load(self) -> str:
with self._fs.open(load_path, mode="r") as fs_file:
return fs_file.read()

load = _load

def _save(self, data: str) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, mode="w") as fs_file:
fs_file.write(data)

save = _save

def _exists(self) -> bool:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
# no try catch - will return a VersionNotFoundError to be caught be AbstractVersionedDataset.exists()
Expand Down Expand Up @@ -458,14 +446,10 @@ def _exists(self) -> bool:
def _load(self):
return pd.read_csv(self._filepath)

# load = _load

def _save(self, data: str) -> None:
with open(self._filepath, mode="w") as file:
file.write(data)

# save = _save


class MyLegacyVersionedDataset(AbstractVersionedDataset[str, str]):
def __init__( # noqa: PLR0913
Expand Down Expand Up @@ -496,16 +480,12 @@ def _load(self) -> str:
with self._fs.open(load_path, mode="r") as fs_file:
return fs_file.read()

# load = _load

def _save(self, data: str) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

with self._fs.open(save_path, mode="w") as fs_file:
fs_file.write(data)

# save = _save

def _exists(self) -> bool:
try:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
Expand Down

0 comments on commit ba8255a

Please sign in to comment.