From 5a321d0d87587d590b7e8afb77e68dbd3d89e31d Mon Sep 17 00:00:00 2001 From: davidparks21 Date: Sun, 9 Oct 2022 21:34:52 -0700 Subject: [PATCH] Rewrite of s3.Reader class to protect S3 servers from open range headers (#725) --- smart_open/s3.py | 695 +++++++++++++++++++++++------------- smart_open/tests/test_s3.py | 179 ++++------ 2 files changed, 515 insertions(+), 359 deletions(-) diff --git a/smart_open/s3.py b/smart_open/s3.py index c5959bdb..b573de23 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -38,7 +38,8 @@ DEFAULT_PORT = 443 DEFAULT_HOST = 's3.amazonaws.com' -DEFAULT_BUFFER_SIZE = 128 * 1024 +DEFAULT_BUFFER_SIZE = 1500 +DEFAULT_STREAM_RANGE = 10485760 URI_EXAMPLES = ( 's3://my_bucket/my_key', @@ -232,7 +233,7 @@ def open( buffer_size=DEFAULT_BUFFER_SIZE, min_part_size=DEFAULT_MIN_PART_SIZE, multipart_upload=True, - defer_seek=False, + stream_range=10485760, client=None, client_kwargs=None, writebuffer=None, @@ -260,11 +261,16 @@ def open( version_id: str, optional Version of the object, used when reading object. If None, will fetch the most recent version. - defer_seek: boolean, optional - Default: `False` - If set to `True` on a file opened for reading, GetObject will not be - called until the first seek() or read(). - Avoids redundant API queries when seeking before reading. + stream_range: str, optional + Default: 10485760 bytes (10 MB) + The stream_range setting limits the size of data that may be streamed through a + single HTTP request across multiple read calls, this is an important protection + for the S3 server, ensuring the S3 server doesn't get an open-ended byte-range request + which can cause it to internally queue up a massive file when only a small bit of it may ultimately be + read by the user. Note that the first read call (after opening or seeking) will always set the + byte range header to exactly the read size, an optimization for use cases in which a single + small read is performed against a large file (example: random reading of small data samples + from large files in machine learning contexts). client: object, optional The S3 client to use when working with boto3. If you don't specify this, then smart_open will create a new client for you. @@ -280,6 +286,7 @@ def open( disk IO. If you pass in an open file, then you are responsible for cleaning it up after writing completes. """ + logger.debug('%r', locals()) if mode not in constants.BINARY_MODES: raise NotImplementedError('bad mode: %r expected one of %r' % (mode, constants.BINARY_MODES)) @@ -293,7 +300,7 @@ def open( key_id, version_id=version_id, buffer_size=buffer_size, - defer_seek=defer_seek, + stream_range=stream_range, client=client, client_kwargs=client_kwargs, ) @@ -338,6 +345,22 @@ def _get(client, bucket, key, version, range_string): raise wrapped_error from error +def _head(client, bucket, key, version): + try: + if version: + return client.head_object(Bucket=bucket, Key=key, VersionId=version) + else: + return client.head_object(Bucket=bucket, Key=key) + except botocore.client.ClientError as error: + wrapped_error = IOError( + 'unable to access bucket: %r key: %r version: %r error: %s' % ( + bucket, key, version, error + ) + ) + wrapped_error.backend_error = error + raise wrapped_error from error + + def _unwrap_ioerror(ioe): """Given an IOError from _get, return the 'Error' dictionary from boto.""" try: @@ -346,178 +369,172 @@ def _unwrap_ioerror(ioe): return None -class _SeekableRawReader(object): - """Read an S3 object. - - This class is internal to the S3 submodule. - """ - - def __init__( - self, - client, - bucket, - key, - version_id=None, - ): - self._client = client - self._bucket = bucket - self._key = key - self._version_id = version_id - - self._content_length = None - self._position = 0 - self._body = None - - def seek(self, offset, whence=constants.WHENCE_START): - """Seek to the specified position. - - :param int offset: The offset in bytes. - :param int whence: Where the offset is from. - - :returns: the position after seeking. - :rtype: int - """ - if whence not in constants.WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) - - # - # Close old body explicitly. - # When first seek() after __init__(), self._body is not exist. - # - if self._body is not None: - self._body.close() - self._body = None - - start = None - stop = None - if whence == constants.WHENCE_START: - start = max(0, offset) - elif whence == constants.WHENCE_CURRENT: - start = max(0, offset + self._position) - else: - stop = max(0, -offset) - - # - # If we can figure out that we've read past the EOF, then we can save - # an extra API call. - # - if self._content_length is None: - reached_eof = False - elif start is not None and start >= self._content_length: - reached_eof = True - elif stop == 0: - reached_eof = True - else: - reached_eof = False - - if reached_eof: - self._body = io.BytesIO() - self._position = self._content_length - else: - self._open_body(start, stop) - - return self._position - - def _open_body(self, start=None, stop=None): - """Open a connection to download the specified range of bytes. Store - the open file handle in self._body. - - If no range is specified, start defaults to self._position. - start and stop follow the semantics of the http range header, - so a stop without a start will read bytes beginning at stop. - - As a side effect, set self._content_length. Set self._position - to self._content_length if start is past end of file. - """ - if start is None and stop is None: - start = self._position - range_string = smart_open.utils.make_range_string(start, stop) - - try: - # Optimistically try to fetch the requested content range. - response = _get( - self._client, - self._bucket, - self._key, - self._version_id, - range_string, - ) - except IOError as ioe: - # Handle requested content range exceeding content size. - error_response = _unwrap_ioerror(ioe) - if error_response is None or error_response.get('Code') != _OUT_OF_RANGE: - raise - self._position = self._content_length = int(error_response['ActualObjectSize']) - self._body = io.BytesIO() - else: - # - # Keep track of how many times boto3's built-in retry mechanism - # activated. - # - # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#checking-retry-attempts-in-an-aws-service-response - # - logger.debug( - '%s: RetryAttempts: %d', - self, - response['ResponseMetadata']['RetryAttempts'], - ) - units, start, stop, length = smart_open.utils.parse_content_range(response['ContentRange']) - self._content_length = length - self._position = start - self._body = response['Body'] - - def read(self, size=-1): - """Read from the continuous connection with the remote peer.""" - if self._body is None: - # This is necessary for the very first read() after __init__(). - self._open_body() - if self._position >= self._content_length: - return b'' - - # - # Boto3 has built-in error handling and retry mechanisms: - # - # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/error-handling.html - # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html - # - # Unfortunately, it isn't always enough. There is still a non-zero - # possibility that an exception will slip past these mechanisms and - # terminate the read prematurely. Luckily, at this stage, it's very - # simple to recover from the problem: wait a little bit, reopen the - # HTTP connection and try again. Usually, a single retry attempt is - # enough to recover, but we try multiple times "just in case". - # - for attempt, seconds in enumerate([1, 2, 4, 8, 16], 1): - try: - if size == -1: - binary = self._body.read() - else: - binary = self._body.read(size) - except ( - ConnectionResetError, - botocore.exceptions.BotoCoreError, - urllib3.exceptions.HTTPError, - ) as err: - logger.warning( - '%s: caught %r while reading %d bytes, sleeping %ds before retry', - self, - err, - size, - seconds, - ) - time.sleep(seconds) - self._open_body() - else: - self._position += len(binary) - return binary - - raise IOError('%s: failed to read %d bytes after %d attempts' % (self, size, attempt)) - - def __str__(self): - return 'smart_open.s3._SeekableReader(%r, %r)' % (self._bucket, self._key) +# class _SeekableRawReader(object): +# """Read an S3 object. +# +# This class is internal to the S3 submodule. +# """ +# +# def __init__( +# self, +# client, +# bucket, +# key, +# stream_range, +# version_id=None, +# ): +# self._client = client +# self._bucket = bucket +# self._key = key +# self._version_id = version_id +# +# self._content_length = None +# self._position = 0 +# self._body = None +# +# # the max_stream_size setting limits how much data will be read in a single HTTP request, this is an +# # important protection for the S3 server, ensuring the S3 server doesn't get an open-ended byte-range request +# # which can cause it to internally queue up a massive file when only a small bit of it may ultimately be +# # read by the user. The variable _stream_range_[from|to] tracks the range of bytes that can be read +# # from the current request body (e.g. from the same HTTP request). Note that the first read call +# # will always set the byte range header to exactly the read size, an optimization for uses cases in which a +# # single small read is performed against a large file (example: random sampling small data samples from +# # large files in machine learning contexts). +# self._stream_range = stream_range +# self._stream_range_from = None # a None value signifies the first call to `read` where this will be set +# self._stream_range_to = None +# +# def seek(self, offset, whence=constants.WHENCE_START): +# """Seek to the specified position. +# +# :param int offset: The offset in bytes. +# :param int whence: Where the offset is from. +# +# :returns: the position after seeking. +# :rtype: int +# """ +# if whence not in constants.WHENCE_CHOICES: +# raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) +# if whence == constants.WHENCE_END and offset > 0: +# raise ValueError('offset must be <= 0 when whence == WHENCE_END, got offset: ' + offset) +# +# if whence == constants.WHENCE_END and self._content_length is None: +# # WHENCE_END: head request is necessary to determine file length if it's not known yet +# # this is necessary to return the absolute position as specified by io.IOBase +# response = _head(self._client, self._bucket, self._key, self._version_id) +# _log_retry_attempts(response) +# self._content_length = int(response['ContentLength']) +# self._position = self._content_length + offset +# elif whence == constants.WHENCE_END: +# # WHENCE_END: we already have file length, no API call needed to compute the absolute position +# self._position = self._content_length + offset +# else: +# # WHENCE_START or WHENCE_CURRENT +# start = 0 if whence == constants.WHENCE_START else self._position +# self._position = start + offset +# +# return self._position +# +# def _open_body(self, start=None, stop=None): +# """Open a connection to download the specified range of bytes. Store +# the open file handle in self._body. +# +# If no range is specified, start defaults to self._position. +# start and stop follow the semantics of the http range header, +# so a stop without a start will read bytes beginning at stop. +# +# As a side effect, set self._content_length. Set self._position +# to self._content_length if start is past end of file. +# """ +# if start is None and stop is None: +# start = self._position +# range_string = smart_open.utils.make_range_string(start, stop) +# +# try: +# # Optimistically try to fetch the requested content range. +# response = _get( +# self._client, +# self._bucket, +# self._key, +# self._version_id, +# range_string, +# ) +# except IOError as ioe: +# # Handle requested content range exceeding content size. +# error_response = _unwrap_ioerror(ioe) +# if error_response is None or error_response.get('Code') != _OUT_OF_RANGE: +# raise +# self._position = self._content_length = int(error_response['ActualObjectSize']) +# self._body = io.BytesIO() +# else: +# _log_retry_attempts(response) # keep track of how many retries boto3 attempted +# units, start, stop, length = smart_open.utils.parse_content_range(response['ContentRange']) +# self._content_length = length +# self._position = start +# self._body = response['Body'] +# +# def read(self, size=-1): +# """Read from the continuous connection with the remote peer.""" +# +# # If we can figure out that we've read past the EOF, then we can save +# # an extra API call. +# reached_eof = True if self._content_length is not None and self._position >= self._content_length else False +# +# if reached_eof or size == 0: +# return b'' +# +# if self._body is None: +# stop = None if size == -1 else self._position + size +# self._open_body(start=self._position, stop=stop) +# +# # +# # Boto3 has built-in error handling and retry mechanisms: +# # +# # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/error-handling.html +# # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html +# # +# # Unfortunately, it isn't always enough. There is still a non-zero +# # possibility that an exception will slip past these mechanisms and +# # terminate the read prematurely. Luckily, at this stage, it's very +# # simple to recover from the problem: wait a little bit, reopen the +# # HTTP connection and try again. Usually, a single retry attempt is +# # enough to recover, but we try multiple times "just in case". +# # +# for attempt, seconds in enumerate([1, 2, 4, 8, 16], 1): +# try: +# if size == -1: +# binary = self._body.read() +# else: +# binary = self._body.read(size) +# except ( +# ConnectionResetError, +# botocore.exceptions.BotoCoreError, +# urllib3.exceptions.HTTPError, +# ) as err: +# logger.warning( +# '%s: caught %r while reading %d bytes, sleeping %ds before retry', +# self, +# err, +# size, +# seconds, +# ) +# time.sleep(seconds) +# self._open_body() +# else: +# self._position += len(binary) +# if self._optimize == 'reading' or (self._optimize == 'auto' and self._read_call_counter == 0): +# self._body.close() +# self._body = None +# self._read_call_counter += 0 +# return binary +# +# raise IOError('%s: failed to read %d bytes after %d attempts' % (self, size, attempt)) +# +# def __str__(self): +# return 'smart_open.s3._SeekableReader(%r, %r)' % (self._bucket, self._key) -def _initialize_boto3(rw, client, client_kwargs, bucket, key): +def _initialize_boto3(client, client_kwargs, bucket, key): """Created the required objects for accessing S3. Ideally, they have been already created for us and we can just reuse them.""" if client_kwargs is None: @@ -528,9 +545,11 @@ def _initialize_boto3(rw, client, client_kwargs, bucket, key): client = boto3.client('s3', **init_kwargs) assert client - rw._client = _ClientWrapper(client, client_kwargs) - rw._bucket = bucket - rw._key = key + _client = _ClientWrapper(client, client_kwargs) + _bucket = bucket + _key = key + + return _client, _bucket, _key class Reader(io.BufferedIOBase): @@ -545,38 +564,47 @@ def __init__( version_id=None, buffer_size=DEFAULT_BUFFER_SIZE, line_terminator=constants.BINARY_NEWLINE, - defer_seek=False, + stream_range=DEFAULT_STREAM_RANGE, client=None, client_kwargs=None, ): + assert isinstance(stream_range, int), \ + 'stream_range should be an integer number of bytes that restricts the maximum size the S3 server ' \ + 'needs to prepare for a single HTTP request. Got type: ' + type(stream_range) + self._version_id = version_id self._buffer_size = buffer_size - _initialize_boto3(self, client, client_kwargs, bucket, key) + self._client, self._bucket, self._key = _initialize_boto3(client, client_kwargs, bucket, key) + self._version_id = version_id - self._raw_reader = _SeekableRawReader( - self._client, - bucket, - key, - self._version_id, - ) - self._current_pos = 0 - self._buffer = smart_open.bytebuffer.ByteBuffer(buffer_size) + self._content_length = None self._eof = False + self._position = 0 + self._body = None + self._buffer = smart_open.bytebuffer.ByteBuffer(buffer_size) self._line_terminator = line_terminator + # the max_stream_size setting limits how much data will be read in a single HTTP request, this is an + # important protection for the S3 server, ensuring the S3 server doesn't get an open-ended byte-range request + # which can cause it to internally queue up a massive file when only a small bit of it may ultimately be + # read by the user. The variable _stream_range_[from|to] tracks the range of bytes that can be read + # from the current request body (e.g. from the same HTTP request). Note that the first read call + # will always set the byte range header to exactly the read size, an optimization for uses cases in which a + # single small read is performed against a large file (example: random sampling small data samples from + # large files in machine learning contexts). + self._stream_range = stream_range + self._stream_range_from = None # a None value signifies the first call to `read` where this will be set + self._stream_range_to = None + # # This member is part of the io.BufferedIOBase interface. # self.raw = None - if not defer_seek: - self.seek(0) - # # io.BufferedIOBase methods. # - def close(self): """Flush and close this stream.""" pass @@ -587,28 +615,21 @@ def readable(self): def read(self, size=-1): """Read up to size bytes from the object and return them.""" - if size == 0: - return b'' - elif size < 0: - # call read() before setting _current_pos to make sure _content_length is set - out = self._read_from_buffer() + self._raw_reader.read() - self._current_pos = self._raw_reader._content_length - return out - # - # Return unused data first - # - if len(self._buffer) >= size: - return self._read_from_buffer(size) + # if an absolute size can be calculated we do calculate it to determine if it's available in the buffer already + if size < 0: + size = size if self._content_length is None else self._content_length - self._position - # - # If the stream is finished, return what we have. - # - if self._eof: - return self._read_from_buffer() + # Fill the buffer with at least enough data to satisfy the request + if size < 0 or len(self._buffer) < size: + user_request_size = size - len(self._buffer) + fill_amount = -1 if size < 0 else max(user_request_size, self._buffer_size) + self._fill_buffer(fill_amount) + + b = self._buffer.read(size) + self._position += len(b) - self._fill_buffer(size) - return self._read_from_buffer(size) + return b def read1(self, size=-1): """This is the same as read().""" @@ -623,10 +644,12 @@ def readinto(self, b): b[:len(data)] = data return len(data) - def readline(self, limit=-1): - """Read up to and including the next newline. Returns the bytes read.""" - if limit != -1: - raise NotImplementedError('limits other than -1 not implemented yet') + def readline(self, size=-1): + """Read up to and including the next newline. Returns the bytes read.""" + + # smart_open.bytebuffer.ByteBuffer doesn't support this yet + if size != -1: + raise NotImplementedError('size other than -1 not implemented') # # A single line may span multiple buffers. @@ -635,12 +658,13 @@ def readline(self, limit=-1): while not (self._eof and len(self._buffer) == 0): line_part = self._buffer.readline(self._line_terminator) line.write(line_part) - self._current_pos += len(line_part) + self._position += len(line_part) + self._eof = self._position == self._content_length - if line_part.endswith(self._line_terminator): + if line_part.endswith(self._line_terminator) or self._eof: break else: - self._fill_buffer() + self._fill_buffer(self._buffer_size) return line.getvalue() @@ -657,21 +681,37 @@ def seek(self, offset, whence=constants.WHENCE_START): :param int whence: Where the offset is from. Returns the position after seeking.""" - # Convert relative offset to absolute, since self._raw_reader - # doesn't know our current position. - if whence == constants.WHENCE_CURRENT: - whence = constants.WHENCE_START - offset += self._current_pos - - self._current_pos = self._raw_reader.seek(offset, whence) + if whence not in constants.WHENCE_CHOICES: + raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) + if whence == constants.WHENCE_END and offset > 0: + raise ValueError('offset must be <= 0 when whence == WHENCE_END, got offset: ' + offset) + + if whence == constants.WHENCE_END and self._content_length is None: + # WHENCE_END: head request is necessary to determine file length if it's not known yet + # this API call is necessary to return the absolute position as required by io.IOBase + response = _head(self._client, self._bucket, self._key, self._version_id) + self._log_retry_attempts(response) + self._content_length = int(response['ContentLength']) + self._position = self._content_length + offset + elif whence == constants.WHENCE_END: + # WHENCE_END: we already have file length, no API call needed to compute the absolute position + self._position = self._content_length + offset + else: + # WHENCE_START or WHENCE_CURRENT + start = 0 if whence == constants.WHENCE_START else self._position + self._position = start + offset self._buffer.empty() - self._eof = self._current_pos == self._raw_reader._content_length - return self._current_pos + if self._body is not None: + self._body.close() + self._body = self._stream_range_from = self._stream_range_to = None + self._eof = False if self._stream_range_from is None or self._position < self._content_length - 1 else True + + return self._position def tell(self): """Return the current position within the file.""" - return self._current_pos + return self._position def truncate(self, size=None): """Unsupported.""" @@ -700,20 +740,181 @@ def to_boto3(self, resource): # # Internal methods. # - def _read_from_buffer(self, size=-1): - """Remove at most size bytes from our buffer and return them.""" - size = size if size >= 0 else len(self._buffer) - part = self._buffer.read(size) - self._current_pos += len(part) - return part def _fill_buffer(self, size=-1): - size = max(size, self._buffer._chunk_size) - while len(self._buffer) < size and not self._eof: - bytes_read = self._buffer.fill(self._raw_reader) - if bytes_read == 0: - logger.debug('%s: reached EOF while filling buffer', self) - self._eof = True + # check if existing body range is sufficient to satisfy this request, if not close it so that it re-opens + # with an appropriate range. + is_stream_limit_before_eof = self._body is not None \ + and size < 0 \ + and self._stream_range_to < self._content_length - 1 + is_request_beyond_stream_limit = self._body is not None \ + and size > 0 \ + and self._stream_range_to is not None \ + and self._stream_range_to - self._position < size + if is_stream_limit_before_eof or is_request_beyond_stream_limit: + self._body.close() + self._body = self._stream_range_from = self._stream_range_to = None + + # open the HTTP request if it's not open already + if self._body is None: + start = self._position + len(self._buffer) + stop = None if size < 0 else \ + start + size - 1 if self._content_length is None else \ + start + max(size, self._stream_range) - 1 + self._open_body(start=start, stop=stop) + + b = [self._stream_from_body(size)] + bytes_read = self._buffer.fill(b) + if bytes_read == 0: + logger.debug('%s: reached EOF while filling buffer', self) + self._eof = True + + def _raw_read(self, size=-1): + """Internal read from the continuous connection with the remote peer without considering buffering.""" + + # If we can figure out that we've read past the EOF, then we can save + # an extra API call. + reached_eof = True if self._content_length is not None and self._position >= self._content_length else False + + if reached_eof or size == 0: + return b'' + + if self._body is None: + stop = None if size == -1 else self._position + size + self._open_body(start=self._position, stop=stop) + + # + # Boto3 has built-in error handling and retry mechanisms: + # + # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/error-handling.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html + # + # Unfortunately, it isn't always enough. There is still a non-zero + # possibility that an exception will slip past these mechanisms and + # terminate the read prematurely. Luckily, at this stage, it's very + # simple to recover from the problem: wait a little bit, reopen the + # HTTP connection and try again. Usually, a single retry attempt is + # enough to recover, but we try multiple times "just in case". + # + for attempt, seconds in enumerate([1, 2, 4, 8, 16], 1): + try: + if size == -1: + binary = self._body.read() + else: + binary = self._body.read(size) + except ( + ConnectionResetError, + botocore.exceptions.BotoCoreError, + urllib3.exceptions.HTTPError, + ) as err: + logger.warning( + '%s: caught %r while reading %d bytes, sleeping %ds before retry', + self, + err, + size, + seconds, + ) + time.sleep(seconds) + self._open_body() + else: + self._position += len(binary) + if self._optimize == 'reading' or (self._optimize == 'auto' and self._read_call_counter == 0): + self._body.close() + self._body = None + self._read_call_counter += 0 + return binary + + raise IOError('%s: failed to read %d bytes after %d attempts' % (self, size, attempt)) + + def _open_body(self, start=None, stop=None): + """Open a connection to download the specified range of bytes. Store + the open file handle in self._body. + + If no range is specified, start defaults to self._position. + start and stop follow the semantics of the http range header, + so a stop without a start will read bytes beginning at stop. + + As a side effect, set self._content_length. Set self._position + to self._content_length if start is past end of file. + """ + if start is None and stop is None: + start = self._position + range_string = smart_open.utils.make_range_string(start, stop) + + try: + # Optimistically try to fetch the requested content range. + response = _get( + self._client, + self._bucket, + self._key, + self._version_id, + range_string, + ) + except IOError as ioe: + # Handle requested content range exceeding content size. + error_response = _unwrap_ioerror(ioe) + if error_response is None or error_response.get('Code') != _OUT_OF_RANGE: + raise + # self._position = self._content_length = int(error_response['ActualObjectSize']) + self._content_length = int(error_response['ActualObjectSize']) + self._body = io.BytesIO() + else: + self._log_retry_attempts(response) # keep track of how many retries boto3 attempted + units, start, stop, length = smart_open.utils.parse_content_range(response['ContentRange']) + self._stream_range_from = start + # _stream_range_to is set to the minimal value for the first read, + # after that it's set to the user defined value + self._stream_range_to = stop if self._content_length is None else start + self._stream_range + self._content_length = length + self._body = response['Body'] + + def _stream_from_body(self, size=-1): + """Reads data from an open Body""" + + # + # Boto3 has built-in error handling and retry mechanisms: + # + # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/error-handling.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html + # + # Unfortunately, it isn't always enough. There is still a non-zero + # possibility that an exception will slip past these mechanisms and + # terminate the read prematurely. Luckily, at this stage, it's very + # simple to recover from the problem: wait a little bit, reopen the + # HTTP connection and try again. Usually, a single retry attempt is + # enough to recover, but we try multiple times "just in case". + # + for attempt, seconds in enumerate([1, 2, 4, 8, 16], 1): + try: + binary = self._body.read(None if size < 0 else size) # botocore requires None rather than -1 + except ( + ConnectionResetError, + botocore.exceptions.BotoCoreError, + urllib3.exceptions.HTTPError, + ) as err: + logger.warning( + '%s: caught %r while reading %d bytes, sleeping %ds before retry', + self, + err, + size, + seconds, + ) + time.sleep(seconds) + self._open_body() + else: + # self._position += len(binary) + return binary + + raise IOError('%s: failed to read %d bytes after %d attempts' % (self, size, attempt)) + + def _log_retry_attempts(self, response): + """Keep track of how many times boto3's built-in retry mechanism activated.""" + # https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#checking-retry-attempts-in-an-aws-service-response + logger.debug( + '%s: RetryAttempts: %d', + self, + response['ResponseMetadata']['RetryAttempts'], + ) def __str__(self): return "smart_open.s3.Reader(%r, %r)" % (self._bucket, self._key) @@ -750,11 +951,11 @@ def __init__( writebuffer=None, ): if min_part_size < MIN_MIN_PART_SIZE: - logger.warning("S3 requires minimum part size >= 5MB; \ -multipart upload may fail") + logger.warning("S3 requires minimum part size >= 5MB; " + "multipart upload may fail") self._min_part_size = min_part_size - _initialize_boto3(self, client, client_kwargs, bucket, key) + self._client, self._bucket, self._key = _initialize_boto3(client, client_kwargs, bucket, key) try: partial = functools.partial( @@ -968,7 +1169,7 @@ def __init__( client_kwargs=None, writebuffer=None, ): - _initialize_boto3(self, client, client_kwargs, bucket, key) + self._client, self._bucket, self._key = _initialize_boto3(client, client_kwargs, bucket, key) try: self._client.head_bucket(Bucket=bucket) diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index 4e5ca5b1..f9252c42 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -99,94 +99,6 @@ def mock_make_request(self, operation_model, *args, **kwargs): patcher.stop() -@unittest.skipUnless( - ENABLE_MOTO_SERVER, - 'The test case needs a Moto server running on the local 5000 port.' -) -class SeekableRawReaderTest(unittest.TestCase): - def setUp(self): - self._body = b'123456' - self._local_resource = boto3.resource('s3', endpoint_url='http://localhost:5000') - self._local_resource.Bucket(BUCKET_NAME).create() - self._local_resource.Object(BUCKET_NAME, KEY_NAME).put(Body=self._body) - self._local_client = boto3.client('s3', endpoint_url='http://localhost:5000') - - def tearDown(self): - self._local_resource.Object(BUCKET_NAME, KEY_NAME).delete() - self._local_resource.Bucket(BUCKET_NAME).delete() - - def test_read_from_a_closed_body(self): - reader = smart_open.s3._SeekableRawReader(self._local_client, BUCKET_NAME, KEY_NAME) - self.assertEqual(reader.read(1), b'1') - reader._body.close() - self.assertEqual(reader.read(2), b'23') - - -class CrapStream(io.BytesIO): - """Raises an exception on every second read call.""" - def __init__(self, *args, modulus=2, **kwargs): - super().__init__(*args, **kwargs) - self._count = 0 - self._modulus = modulus - - def read(self, size=-1): - self._count += 1 - if self._count % self._modulus == 0: - raise botocore.exceptions.BotoCoreError() - the_bytes = super().read(size) - return the_bytes - - -class CrapClient: - def __init__(self, data, modulus=2): - self._datasize = len(data) - self._body = CrapStream(data, modulus=modulus) - - def get_object(self, *args, **kwargs): - return { - 'ActualObjectSize': self._datasize, - 'ContentLength': self._datasize, - 'ContentRange': 'bytes 0-%d/%d' % (self._datasize, self._datasize), - 'Body': self._body, - 'ResponseMetadata': {'RetryAttempts': 1}, - } - - -class IncrementalBackoffTest(unittest.TestCase): - def test_every_read_fails(self): - reader = smart_open.s3._SeekableRawReader(CrapClient(b'hello', 1), 'bucket', 'key') - with mock.patch('time.sleep') as mock_sleep: - with self.assertRaises(IOError): - reader.read() - - # - # Make sure our incremental backoff is actually happening here. - # - mock_sleep.assert_has_calls([mock.call(s) for s in (1, 2, 4, 8, 16)]) - - def test_every_second_read_fails(self): - """Can we read from a stream that raises exceptions from time to time?""" - reader = smart_open.s3._SeekableRawReader(CrapClient(b'hello'), 'bucket', 'key') - with mock.patch('time.sleep') as mock_sleep: - assert reader.read(1) == b'h' - mock_sleep.assert_not_called() - - assert reader.read(1) == b'e' - mock_sleep.assert_called_with(1) - mock_sleep.reset_mock() - - assert reader.read(1) == b'l' - mock_sleep.reset_mock() - - assert reader.read(1) == b'l' - mock_sleep.assert_called_with(1) - mock_sleep.reset_mock() - - assert reader.read(1) == b'o' - mock_sleep.assert_called_with(1) - mock_sleep.reset_mock() - - @moto.mock_s3 class ReaderTest(BaseTest): def setUp(self): @@ -226,12 +138,35 @@ def test_iter_context_manager(self): def test_read(self): """Are S3 files read correctly?""" + # 1 api call is expected because the file size is less than the buffer size with self.assertApiCalls(GetObject=1): fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) self.assertEqual(self.body[:6], fin.read(6)) self.assertEqual(self.body[6:14], fin.read(8)) # ř is 2 bytes self.assertEqual(self.body[14:], fin.read()) # read the rest + def test_read_stream_range(self): + """Does stream_range get set minimally for the first read call only?""" + # 2 api calls are expected because the buffer size is small and stream_range is only + # used on the 2nd and subsequent read calls, that is done to optimize the open, single read, close use case. + with self.assertApiCalls(GetObject=2): + fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, buffer_size=10) + self.assertEqual(self.body[:6], fin.read(6)) + self.assertEqual(self.body[6:14], fin.read(8)) # ř is 2 bytes + self.assertEqual(self.body[14:], fin.read()) # read the rest + + def test_read_full_file(self): + """Does a full read work as expected?""" + with self.assertApiCalls(GetObject=1): + fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) + self.assertEqual(self.body, fin.read()) # read the rest + + def test_read_after_seek(self): + with self.assertApiCalls(GetObject=1): + fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, buffer_size=1) + fin.seek(6) + self.assertEqual(self.body[6:14], fin.read(8)) # ř is 2 bytes + def test_seek_beginning(self): """Does seeking to the beginning of S3 files work correctly?""" with self.assertApiCalls(GetObject=1): @@ -250,7 +185,7 @@ def test_seek_beginning(self): def test_seek_start(self): """Does seeking from the start of S3 files work correctly?""" with self.assertApiCalls(GetObject=1): - fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) + fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) seek = fin.seek(6) self.assertEqual(seek, 6) self.assertEqual(fin.tell(), 6) @@ -269,17 +204,28 @@ def test_seek_current(self): def test_seek_end(self): """Does seeking from the end of S3 files work correctly?""" - with self.assertApiCalls(GetObject=1): - fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) + with self.assertApiCalls(HeadObject=1, GetObject=1): + fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) + seek = fin.seek(-4, whence=smart_open.constants.WHENCE_END) + self.assertEqual(seek, len(self.body) - 4) + self.assertEqual(fin.read(), b'you?') + + def test_seek_end_after_read(self): + """Seeking to end after a read should not invoke API""" + with self.assertApiCalls(GetObject=2): + fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) + fin.read(1) seek = fin.seek(-4, whence=smart_open.constants.WHENCE_END) self.assertEqual(seek, len(self.body) - 4) self.assertEqual(fin.read(), b'you?') def test_seek_past_end(self): + """Seek past end will return the invalid seek value (same as system open) but read 0 bytes.""" with self.assertApiCalls(GetObject=1), patch_invalid_range_response(str(len(self.body))): - fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) + fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) seek = fin.seek(60) - self.assertEqual(seek, len(self.body)) + self.assertEqual(seek, 60) # this is consistent with system open + self.assertEqual(b'', fin.read()) def test_detect_eof(self): with self.assertApiCalls(GetObject=1): @@ -315,7 +261,7 @@ def test_read_gzip(self): self.assertEqual(zipfile.read(), expected) logger.debug('starting actual test') - with self.assertApiCalls(GetObject=1): + with self.assertApiCalls(GetObject=2): with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) as fin: with gzip.GzipFile(fileobj=fin) as zipfile: actual = zipfile.read() @@ -342,7 +288,7 @@ def test_readline_tiny_buffer(self): content = b'englishman\nin\nnew\nyork\n' _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=content) - with self.assertApiCalls(GetObject=1): + with self.assertApiCalls(GetObject=2): with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, buffer_size=8) as fin: actual = list(fin) @@ -352,7 +298,7 @@ def test_readline_tiny_buffer(self): def test_read0_does_not_return_data(self): with self.assertApiCalls(): # set defer_seek to verify that read(0) doesn't trigger an unnecessary API call - with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) as fin: + with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) as fin: data = fin.read(0) self.assertEqual(data, b'') @@ -360,7 +306,7 @@ def test_read0_does_not_return_data(self): def test_to_boto3(self): with self.assertApiCalls(): # set defer_seek to verify that to_boto3() doesn't trigger an unnecessary API call - with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) as fin: + with smart_open.s3.Reader(BUCKET_NAME, KEY_NAME) as fin: returned_obj = fin.to_boto3(_resource('s3')) boto3_body = returned_obj.get()['Body'].read() @@ -375,21 +321,6 @@ def test_binary_iterator(self): actual = [line.rstrip() for line in fin] self.assertEqual(expected, actual) - def test_defer_seek(self): - content = b'englishman\nin\nnew\nyork\n' - _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=content) - - with self.assertApiCalls(): - fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) - with self.assertApiCalls(GetObject=1): - self.assertEqual(fin.read(), content) - - with self.assertApiCalls(): - fin = smart_open.s3.Reader(BUCKET_NAME, KEY_NAME, defer_seek=True) - with self.assertApiCalls(GetObject=1): - fin.seek(10) - self.assertEqual(fin.read(), content[10:]) - def test_read_empty_file(self): _resource('s3').Object(BUCKET_NAME, KEY_NAME).put(Body=b'') @@ -669,6 +600,10 @@ def test_writebuffer(self): class IterBucketTest(unittest.TestCase): def setUp(self): ignore_resource_warnings() + # ensure a clean start by removing any existing buckets from previous tests + for bucket in _resource('s3').buckets.all(): + bucket.objects.all().delete() + bucket.delete() _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists() @pytest.mark.skipif(condition=sys.platform == 'win32', reason="does not run on windows") @@ -739,6 +674,11 @@ def setUp(self): smart_open.concurrency._MULTIPROCESSING = False ignore_resource_warnings() + # ensure a clean start by removing any existing buckets from previous tests + for bucket in _resource('s3').buckets.all(): + bucket.objects.all().delete() + bucket.delete() + _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists() def tearDown(self): @@ -770,6 +710,11 @@ def setUp(self): smart_open.concurrency._CONCURRENT_FUTURES = False ignore_resource_warnings() + # ensure a clean start by removing any existing buckets from previous tests + for bucket in _resource('s3').buckets.all(): + bucket.objects.all().delete() + bucket.delete() + _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists() def tearDown(self): @@ -795,6 +740,11 @@ def setUp(self): ignore_resource_warnings() + # ensure a clean start by removing any existing buckets from previous tests + for bucket in _resource('s3').buckets.all(): + bucket.objects.all().delete() + bucket.delete() + _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists() def tearDown(self): @@ -818,6 +768,11 @@ def test(self): @moto.mock_s3 class IterBucketCredentialsTest(unittest.TestCase): def test(self): + # ensure a clean start by removing any existing buckets from previous tests + for bucket in _resource('s3').buckets.all(): + bucket.objects.all().delete() + bucket.delete() + _resource('s3').create_bucket(Bucket=BUCKET_NAME).wait_until_exists() num_keys = 10 populate_bucket(num_keys=num_keys)