Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GooglePathHandler for GCS input and output #40

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ install_dep: &install_dep
- run:
name: Install Dependencies
command: |
pip install --progress-bar off torch shapely flake8 flake8-bugbear flake8-comprehensions isort 'black @ git+https://github.com/psf/black@673327449f86fce558adde153bb6cbe54bfebad2'
pip install --progress-bar off torch shapely

install_fvcore: &install_fvcore
- run:
Expand Down Expand Up @@ -83,19 +83,6 @@ jobs:

- <<: *install_fvcore

- run:
name: isort
command: |
isort -c -sp .
- run:
name: black
command: |
black --check .
- run:
name: flake8
command: |
flake8 .

- <<: *run_unittests

- store_artifacts:
Expand Down
245 changes: 245 additions & 0 deletions fvcore/common/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import os
import shutil
import traceback
import types
from collections import OrderedDict
from typing import IO, Any, Callable, Dict, List, MutableMapping, Optional, Union
from urllib.parse import urlparse

import portalocker # type: ignore
from fvcore.common.download import download
from google.cloud import storage


__all__ = ["LazyPath", "PathManager", "get_cache_dir", "file_lock"]
Expand Down Expand Up @@ -588,6 +590,248 @@ def _get_local_path(self, path: str, **kwargs: Any) -> str:
return PathManager.get_local_path(os.fspath(direct_url), **kwargs)


# Override for close() on files to write to google cloud
def close_and_upload(self):
mode = self.mode
name = self.name
self._close()
with open(name, mode.replace("w", "r")) as file_to_upload:
self._gc_blob.upload_from_file(file_to_upload)


class GoogleCloudHandler(PathHandler):
"""
Support for Google Cloud Storage file system
"""

def _get_supported_prefixes(self) -> List[str]:
"""
Returns:
List[str]: the list of URI prefixes this PathHandler can support
"""
return ["gs://"]

