Skip to content

Commit

Permalink
PYTHON-4669 - Update Async GridFS APIs for Motor Compatibility (#1821)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp authored Sep 4, 2024
1 parent 5a49ccc commit 4e74c82
Show file tree
Hide file tree
Showing 10 changed files with 1,115 additions and 105 deletions.
101 changes: 60 additions & 41 deletions gridfs/asynchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,24 +1176,6 @@ def __getattr__(self, name: str) -> Any:
raise AttributeError("GridIn object has no attribute '%s'" % name)

def __setattr__(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
if _IS_SYNC:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"AsyncGridIn does not support __setattr__. Use AsyncGridIn.set() instead"
)

async def set(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
Expand All @@ -1204,9 +1186,17 @@ async def set(self, name: str, value: Any) -> None:
# them now.
self._file[name] = value
if self._closed:
await self._coll.files.update_one(
{"_id": self._file["_id"]}, {"$set": {name: value}}
)
if _IS_SYNC:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead"
)

async def set(self, name: str, value: Any) -> None:
self._file[name] = value
if self._closed:
await self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})

async def _flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
Expand Down Expand Up @@ -1400,7 +1390,11 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
return False


class AsyncGridOut(io.IOBase):
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any


class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore

"""Class to read data out of GridFS."""

def __init__(
Expand Down Expand Up @@ -1460,6 +1454,8 @@ def __init__(
self._position = 0
self._file = file_document
self._session = session
if not _IS_SYNC:
self.closed = False

_id: Any = _a_grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _a_grid_out_property("filename", "Name of this file.")
Expand All @@ -1486,16 +1482,43 @@ def __init__(
_file: Any
_chunk_iter: Any

async def __anext__(self) -> bytes:
return super().__next__()
if not _IS_SYNC:
closed: bool

def __next__(self) -> bytes: # noqa: F811, RUF100
if _IS_SYNC:
return super().__next__()
else:
raise TypeError(
"AsyncGridOut does not support synchronous iteration. Use `async for` instead"
)
async def __anext__(self) -> bytes:
line = await self.readline()
if line:
return line
raise StopAsyncIteration()

async def to_list(self) -> list[bytes]:
return [x async for x in self] # noqa: C416, RUF100

async def readline(self, size: int = -1) -> bytes:
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return await self._read_size_or_line(size=size, line=True)

async def readlines(self, size: int = -1) -> list[bytes]:
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
await self.open()
lines = []
remainder = int(self.length) - self._position
bytes_read = 0
while remainder > 0:
line = await self._read_size_or_line(line=True)
bytes_read += len(line)
lines.append(line)
remainder = int(self.length) - self._position
if 0 < size < bytes_read:
break

return lines

async def open(self) -> None:
if not self._file:
Expand Down Expand Up @@ -1616,18 +1639,11 @@ async def read(self, size: int = -1) -> bytes:
"""
return await self._read_size_or_line(size=size)

async def readline(self, size: int = -1) -> bytes: # type: ignore[override]
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return await self._read_size_or_line(size=size, line=True)

def tell(self) -> int:
"""Return the current position of this file."""
return self._position

async def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
async def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.
:param pos: the position (or offset if using relative
Expand Down Expand Up @@ -1690,12 +1706,15 @@ def __aiter__(self) -> AsyncGridOut:
"""
return self

async def close(self) -> None: # type: ignore[override]
async def close(self) -> None:
"""Make GridOut more generically file-like."""
if self._chunk_iter:
await self._chunk_iter.close()
self._chunk_iter = None
super().close()
if _IS_SYNC:
super().close()
else:
self.closed = True

def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")
Expand Down
19 changes: 19 additions & 0 deletions gridfs/grid_file_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,31 @@ def _a_grid_in_property(
) -> Any:
"""Create a GridIn property."""

warn_str = ""
if docstring.startswith("DEPRECATED,"):
warn_str = (
f"GridIn property '{field_name}' is deprecated and will be removed in PyMongo 5.0"
)

def getter(self: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
if closed_only and not self._closed:
raise AttributeError("can only get %r on a closed file" % field_name)
# Protect against PHP-237
if field_name == "length":
return self._file.get(field_name, 0)
return self._file.get(field_name, None)

def setter(self: Any, value: Any) -> Any:
if warn_str:
warnings.warn(warn_str, stacklevel=2, category=DeprecationWarning)
if self._closed:
raise InvalidOperation(
"AsyncGridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use AsyncGridIn.set() instead"
)
self._file[field_name] = value

if read_only:
docstring += "\n\nThis attribute is read-only."
elif closed_only:
Expand All @@ -56,6 +73,8 @@ def getter(self: Any) -> Any:
"has been called.",
)

if not read_only and not closed_only:
return property(getter, setter, doc=docstring)
return property(getter, doc=docstring)


Expand Down
97 changes: 60 additions & 37 deletions gridfs/synchronous/grid_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,24 +1166,6 @@ def __getattr__(self, name: str) -> Any:
raise AttributeError("GridIn object has no attribute '%s'" % name)

def __setattr__(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
object.__setattr__(self, name, value)
else:
if _IS_SYNC:
# All other attributes are part of the document in db.fs.files.
# Store them to be sent to server on close() or if closed, send
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"GridIn does not support __setattr__. Use GridIn.set() instead"
)

def set(self, name: str, value: Any) -> None:
# For properties of this instance like _buffer, or descriptors set on
# the class like filename, use regular __setattr__
if name in self.__dict__ or name in self.__class__.__dict__:
Expand All @@ -1194,7 +1176,17 @@ def set(self, name: str, value: Any) -> None:
# them now.
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
if _IS_SYNC:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})
else:
raise AttributeError(
"GridIn does not support __setattr__ after being closed(). Set the attribute before closing the file or use GridIn.set() instead"
)

def set(self, name: str, value: Any) -> None:
self._file[name] = value
if self._closed:
self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}})

