Skip to content

Commit

Permalink
Fix: Buffering zip downloads to files rather than holding in memory (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tofarr authored Nov 7, 2024
1 parent fa625fe commit 932de79
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 18 deletions.
5 changes: 3 additions & 2 deletions openhands/runtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
from abc import abstractmethod
from pathlib import Path
from typing import Callable

from requests.exceptions import ConnectionError
Expand Down Expand Up @@ -274,6 +275,6 @@ def list_files(self, path: str | None = None) -> list[str]:
raise NotImplementedError('This method is not implemented in the base class.')

@abstractmethod
def copy_from(self, path: str) -> bytes:
"""Zip all files in the sandbox and return as a stream of bytes."""
def copy_from(self, path: str) -> Path:
"""Zip all files in the sandbox and return a path in the local filesystem."""
raise NotImplementedError('This method is not implemented in the base class.')
10 changes: 7 additions & 3 deletions openhands/runtime/impl/eventstream/eventstream_runtime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
import tempfile
import threading
from functools import lru_cache
Expand Down Expand Up @@ -604,7 +605,7 @@ def list_files(self, path: str | None = None) -> list[str]:
except requests.Timeout:
raise TimeoutError('List files operation timed out')

def copy_from(self, path: str) -> bytes:
def copy_from(self, path: str) -> Path:
"""Zip all files in the sandbox and return as a stream of bytes."""
self._refresh_logs()
try:
Expand All @@ -617,8 +618,11 @@ def copy_from(self, path: str) -> bytes:
stream=True,
timeout=30,
)
data = response.content
return data
temp_file = tempfile.NamedTemporaryFile(delete=False)
for chunk in response.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
return Path(temp_file.name)
except requests.Timeout:
raise TimeoutError('Copy operation timed out')

Expand Down
10 changes: 8 additions & 2 deletions openhands/runtime/impl/remote/remote_runtime.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path
import tempfile
import threading
from typing import Callable, Optional
Expand Down Expand Up @@ -467,13 +468,18 @@ def list_files(self, path: str | None = None) -> list[str]:
assert isinstance(response_json, list)
return response_json

def copy_from(self, path: str) -> bytes:
def copy_from(self, path: str) -> Path:
"""Zip all files in the sandbox and return as a stream of bytes."""
params = {'path': path}
response = self._send_request(
'GET',
f'{self.runtime_url}/download_files',
params=params,
stream=True,
timeout=30,
)
return response.content
temp_file = tempfile.NamedTemporaryFile(delete=False)
for chunk in response.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
temp_file.write(chunk)
return Path(temp_file.name)
19 changes: 10 additions & 9 deletions openhands/server/listen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import io
import os
import re
import tempfile
Expand Down Expand Up @@ -27,14 +26,15 @@

from dotenv import load_dotenv
from fastapi import (
BackgroundTasks,
FastAPI,
HTTPException,
Request,
UploadFile,
WebSocket,
status,
)
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import FileResponse, JSONResponse
from fastapi.security import HTTPBearer
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
Expand Down Expand Up @@ -790,20 +790,21 @@ async def security_api(request: Request):


@app.get('/api/zip-directory')
async def zip_current_workspace(request: Request):
async def zip_current_workspace(request: Request, background_tasks: BackgroundTasks):
try:
logger.debug('Zipping workspace')
runtime: Runtime = request.state.conversation.runtime

path = runtime.config.workspace_mount_path_in_sandbox
zip_file_bytes = await call_sync_from_async(runtime.copy_from, path)
zip_stream = io.BytesIO(zip_file_bytes) # Wrap to behave like a file stream
response = StreamingResponse(
zip_stream,
zip_file = await call_sync_from_async(runtime.copy_from, path)
response = FileResponse(
path=zip_file,
filename='workspace.zip',
media_type='application/x-zip-compressed',
headers={'Content-Disposition': 'attachment; filename=workspace.zip'},
)

# This will execute after the response is sent (So the file is not deleted before being sent)
background_tasks.add_task(zip_file.unlink)

return response
except Exception as e:
logger.error(f'Error zipping workspace: {e}', exc_info=True)
Expand Down
7 changes: 5 additions & 2 deletions tests/runtime/test_bash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Bash-related tests for the EventStreamRuntime, which connects to the ActionExecutor running in the sandbox."""

import os
from pathlib import Path

import pytest
from conftest import (
Expand Down Expand Up @@ -586,8 +587,10 @@ def test_copy_from_directory(temp_dir, runtime_cls):
path_to_copy_from = f'{sandbox_dir}/test_dir'
result = runtime.copy_from(path=path_to_copy_from)

# Result is returned in bytes
assert isinstance(result, bytes)
# Result is returned as a path
assert isinstance(result, Path)

result.unlink()
finally:
_close_test_runtime(runtime)

Expand Down

0 comments on commit 932de79

Please sign in to comment.