def _get_local_path(self, path: str, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk. In this case, the cache stays on filesystem
(under `file_io.get_cache_dir()`) and will be used by a different run.
Therefore this function is meant to be used with read-only resources.
Args:
path (str): A URI supported by this PathHandler
Returns:
local_path (str): a file path which exists on the local file system
"""
self._cache_remote_file(path)
return self._get_local_cache_path(path)

def _copy_from_local(
self, local_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> bool:
"""
Copies a local file to the specified URI.
If the URI is another local path, this should be functionally identical
to copy.
Args:
local_path (str): a file path which exists on the local file system
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing URI
Returns:
status (bool): True on success
"""
return self._upload_file(dst_path, local_path)

def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
self._cache_remote_file(path)
return self._open_local_copy(path, mode)

def _copy(
self, src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""

if not self._cache_remote_file(src_path):
return False
local_path = self._get_local_cache_path(src_path)
return self._copy_from_local(local_path, dst_path)

def _exists(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
return self._get_blob(path).exists()

def _isfile(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""

return "." in path.split("/")[-1]

def _isdir(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
return "/" == path[-1]

def _ls(self, path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
raise NotImplementedError()

def _mkdirs(self, path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
# GCS does this automatically
pass

def _rm(self, path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
if not self._exists(path):
return
if self._isdir(path):
return
self._delete_remote_resource(path)

def _get_gc_bucket(self, path: str) -> storage.Bucket:
if not hasattr(self, "_gc_client"):
self._create_gc_client(path)
gc_bucket_name = self._extract_gc_bucket_name(path)
return self._gc_client.get_bucket(gc_bucket_name)

def _create_gc_client(self, path: str):
namespace = self._extract_gc_namespace(path)
gc_client = storage.Client(project=namespace)
self._gc_client = gc_client

def _get_blob(self, path: str) -> storage.Blob:
gc_bucket = self._get_gc_bucket(path)
return gc_bucket.blob(self._extract_blob_path(path))

def _cache_blob(self, local_path: str, gc_blob: storage.Blob) -> bool:
if not gc_blob.exists():
return False
with open(local_path, "wb") as file:
gc_blob.download_to_file(file)
return True

def _upload_file(self, destination_path: str, local_path: str):
gc_blob = self._get_blob(destination_path)
if not gc_blob._exists():
return False
with open(local_path, "r") as file:
gc_blob.upload_from_file(file)
return True

def _cache_remote_file(self, remote_path: str):
local_path = self._get_local_cache_path(remote_path)
local_directory = self._get_local_cache_directory(remote_path)
self._maybe_make_directory(local_directory)
gc_blob = self._get_blob(remote_path)
return self._cache_blob(local_path, gc_blob)

def _open_local_copy(self, path: str, mode: str) -> Union[IO[str], IO[bytes]]:
local_path = self._get_local_cache_path(path)
gc_blob = self._get_blob(path)
file = open(local_path, mode)
if "w" in mode:
self._decorate_file_with_gc_methods(file, gc_blob)
return file

def _delete_remote_resource(self, path):
self._get_blob(path).delete()

def _decorate_file_with_gc_methods(
self, file: Union[IO[str], IO[bytes]], gc_blob: storage.Blob
):
file._gc_blob = gc_blob
file._close = file.close
file.close = types.MethodType(close_and_upload, file)

def _maybe_make_directory(self, path: str) -> bool:
is_made = False
with file_lock(path):
if not os.path.exists(path):
os.makedirs(path)
is_made = True
return is_made

def _extract_gc_namespace(self, path: str) -> str:
return self._extract_gc_bucket_name(path).replace("-data", "")

def _extract_gc_bucket_name(self, path: str) -> str:
return self._remove_file_system(path).split("/")[0]

def _remove_file_system(self, path: str) -> str:
return path.replace("gs://", "")

def _remove_bucket_name(self, path: str) -> str:
return path.replace(self._extract_gc_bucket_name(path) + "/", "")

def _extract_blob_path(self, path: str) -> str:
return self._remove_file_system(self._remove_bucket_name(path))

def _get_local_cache_path(self, path: str) -> str:
path = self._extract_blob_path(path)
return "/".join([".", "tmp", path])

def _get_local_cache_directory(self, path: str) -> str:
path = self._get_local_cache_path(path)
return path.replace(path.split("/")[-1], "")


# NOTE: this class should be renamed back to PathManager when it is moved to a new library
class PathManagerBase:
"""
Expand Down Expand Up @@ -896,3 +1140,4 @@ def set_strict_kwargs_checking(self, enable: bool) -> None:

PathManager.register_handler(HTTPURLHandler())
PathManager.register_handler(OneDrivePathHandler())
PathManager.register_handler(GoogleCloudHandler())
12 changes: 12 additions & 0 deletions fvcore/common/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Dict, Optional

from tabulate import tabulate


class Registry(object):
"""
Expand Down Expand Up @@ -73,3 +75,13 @@ def get(self, name: str) -> object:

def __contains__(self, name: str) -> bool:
return name in self._obj_map

def __repr__(self) -> str:
table_headers = ["Names", "Objects"]
table = tabulate(
self._obj_map.items(), headers=table_headers, tablefmt="fancy_grid"
)
return "Registry of {}:\n".format(self._name) + table

# pyre-fixme[4]: Attribute must be annotated.
__str__ = __repr__
1 change: 1 addition & 0 deletions packaging/fvcore/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ requirements:
- termcolor
- pillow
- tabulate
- google-cloud-storage

build:
string: py{{py}}
Expand Down
1 change: 1 addition & 0 deletions setup.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def get_version():
"termcolor>=1.1",
"Pillow",
"tabulate",
"google-cloud-storage",
],
extras_require={"all": ["shapely"]},
packages=find_packages(exclude=("tests",)),
Expand Down
5 changes: 2 additions & 3 deletions tests/bm_main.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import glob
import importlib

# pyre-fixme[21]: Could not find name `sys` in `os.path`.
from os.path import basename, dirname, isfile, join, sys
import sys
from os.path import basename, dirname, isfile, join


if __name__ == "__main__":
Expand Down
Loading