From 932de7915472f186be0bfcbd715e3d5cf1a91615 Mon Sep 17 00:00:00 2001 From: tofarr Date: Thu, 7 Nov 2024 10:24:30 -0700 Subject: [PATCH] Fix: Buffering zip downloads to files rather than holding in memory (#4802) --- openhands/runtime/base.py | 5 +++-- .../impl/eventstream/eventstream_runtime.py | 10 +++++++--- .../runtime/impl/remote/remote_runtime.py | 10 ++++++++-- openhands/server/listen.py | 19 ++++++++++--------- tests/runtime/test_bash.py | 7 +++++-- 5 files changed, 33 insertions(+), 18 deletions(-) diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 94dfeb3f5b5d..076732a4636c 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -3,6 +3,7 @@ import json import os from abc import abstractmethod +from pathlib import Path from typing import Callable from requests.exceptions import ConnectionError @@ -274,6 +275,6 @@ def list_files(self, path: str | None = None) -> list[str]: raise NotImplementedError('This method is not implemented in the base class.') @abstractmethod - def copy_from(self, path: str) -> bytes: - """Zip all files in the sandbox and return as a stream of bytes.""" + def copy_from(self, path: str) -> Path: + """Zip all files in the sandbox and return a path in the local filesystem.""" raise NotImplementedError('This method is not implemented in the base class.') diff --git a/openhands/runtime/impl/eventstream/eventstream_runtime.py b/openhands/runtime/impl/eventstream/eventstream_runtime.py index e90fb7680b2e..be05c767f544 100644 --- a/openhands/runtime/impl/eventstream/eventstream_runtime.py +++ b/openhands/runtime/impl/eventstream/eventstream_runtime.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import tempfile import threading from functools import lru_cache @@ -604,7 +605,7 @@ def list_files(self, path: str | None = None) -> list[str]: except requests.Timeout: raise TimeoutError('List files operation timed out') - def copy_from(self, path: str) -> bytes: + def copy_from(self, path: str) -> Path: """Zip all files in the sandbox and return as a stream of bytes.""" self._refresh_logs() try: @@ -617,8 +618,11 @@ def copy_from(self, path: str) -> bytes: stream=True, timeout=30, ) - data = response.content - return data + temp_file = tempfile.NamedTemporaryFile(delete=False) + for chunk in response.iter_content(chunk_size=8192): + if chunk: # filter out keep-alive new chunks + temp_file.write(chunk) + return Path(temp_file.name) except requests.Timeout: raise TimeoutError('Copy operation timed out') diff --git a/openhands/runtime/impl/remote/remote_runtime.py b/openhands/runtime/impl/remote/remote_runtime.py index e74d4305be8a..7c2badfb19b1 100644 --- a/openhands/runtime/impl/remote/remote_runtime.py +++ b/openhands/runtime/impl/remote/remote_runtime.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import tempfile import threading from typing import Callable, Optional @@ -467,13 +468,18 @@ def list_files(self, path: str | None = None) -> list[str]: assert isinstance(response_json, list) return response_json - def copy_from(self, path: str) -> bytes: + def copy_from(self, path: str) -> Path: """Zip all files in the sandbox and return as a stream of bytes.""" params = {'path': path} response = self._send_request( 'GET', f'{self.runtime_url}/download_files', params=params, + stream=True, timeout=30, ) - return response.content + temp_file = tempfile.NamedTemporaryFile(delete=False) + for chunk in response.iter_content(chunk_size=8192): + if chunk: # filter out keep-alive new chunks + temp_file.write(chunk) + return Path(temp_file.name) diff --git a/openhands/server/listen.py b/openhands/server/listen.py index 7ccc7046594a..c679be2ded0f 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -1,5 +1,4 @@ import asyncio -import io import os import re import tempfile @@ -27,6 +26,7 @@ from dotenv import load_dotenv from fastapi import ( + BackgroundTasks, FastAPI, HTTPException, Request, @@ -34,7 +34,7 @@ WebSocket, status, ) -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import FileResponse, JSONResponse from fastapi.security import HTTPBearer from fastapi.staticfiles import StaticFiles from pydantic import BaseModel @@ -790,20 +790,21 @@ async def security_api(request: Request): @app.get('/api/zip-directory') -async def zip_current_workspace(request: Request): +async def zip_current_workspace(request: Request, background_tasks: BackgroundTasks): try: logger.debug('Zipping workspace') runtime: Runtime = request.state.conversation.runtime - path = runtime.config.workspace_mount_path_in_sandbox - zip_file_bytes = await call_sync_from_async(runtime.copy_from, path) - zip_stream = io.BytesIO(zip_file_bytes) # Wrap to behave like a file stream - response = StreamingResponse( - zip_stream, + zip_file = await call_sync_from_async(runtime.copy_from, path) + response = FileResponse( + path=zip_file, + filename='workspace.zip', media_type='application/x-zip-compressed', - headers={'Content-Disposition': 'attachment; filename=workspace.zip'}, ) + # This will execute after the response is sent (So the file is not deleted before being sent) + background_tasks.add_task(zip_file.unlink) + return response except Exception as e: logger.error(f'Error zipping workspace: {e}', exc_info=True) diff --git a/tests/runtime/test_bash.py b/tests/runtime/test_bash.py index 3673dd927c68..f8ff95d9a6b2 100644 --- a/tests/runtime/test_bash.py +++ b/tests/runtime/test_bash.py @@ -1,6 +1,7 @@ """Bash-related tests for the EventStreamRuntime, which connects to the ActionExecutor running in the sandbox.""" import os +from pathlib import Path import pytest from conftest import ( @@ -586,8 +587,10 @@ def test_copy_from_directory(temp_dir, runtime_cls): path_to_copy_from = f'{sandbox_dir}/test_dir' result = runtime.copy_from(path=path_to_copy_from) - # Result is returned in bytes - assert isinstance(result, bytes) + # Result is returned as a path + assert isinstance(result, Path) + + result.unlink() finally: _close_test_runtime(runtime)