From dcc9340e5adb44a5a25f5573b833ee9c0e9516ff Mon Sep 17 00:00:00 2001 From: Jakub Novak Date: Fri, 13 Oct 2023 14:52:13 -0700 Subject: [PATCH] Change artifact folder --- packages/python-sdk/e2b/templates/data_analysis.py | 10 +++++----- .../tests/templates/test_data_analysis.py | 13 +++++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/packages/python-sdk/e2b/templates/data_analysis.py b/packages/python-sdk/e2b/templates/data_analysis.py index c58bff087..33b5583e1 100644 --- a/packages/python-sdk/e2b/templates/data_analysis.py +++ b/packages/python-sdk/e2b/templates/data_analysis.py @@ -16,6 +16,9 @@ def __init__(self, **data: Any): super().__init__(**data) self._session = data["_session"] + def __hash__(self): + return hash(self.name) + def read(self) -> bytes: return self._session.download_file(self.name) @@ -72,7 +75,7 @@ def register_artifacts(event: Any): except Exception as e: logger.error("Failed to process artifact", exc_info=e) - watcher = self.filesystem.watch_dir("/tmp") + watcher = self.filesystem.watch_dir("/home/user/artifacts") watcher.add_event_listener(register_artifacts) watcher.start() @@ -86,10 +89,7 @@ def register_artifacts(event: Any): watcher.stop() - artifacts = list( - map(lambda artifact: Artifact(name=artifact, _session=self), artifacts) - ) - return process.output.stdout, process.output.stderr, artifacts + return process.output.stdout, process.output.stderr, list(artifacts) def install_python_package(self, package_names: Union[str, List[str]]): if isinstance(package_names, list): diff --git a/packages/python-sdk/tests/templates/test_data_analysis.py b/packages/python-sdk/tests/templates/test_data_analysis.py index d49f7a525..72e7ee031 100644 --- a/packages/python-sdk/tests/templates/test_data_analysis.py +++ b/packages/python-sdk/tests/templates/test_data_analysis.py @@ -3,14 +3,15 @@ def test_create_graph(): s = DataAnalysis() - _, _, artifacts = s.run_python( + a, b, artifacts = s.run_python( """ - import matplotlib.pyplot as plt - - plt.plot([1, 2, 3, 4]) - plt.ylabel('some numbers') - plt.show() +import matplotlib.pyplot as plt + +plt.plot([1, 2, 3, 4]) +plt.ylabel('some numbers') +plt.show() """ ) + print(a, b, artifacts) s.close() assert len(artifacts) == 1