Skip to content

Commit

Permalink
Close socket when any exception was raised
Browse files Browse the repository at this point in the history
  • Loading branch information
anti-social committed Sep 29, 2017
1 parent 95d1a10 commit 1d40693
Showing 1 changed file with 42 additions and 63 deletions.
105 changes: 42 additions & 63 deletions memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,13 +488,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
for key in server_keys[server]: # These are mangled keys
cmd = self._encode_cmd('delete', key, headers, noreply, b'\r\n')
write(cmd)
try:
with _socket_guard(server, (socket.error,)) as sg:
server.send_cmds(b''.join(bigcmd))
except socket.error as msg:
if sg.interrupted:
rc = 0
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
dead_servers.append(server)

# if noreply, just return
Expand All @@ -506,13 +503,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False):
del server_keys[server]

for server, keys in six.iteritems(server_keys):
try:
with _socket_guard(server, (socket.error,)) as sg:
for key in keys:
server.expect(b"DELETED")
except socket.error as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
if sg.interrupted:
rc = 0
return rc

Expand Down Expand Up @@ -558,7 +552,7 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False):
headers = None
fullcmd = self._encode_cmd(cmd, key, headers, noreply)

try:
with _socket_guard(server, (socket.error,)):
server.send_cmd(fullcmd)
if noreply:
return 1
Expand All @@ -567,10 +561,6 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False):
return 1
self.debuglog('%s expected %s, got: %r'
% (cmd, ' or '.join(expected), line))
except socket.error as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
return 0

def incr(self, key, delta=1, noreply=False):
Expand Down Expand Up @@ -633,19 +623,14 @@ def _incrdecr(self, cmd, key, delta, noreply=False):
return None
self._statlog(cmd)
fullcmd = self._encode_cmd(cmd, key, str(delta), noreply)
try:
with _socket_guard(server, (socket.error,)):
server.send_cmd(fullcmd)
if noreply:
return
line = server.readline()
if line is None or line.strip() == b'NOT_FOUND':
return None
return int(line)
except socket.error as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
return None

def add(self, key, val, time=0, min_compress_len=0, noreply=False):
'''Add new key with value.
Expand Down Expand Up @@ -902,7 +887,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
for server in six.iterkeys(server_keys):
bigcmd = []
write = bigcmd.append
try:
with _socket_guard(server, (socket.error,)) as sg:
for key in server_keys[server]: # These are mangled keys
store_info = self._val_to_store_info(
mapping[prefixed_to_orig_key[key]],
Expand All @@ -917,10 +902,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
else:
notstored.append(prefixed_to_orig_key[key])
server.send_cmds(b''.join(bigcmd))
except socket.error as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
if sg.interrupted:
dead_servers.append(server)

# if noreply, just return early
Expand All @@ -936,17 +918,13 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0,
return list(mapping.keys())

for server, keys in six.iteritems(server_keys):
try:
with _socket_guard(server, (_Error, socket.error)):
for key in keys:
if server.readline() == b'STORED':
continue
else:
# un-mangle.
notstored.append(prefixed_to_orig_key[key])
except (_Error, socket.error) as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
return notstored

def _val_to_store_info(self, val, min_compress_len):
Expand Down Expand Up @@ -1032,15 +1010,11 @@ def _unsafe_set():
fullcmd = self._encode_cmd(cmd, key, headers, noreply,
b'\r\n', encoded_val)

try:
with _socket_guard(server, (socket.error,)):
server.send_cmd(fullcmd)
if noreply:
return True
return server.expect(b"STORED", raise_exception=True) == b"STORED"
except socket.error as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
return 0

try:
Expand All @@ -1065,7 +1039,7 @@ def _get(self, cmd, key):
def _unsafe_get():
self._statlog(cmd)

try:
with _socket_guard(server, (_Error, socket.error)):
cmd_bytes = cmd.encode('utf-8') if six.PY3 else cmd
fullcmd = b''.join((cmd_bytes, b' ', key))
server.send_cmd(fullcmd)
Expand All @@ -1085,16 +1059,9 @@ def _unsafe_get():
if not rkey:
return None
try:
value = self._recv_value(server, flags, rlen)
return self._recv_value(server, flags, rlen)
finally:
server.expect(b"END", raise_exception=True)
except (_Error, socket.error) as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
return None

return value

try:
return _unsafe_get()
Expand Down Expand Up @@ -1185,13 +1152,10 @@ def get_multi(self, keys, key_prefix=''):
# send out all requests on each server before reading anything
dead_servers = []
for server in six.iterkeys(server_keys):
try:
with _socket_guard(server, (socket.error,)) as sg:
fullcmd = b"get " + b" ".join(server_keys[server])
server.send_cmd(fullcmd)
except socket.error as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
if sg.interrupted:
dead_servers.append(server)

# if any servers died on the way, don't expect them to respond.
Expand All @@ -1200,7 +1164,7 @@ def get_multi(self, keys, key_prefix=''):

retvals = {}
for server in six.iterkeys(server_keys):
try:
with _socket_guard(server, (_Error, socket.error)):
line = server.readline()
while line and line != b'END':
rkey, flags, rlen = self._expectvalue(server, line)
Expand All @@ -1210,10 +1174,6 @@ def get_multi(self, keys, key_prefix=''):
# un-prefix returned key.
retvals[prefixed_to_orig_key[rkey]] = val
line = server.readline()
except (_Error, socket.error) as msg:
if isinstance(msg, tuple):
msg = msg[1]
server.mark_dead(msg)
return retvals

def _expect_cas_value(self, server, line=None, raise_exception=False):
Expand Down Expand Up @@ -1394,15 +1354,10 @@ def _get_socket(self):
s = socket.socket(self.family, socket.SOCK_STREAM)
if hasattr(s, 'settimeout'):
s.settimeout(self.socket_timeout)
try:
with _socket_guard(self, (socket.error,),
msg_tmpl='connect: {}') as sg:
s.connect(self.address)
except socket.timeout as msg:
self.mark_dead("connect: %s" % msg)
return None
except socket.error as msg:
if isinstance(msg, tuple):
msg = msg[1]
self.mark_dead("connect: %s" % msg)
if sg.interrupted:
return None
self.socket = s
self.buffer = b''
Expand Down Expand Up @@ -1497,6 +1452,30 @@ def __str__(self):
return "unix:%s%s" % (self.address, d)


class _socket_guard(object):
def __init__(self, server, exceptions, msg_tmpl='{}'):
self._server = server
self._exceptions = exceptions
self._msg_tmpl = msg_tmpl
self.interrupted = False

def __enter__(self):
return self

def __exit__(self, exc_type, exc, exc_tb):
if exc is not None:
self.interrupted = True

if isinstance(exc, self._exceptions):
msg = self._msg_tmpl.format(exc)
self._server.mark_dead(msg)
return True
elif exc is not None:
self._server.close_socket()

return False


def _doctest():
import doctest
import memcache
Expand Down

0 comments on commit 1d40693

Please sign in to comment.