def _flush_data(self, data: Any, force: bool = False) -> None:
"""Flush `data` to a chunk."""
Expand Down Expand Up @@ -1388,7 +1380,11 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
return False


class GridOut(io.IOBase):
GRIDOUT_BASE_CLASS = io.IOBase if _IS_SYNC else object # type: Any


class GridOut(GRIDOUT_BASE_CLASS): # type: ignore

"""Class to read data out of GridFS."""

def __init__(
Expand Down Expand Up @@ -1448,6 +1444,8 @@ def __init__(
self._position = 0
self._file = file_document
self._session = session
if not _IS_SYNC:
self.closed = False

_id: Any = _grid_out_property("_id", "The ``'_id'`` value for this file.")
filename: str = _grid_out_property("filename", "Name of this file.")
Expand All @@ -1474,14 +1472,43 @@ def __init__(
_file: Any
_chunk_iter: Any

def __next__(self) -> bytes:
return super().__next__()
if not _IS_SYNC:
closed: bool

def __next__(self) -> bytes: # noqa: F811, RUF100
if _IS_SYNC:
return super().__next__()
else:
raise TypeError("GridOut does not support synchronous iteration. Use `for` instead")
def __next__(self) -> bytes:
line = self.readline()
if line:
return line
raise StopIteration()

def to_list(self) -> list[bytes]:
return [x for x in self] # noqa: C416, RUF100

def readline(self, size: int = -1) -> bytes:
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return self._read_size_or_line(size=size, line=True)

def readlines(self, size: int = -1) -> list[bytes]:
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
self.open()
lines = []
remainder = int(self.length) - self._position
bytes_read = 0
while remainder > 0:
line = self._read_size_or_line(line=True)
bytes_read += len(line)
lines.append(line)
remainder = int(self.length) - self._position
if 0 < size < bytes_read:
break

return lines

def open(self) -> None:
if not self._file:
Expand Down Expand Up @@ -1602,18 +1629,11 @@ def read(self, size: int = -1) -> bytes:
"""
return self._read_size_or_line(size=size)

def readline(self, size: int = -1) -> bytes: # type: ignore[override]
"""Read one line or up to `size` bytes from the file.
:param size: the maximum number of bytes to read
"""
return self._read_size_or_line(size=size, line=True)

def tell(self) -> int:
"""Return the current position of this file."""
return self._position

def seek(self, pos: int, whence: int = _SEEK_SET) -> int: # type: ignore[override]
def seek(self, pos: int, whence: int = _SEEK_SET) -> int:
"""Set the current position of this file.
:param pos: the position (or offset if using relative
Expand Down Expand Up @@ -1676,12 +1696,15 @@ def __iter__(self) -> GridOut:
"""
return self

def close(self) -> None: # type: ignore[override]
def close(self) -> None:
"""Make GridOut more generically file-like."""
if self._chunk_iter:
self._chunk_iter.close()
self._chunk_iter = None
super().close()
if _IS_SYNC:
super().close()
else:
self.closed = True

def write(self, value: Any) -> NoReturn:
raise io.UnsupportedOperation("write")
Expand Down
5 changes: 5 additions & 0 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ async def inner(*args: Any, **kwargs: Any) -> Any:

if sys.version_info >= (3, 10):
anext = builtins.anext
aiter = builtins.aiter
else:

async def anext(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return await cls.__anext__()

def aiter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext."""
return cls.__aiter__()
2 changes: 1 addition & 1 deletion pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ async def _process_change(
if server:
await server.pool.reset(interrupt_connections=interrupt_connections)

# Wake waiters in select_servers().
# Wake anything waiting in select_servers().
self._condition.notify_all()

async def on_change(
Expand Down
5 changes: 5 additions & 0 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def inner(*args: Any, **kwargs: Any) -> Any:

if sys.version_info >= (3, 10):
next = builtins.next
iter = builtins.iter
else:

def next(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
return cls.__next__()

def iter(cls: Any) -> Any:
"""Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next."""
return cls.__iter__()
Loading

0 comments on commit 4e74c82

Please sign in to comment.