diff --git a/memcache.py b/memcache.py index 3e6bd67..e8260ba 100644 --- a/memcache.py +++ b/memcache.py @@ -488,13 +488,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False): for key in server_keys[server]: # These are mangled keys cmd = self._encode_cmd('delete', key, headers, noreply, b'\r\n') write(cmd) - try: + with _socket_guard(server, (socket.error,)) as sg: server.send_cmds(b''.join(bigcmd)) - except socket.error as msg: + if sg.interrupted: rc = 0 - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) dead_servers.append(server) # if noreply, just return @@ -506,13 +503,10 @@ def delete_multi(self, keys, time=None, key_prefix='', noreply=False): del server_keys[server] for server, keys in six.iteritems(server_keys): - try: + with _socket_guard(server, (socket.error,)) as sg: for key in keys: server.expect(b"DELETED") - except socket.error as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) + if sg.interrupted: rc = 0 return rc @@ -558,7 +552,7 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False): headers = None fullcmd = self._encode_cmd(cmd, key, headers, noreply) - try: + with _socket_guard(server, (socket.error,)): server.send_cmd(fullcmd) if noreply: return 1 @@ -567,10 +561,6 @@ def _deletetouch(self, expected, cmd, key, time=0, noreply=False): return 1 self.debuglog('%s expected %s, got: %r' % (cmd, ' or '.join(expected), line)) - except socket.error as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) return 0 def incr(self, key, delta=1, noreply=False): @@ -633,7 +623,7 @@ def _incrdecr(self, cmd, key, delta, noreply=False): return None self._statlog(cmd) fullcmd = self._encode_cmd(cmd, key, str(delta), noreply) - try: + with _socket_guard(server, (socket.error,)): server.send_cmd(fullcmd) if noreply: return @@ -641,11 +631,6 @@ def _incrdecr(self, cmd, key, delta, noreply=False): if line is None or line.strip() == b'NOT_FOUND': return None return int(line) - except socket.error as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) - return None def add(self, key, val, time=0, min_compress_len=0, noreply=False): '''Add new key with value. @@ -902,7 +887,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0, for server in six.iterkeys(server_keys): bigcmd = [] write = bigcmd.append - try: + with _socket_guard(server, (socket.error,)) as sg: for key in server_keys[server]: # These are mangled keys store_info = self._val_to_store_info( mapping[prefixed_to_orig_key[key]], @@ -917,10 +902,7 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0, else: notstored.append(prefixed_to_orig_key[key]) server.send_cmds(b''.join(bigcmd)) - except socket.error as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) + if sg.interrupted: dead_servers.append(server) # if noreply, just return early @@ -936,17 +918,13 @@ def set_multi(self, mapping, time=0, key_prefix='', min_compress_len=0, return list(mapping.keys()) for server, keys in six.iteritems(server_keys): - try: + with _socket_guard(server, (_Error, socket.error)): for key in keys: if server.readline() == b'STORED': continue else: # un-mangle. notstored.append(prefixed_to_orig_key[key]) - except (_Error, socket.error) as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) return notstored def _val_to_store_info(self, val, min_compress_len): @@ -1032,15 +1010,11 @@ def _unsafe_set(): fullcmd = self._encode_cmd(cmd, key, headers, noreply, b'\r\n', encoded_val) - try: + with _socket_guard(server, (socket.error,)): server.send_cmd(fullcmd) if noreply: return True return server.expect(b"STORED", raise_exception=True) == b"STORED" - except socket.error as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) return 0 try: @@ -1065,7 +1039,7 @@ def _get(self, cmd, key): def _unsafe_get(): self._statlog(cmd) - try: + with _socket_guard(server, (_Error, socket.error)): cmd_bytes = cmd.encode('utf-8') if six.PY3 else cmd fullcmd = b''.join((cmd_bytes, b' ', key)) server.send_cmd(fullcmd) @@ -1085,16 +1059,9 @@ def _unsafe_get(): if not rkey: return None try: - value = self._recv_value(server, flags, rlen) + return self._recv_value(server, flags, rlen) finally: server.expect(b"END", raise_exception=True) - except (_Error, socket.error) as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) - return None - - return value try: return _unsafe_get() @@ -1185,13 +1152,10 @@ def get_multi(self, keys, key_prefix=''): # send out all requests on each server before reading anything dead_servers = [] for server in six.iterkeys(server_keys): - try: + with _socket_guard(server, (socket.error,)) as sg: fullcmd = b"get " + b" ".join(server_keys[server]) server.send_cmd(fullcmd) - except socket.error as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) + if sg.interrupted: dead_servers.append(server) # if any servers died on the way, don't expect them to respond. @@ -1200,7 +1164,7 @@ def get_multi(self, keys, key_prefix=''): retvals = {} for server in six.iterkeys(server_keys): - try: + with _socket_guard(server, (_Error, socket.error)): line = server.readline() while line and line != b'END': rkey, flags, rlen = self._expectvalue(server, line) @@ -1210,10 +1174,6 @@ def get_multi(self, keys, key_prefix=''): # un-prefix returned key. retvals[prefixed_to_orig_key[rkey]] = val line = server.readline() - except (_Error, socket.error) as msg: - if isinstance(msg, tuple): - msg = msg[1] - server.mark_dead(msg) return retvals def _expect_cas_value(self, server, line=None, raise_exception=False): @@ -1394,15 +1354,10 @@ def _get_socket(self): s = socket.socket(self.family, socket.SOCK_STREAM) if hasattr(s, 'settimeout'): s.settimeout(self.socket_timeout) - try: + with _socket_guard(self, (socket.error,), + msg_tmpl='connect: {}') as sg: s.connect(self.address) - except socket.timeout as msg: - self.mark_dead("connect: %s" % msg) - return None - except socket.error as msg: - if isinstance(msg, tuple): - msg = msg[1] - self.mark_dead("connect: %s" % msg) + if sg.interrupted: return None self.socket = s self.buffer = b'' @@ -1497,6 +1452,30 @@ def __str__(self): return "unix:%s%s" % (self.address, d) +class _socket_guard(object): + def __init__(self, server, exceptions, msg_tmpl='{}'): + self._server = server + self._exceptions = exceptions + self._msg_tmpl = msg_tmpl + self.interrupted = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, exc_tb): + if exc is not None: + self.interrupted = True + + if isinstance(exc, self._exceptions): + msg = self._msg_tmpl.format(exc) + self._server.mark_dead(msg) + return True + elif exc is not None: + self._server.close_socket() + + return False + + def _doctest(): import doctest import memcache diff --git a/test-requirements.txt b/test-requirements.txt index 8f21390..d80827e 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,3 +1,4 @@ nose coverage hacking +mock diff --git a/tests/test_memcache.py b/tests/test_memcache.py index 0072813..6782633 100644 --- a/tests/test_memcache.py +++ b/tests/test_memcache.py @@ -1,11 +1,15 @@ from __future__ import print_function +import socket import unittest import six from memcache import Client, SERVER_MAX_KEY_LENGTH, SERVER_MAX_VALUE_LENGTH # noqa: H301 +from mock import Mock +from mock import patch + class FooStruct(object): @@ -166,6 +170,57 @@ def test_disconnect_all_delete_multi(self): ret = self.mc.delete_multi({'keyhere': 'a', 'keythere': 'b'}) self.assertEqual(ret, 1) + def test_socket_error(self): + """Tests case when socket.error exception was raised""" + self.mc.set('socket.error', 1) + server = Mock( + # Should we catch secket.error when establishing connection? + # connect=Mock(side_effect=socket.error(-1, 'connect error')), + send_cmd=Mock(side_effect=socket.error(-1, 'send cmd error')), + send_cmds=Mock(side_effect=socket.error(-1, 'send cmds error')), + flush=Mock(side_effect=socket.error(-1, 'flush error')), + ) + with patch.object(self.mc, 'servers', [server]), \ + patch.object(self.mc, 'buckets', [server]): + self.assertEqual(self.mc.set('socket.error', 2), 0) + self.assertEqual( + self.mc.set_multi({'socket.error': 2}), + ['socket.error'] + ) + self.assertIs(self.mc.incr('socket.error'), None) + self.assertIs(self.mc.decr('socket.error'), None) + self.assertEqual(self.mc.add('socket.error', 5), 0) + self.assertEqual(self.mc.append('socket.error', 9), 0) + self.assertEqual(self.mc.prepend('socket.error', 1), 0) + self.assertEqual(self.mc.replace('socket.error', 100), 0) + self.assertEqual(self.mc.cas('socket.error', 100), 0) + self.assertEqual(self.mc.delete('socket.error'), 0) + self.assertEqual(self.mc.delete_multi(['socket.error']), 0) + self.assertEqual(self.mc.touch('socket.error'), 0) + self.assertIs(self.mc.get('socket.error'), None) + self.assertIs(self.mc.gets('socket.error'), None) + self.assertEqual(self.mc.get_multi(['socket.error']), {}) + self.assertRaises(socket.error, self.mc.get_stats) + self.assertRaises(socket.error, self.mc.get_slab_stats) + self.assertRaises(socket.error, self.mc.get_slabs) + self.assertRaises(socket.error, self.mc.flush_all) + + def test_exception_handling(self): + """Tests closing socket when custom exception raised""" + class CustomException(Exception): + pass + + self.mc.set('error', 1) + with patch.object(self.mc, '_recv_value', + Mock(side_effect=CustomException('custom error'))): + try: + self.mc.get('error') + except CustomException: + pass + self.assertIs(self.mc.servers[0].socket, None) + self.assertEqual(self.mc.set('error', 2), True) + self.assertEqual(self.mc.get('error'), 2) + if __name__ == '__main__': unittest.main()