Skip to content

Commit

Permalink
Implement sliced downloads in GSClient (#389) (#391)
Browse files Browse the repository at this point in the history
* WIP: Implement sliced downloads in GSClient (#389)

* feat: Implement sliced downloads in GSClient

* fix: remove unintended import changes

* Mock transfer_manager. Test both worker types.

* Update HISTORY.md

---------

Co-authored-by: Joe O'Connor <[email protected]>
  • Loading branch information
pjbull and joconnor-ecaa authored Dec 29, 2023
1 parent 6bce0f9 commit 9936a10
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 1 deletion.
3 changes: 3 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# cloudpathlib Changelog

## UNRELEASED
- Implement sliced downloads in GSClient. (Issue [#387](https://github.com/drivendataorg/cloudpathlib/issues/387), PR [#389](https://github.com/drivendataorg/cloudpathlib/pull/389))

## 0.17.0 (2023-12-21)

- Fix `S3Client` cleanup via `Client.__del__` when `S3Client` encounters an exception during initialization. (Issue [#372](https://github.com/drivendataorg/cloudpathlib/issues/372), PR [#373](https://github.com/drivendataorg/cloudpathlib/pull/373), thanks to [@bryanwweber](https://github.com/bryanwweber))
Expand Down
15 changes: 14 additions & 1 deletion cloudpathlib/gs/gsclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from google.api_core.exceptions import NotFound
from google.auth.exceptions import DefaultCredentialsError
from google.cloud.storage import Client as StorageClient
from google.cloud.storage import transfer_manager


except ModuleNotFoundError:
Expand All @@ -39,6 +40,7 @@ def __init__(
file_cache_mode: Optional[Union[str, FileCacheMode]] = None,
local_cache_dir: Optional[Union[str, os.PathLike]] = None,
content_type_method: Optional[Callable] = mimetypes.guess_type,
download_chunks_concurrently_kwargs: Optional[Dict[str, Any]] = None,
):
"""Class constructor. Sets up a [`Storage
Client`](https://googleapis.dev/python/storage/latest/client.html).
Expand Down Expand Up @@ -76,6 +78,9 @@ def __init__(
the `CLOUDPATHLIB_LOCAL_CACHE_DIR` environment variable.
content_type_method (Optional[Callable]): Function to call to guess media type (mimetype) when
writing a file to the cloud. Defaults to `mimetypes.guess_type`. Must return a tuple (content type, content encoding).
download_chunks_concurrently_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to
[`download_chunks_concurrently`](https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.transfer_manager#google_cloud_storage_transfer_manager_download_chunks_concurrently)
for sliced parallel downloads.
"""
if application_credentials is None:
application_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
Expand All @@ -92,6 +97,8 @@ def __init__(
except DefaultCredentialsError:
self.client = StorageClient.create_anonymous_client()

self.download_chunks_concurrently_kwargs = download_chunks_concurrently_kwargs

super().__init__(
local_cache_dir=local_cache_dir,
content_type_method=content_type_method,
Expand All @@ -118,7 +125,13 @@ def _download_file(self, cloud_path: GSPath, local_path: Union[str, os.PathLike]

local_path = Path(local_path)

blob.download_to_filename(local_path)
if self.download_chunks_concurrently_kwargs is not None:
transfer_manager.download_chunks_concurrently(
blob, local_path, **self.download_chunks_concurrently_kwargs
)
else:
blob.download_to_filename(local_path)

return local_path

def _is_file_or_dir(self, cloud_path: GSPath) -> Optional[str]:
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .mock_clients.mock_gs import (
mocked_client_class_factory as mocked_gsclient_class_factory,
DEFAULT_GS_BUCKET_NAME,
MockTransferManager,
)
from .mock_clients.mock_s3 import mocked_session_class_factory, DEFAULT_S3_BUCKET_NAME

Expand Down Expand Up @@ -184,6 +185,11 @@ def gs_rig(request, monkeypatch, assets_dir):
"StorageClient",
mocked_gsclient_class_factory(test_dir),
)
monkeypatch.setattr(
cloudpathlib.gs.gsclient,
"transfer_manager",
MockTransferManager,
)

rig = CloudProviderTestRig(
path_class=GSPath,
Expand Down
31 changes: 31 additions & 0 deletions tests/mock_clients/mock_gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ def patch(self):
if "updated" in self.metadata:
(self.bucket / self.name).touch()

def reload(
self,
client=None,
projection="noAcl",
if_etag_match=None,
if_etag_not_match=None,
if_generation_match=None,
if_generation_not_match=None,
if_metageneration_match=None,
if_metageneration_not_match=None,
timeout=None,
retry=None,
):
pass

def upload_from_filename(self, filename, content_type=None):
data = Path(filename).read_bytes()
path = self.bucket / self.name
Expand Down Expand Up @@ -153,3 +168,19 @@ def __next__(self):
@property
def prefixes(self):
return self.sub_directories


class MockTransferManager:
@staticmethod
def download_chunks_concurrently(
blob,
filename,
chunk_size=32 * 1024 * 1024,
download_kwargs=None,
deadline=None,
worker_type="process",
max_workers=8,
*,
crc32c_checksum=True,
):
blob.download_to_filename(filename)
10 changes: 10 additions & 0 deletions tests/test_gs_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,13 @@ def test_gspath_properties(path_class):
p2 = path_class("gs://mybucket/")
assert p2.blob == ""
assert p2.bucket == "mybucket"


@pytest.mark.parametrize("worker_type", ["process", "thread"])
def test_concurrent_download(gs_rig, tmp_path, worker_type):
client = gs_rig.client_class(download_chunks_concurrently_kwargs={"worker_type": worker_type})
p = gs_rig.create_cloud_path("dir_0/file0_0.txt", client=client)
dl_dir = tmp_path
assert not (dl_dir / p.name).exists()
p.download_to(dl_dir)
assert (dl_dir / p.name).is_file()

0 comments on commit 9936a10

Please sign in to comment.