Skip to content

Commit

Permalink
Python/Smart_Open improve read/write API.
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Spörri <[email protected]>
  • Loading branch information
pspoerri committed Apr 17, 2024
1 parent 29c57ef commit 0e7132c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 31 deletions.
3 changes: 2 additions & 1 deletion src/python/geds_smart_open/src/geds_smart_open/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from . import geds
from .geds import register_object_store
from .geds import relocate

register_transport(geds)

__all__ = ["GEDS", "register_object_store"]
__all__ = ["GEDS", "register_object_store", "relocate"]
35 changes: 19 additions & 16 deletions src/python/geds_smart_open/src/geds_smart_open/geds.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
self.bucket = bucket
self.key = key
self.position = 0
self._size = file.size
self.file = file
self.raw = None
self.line_terminator = line_terminator
Expand All @@ -66,7 +67,7 @@ def close(self) -> None:

@property
def size(self) -> int:
return self.file.size
return self._size

@property
def closed(self) -> bool:
Expand Down Expand Up @@ -102,7 +103,7 @@ def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
elif whence == io.SEEK_CUR:
self.position += offset
elif whence == io.SEEK_END:
self.position = self.file.size + offset
self.position = self.size + offset
return self.position

def tell(self) -> int:
Expand All @@ -119,24 +120,22 @@ def read(self, limit: int = -1):
"""
self.checkClosed()
self.checkReadable()
maxcount = self.file.size - self.position
maxcount = self.size - self.position
assert maxcount >= 0
count = limit
if limit == 0:
return b""
if limit < 0 or limit > maxcount:
count = maxcount
buffer = bytearray(count)
count = self.readinto(buffer)
if count < len(buffer):
return buffer[0:count]
return buffer

count = maxcount - self.position
return self.file.read1(self.position, count)

def readinto(self, buffer):
self.checkReadable()
if self.closed:
return -1
count = self.file.read(buffer, self.position, len(buffer))
count = self.file.readinto1(buffer, self.position, len(buffer))
self.position += count
return count

Expand All @@ -148,13 +147,13 @@ def readline(self, limit: int = -1) -> bytes:
print("readline " + limit)
if limit != -1:
raise NotImplementedError("limits other than -1 not implemented yet")
buffer = bytearray(self.buffer_size)
# buffer = bytearray(self.buffer_size)
line = io.BytesIO()

while True:
previous_position = self.position
count = self.readinto(buffer)
if count == 0:
buffer = self.file.read(self.position, self.buffer_size)
if len(buffer) == 0:
break
index = buffer.find(self.line_terminator, 0)
if index > 0:
Expand All @@ -167,7 +166,7 @@ def readline(self, limit: int = -1) -> bytes:
def readall(self) -> bytes:
self.checkClosed()

length = self.file.size - self.position
length = self.size - self.position
buffer = bytearray(length)
count = self.readinto(buffer)
return buffer[0:count]
Expand Down Expand Up @@ -290,6 +289,8 @@ def register_object_store(
):
GEDSInstance.register_object_store(bucket, endpoint_url, access_key, secret_key)

def relocate(force: bool = False):
GEDSInstance.get().relocate(force)

def parse_uri(uri: str):
path = uri.removeprefix("geds://")
Expand Down Expand Up @@ -328,10 +329,12 @@ def open(bucket: str, key: str, mode: str, client=None):
if mode == constants.READ_BINARY:
f = client.open(bucket, key)
elif mode == constants.WRITE_BINARY:
try:
f = client.create(bucket, key, True)
elif mode == 'ab':
f = client.open(bucket, key)
if not f.writable():
client.copy(bucket, key, bucket, key)
f = client.open(bucket, key)
except:
f = client.create(bucket, key)
else:
raise ValueError(f"Invalid argument for mode: {mode}")
return GEDSRawInputBase(
Expand Down
30 changes: 16 additions & 14 deletions src/python/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,25 +188,27 @@ PYBIND11_MODULE(pygeds, m) {
return self.setMetadata(reinterpret_cast<const uint8_t *>(buffer), length, seal);
},
py::arg("buffer"), py::arg("length") = std::nullopt, py::arg("seal") = true)
.def("read",
[](GEDSFile &self, py::buffer buffer, size_t position,
size_t length) -> absl::StatusOr<size_t> {
.def("read1",
[](GEDSFile &self, size_t position,
size_t length) -> absl::StatusOr<py::array_t<uint8_t>> {
auto result = py::array_t<uint8_t>(length);
py::buffer_info buffer = result.request(true);
py::gil_scoped_release release;
auto status = self.read(static_cast<uint8_t *>(buffer.ptr), position, length);
if (!status.ok()) {
return status.status();
}
result.resize({*status});
return result;
})
.def("readinto1",
[](GEDSFile &self, py::buffer buffer, size_t position) -> absl::StatusOr<size_t> {
py::buffer_info info = buffer.request(true);
if (info.ndim != 1) {
return absl::FailedPreconditionError("Buffer has wrong dimensions!");
}
if ((size_t)info.size < length) {
return absl::FailedPreconditionError("The buffer does not have sufficient space!");
}
length = std::min<size_t>(info.size, length);
py::gil_scoped_release release;
return self.read(static_cast<uint8_t *>(info.ptr), position, length);
})
.def("read",
[](GEDSFile &self, char *array, size_t position,
size_t length) -> absl::StatusOr<size_t> {
py::gil_scoped_release release;
return self.read(reinterpret_cast<uint8_t *>(array), position, length);
return self.read(static_cast<uint8_t *>(info.ptr), position, info.size);
})
.def("write",
[](GEDSFile &self, const py::buffer buffer, size_t position,
Expand Down

0 comments on commit 0e7132c

Please sign in to comment.