Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: translate socket.timeout to TTransportException #160

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions thriftpy2/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ def make_client(service, host="localhost", port=9090, unix_socket=None,
proto_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory(),
timeout=3000, cafile=None, ssl_context=None, certfile=None,
keyfile=None, url="", socket_family=socket.AF_INET):
keyfile=None, url="", socket_family=socket.AF_INET,
handle_timeout_error=False):
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout)
socket = TSocket(unix_socket=unix_socket, socket_timeout=timeout,
)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
Expand All @@ -47,7 +49,9 @@ def make_client(service, host="localhost", port=9090, unix_socket=None,
certfile=certfile, keyfile=keyfile,
ssl_context=ssl_context)
else:
socket = TSocket(host, port, socket_family=socket_family, socket_timeout=timeout)
socket = TSocket(host, port, socket_family=socket_family,
socket_timeout=timeout,
handle_timeout_error=handle_timeout_error)
else:
raise ValueError("Either host/port or unix_socket or url must be provided.")

Expand Down Expand Up @@ -91,7 +95,7 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
trans_factory=TBufferedTransportFactory(),
timeout=None, socket_timeout=3000, connect_timeout=3000,
cafile=None, ssl_context=None, certfile=None, keyfile=None,
url=""):
url="", handle_timeout_error=False):
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
Expand Down Expand Up @@ -119,7 +123,8 @@ def client_context(service, host="localhost", port=9090, unix_socket=None,
else:
socket = TSocket(host, port,
connect_timeout=connect_timeout,
socket_timeout=socket_timeout)
socket_timeout=socket_timeout,
handle_timeout_error=handle_timeout_error)
else:
raise ValueError("Either host/port or unix_socket or url must be provided.")

Expand Down
24 changes: 22 additions & 2 deletions thriftpy2/transport/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class TSocket(object):

def __init__(self, host=None, port=None, unix_socket=None,
sock=None, socket_family=socket.AF_INET,
socket_timeout=3000, connect_timeout=None):
socket_timeout=3000, connect_timeout=None,
handle_timeout_error=False):
"""Initialize a TSocket

TSocket can be initialized in 3 ways:
Expand All @@ -35,6 +36,8 @@ def __init__(self, host=None, port=None, unix_socket=None,
@param socket_timeout socket timeout in ms
@param connect_timeout connect timeout in ms, only used in
connection, will be set to socket_timeout if not set.
@param handle_timeout_error(bool) Whether translate socket.timeout
error to TTransportException. Default is False for compalibility.
"""
if sock:
self.sock = sock
Expand All @@ -54,6 +57,8 @@ def __init__(self, host=None, port=None, unix_socket=None,
self.connect_timeout = connect_timeout / 1000 if connect_timeout \
else self.socket_timeout

self.handle_timeout_error = handle_timeout_error

def _init_sock(self):
if self.unix_socket:
_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
Expand Down Expand Up @@ -108,6 +113,13 @@ def read(self, sz):
while True:
try:
buff = self.sock.recv(sz)
except socket.timeout:
if not self.handle_timeout_error:
raise
addr = self.sock.getsockname()
typ = TTransportException.TIMED_OUT
msg = "Timeouted when read from %s" % str(addr)
raise TTransportException(type=typ, message=msg)
except socket.error as e:
if e.errno == errno.EINTR:
continue
Expand All @@ -133,7 +145,15 @@ def read(self, sz):
return buff

def write(self, buff):
self.sock.sendall(buff)
try:
self.sock.sendall(buff)
except socket.timeout:
if not self.handle_timeout_error:
raise
addr = self.sock.getsockname()
typ = TTransportException.TIMED_OUT
msg = "Timeouted when write to %s" % str(addr)
raise TTransportException(type=typ, message=msg)

def flush(self):
pass
Expand Down
5 changes: 3 additions & 2 deletions thriftpy2/transport/sslsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, host, port, socket_family=socket.AF_INET,
socket_timeout=3000, connect_timeout=None,
ssl_context=None, validate=True,
cafile=None, capath=None, certfile=None, keyfile=None,
ciphers=DEFAULT_CIPHERS):
ciphers=DEFAULT_CIPHERS, handle_timeout_error=False):
"""Initialize a TSSLSocket

@param validate(bool) Set to False to disable SSL certificate
Expand All @@ -47,7 +47,8 @@ def __init__(self, host, port, socket_family=socket.AF_INET,
"""
super(TSSLSocket, self).__init__(
host=host, port=port, socket_family=socket_family,
connect_timeout=connect_timeout, socket_timeout=socket_timeout)
connect_timeout=connect_timeout, socket_timeout=socket_timeout,
handle_timeout_error=handle_timeout_error)

if ssl_context:
self.ssl_context = ssl_context
Expand Down