diff --git a/kazoo/client.py b/kazoo/client.py index 27b7c384..9503cda8 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -25,6 +25,7 @@ from kazoo.protocol.connection import ConnectionHandler from kazoo.protocol.paths import _prefix_root, normpath from kazoo.protocol.serialization import ( + AddWatch, Auth, CheckVersion, CloseInstance, @@ -38,6 +39,7 @@ SetACL, GetData, Reconfig, + RemoveWatches, SetData, Sync, Transaction, @@ -48,6 +50,8 @@ KazooState, KeeperState, WatchedEvent, + AddWatchMode, + WatcherType, ) from kazoo.retry import KazooRetry from kazoo.security import ACL, OPEN_ACL_UNSAFE @@ -248,6 +252,8 @@ def __init__( self.state_listeners = set() self._child_watchers = defaultdict(set) self._data_watchers = defaultdict(set) + self._persistent_watchers = defaultdict(set) + self._persistent_recursive_watchers = defaultdict(set) self._reset() self.read_only = read_only @@ -416,8 +422,16 @@ def _reset_watchers(self): for data_watchers in self._data_watchers.values(): watchers.extend(data_watchers) + for persistent_watchers in self._persistent_watchers.values(): + watchers.extend(persistent_watchers) + + for pr_watchers in self._persistent_recursive_watchers.values(): + watchers.extend(pr_watchers) + self._child_watchers = defaultdict(set) self._data_watchers = defaultdict(set) + self._persistent_watchers = defaultdict(set) + self._persistent_recursive_watchers = defaultdict(set) ev = WatchedEvent(EventType.NONE, self._state, None) for watch in watchers: @@ -1644,8 +1658,111 @@ def reconfig_async(self, joining, leaving, new_members, from_config): return async_result + def add_watch(self, path, watch, mode): + """Add a watch. + + This method adds persistent watches. Unlike the data and + child watches which may be set by calls to + :meth:`KazooClient.exists`, :meth:`KazooClient.get`, and + :meth:`KazooClient.get_children`, persistent watches are not + removed after being triggered. + + To remove a persistent watch, use + :meth:`KazooClient.remove_all_watches` with an argument of + :attr:`~kazoo.protocol.states.WatcherType.ANY`. + + The `mode` argument determines whether or not the watch is + recursive. To set a persistent watch, use + :class:`~kazoo.protocol.states.AddWatchMode.PERSISTENT`. To set a + persistent recursive watch, use + :class:`~kazoo.protocol.states.AddWatchMode.PERSISTENT_RECURSIVE`. + + :param path: Path of node to watch. + :param watch: Watch callback to set for future changes + to this path. + :param mode: The mode to use. + :type mode: int + + :raises: + :exc:`~kazoo.exceptions.MarshallingError` if mode is + unknown. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + """ + return self.add_watch_async(path, watch, mode).get() + + def add_watch_async(self, path, watch, mode): + """Asynchronously add a watch. Takes the same arguments as + :meth:`add_watch`. + """ + if not isinstance(path, str): + raise TypeError("Invalid type for 'path' (string expected)") + if not callable(watch): + raise TypeError("Invalid type for 'watch' (must be a callable)") + if not isinstance(mode, int): + raise TypeError("Invalid type for 'mode' (int expected)") + if mode not in ( + AddWatchMode.PERSISTENT, + AddWatchMode.PERSISTENT_RECURSIVE, + ): + raise ValueError("Invalid value for 'mode'") + + async_result = self.handler.async_result() + self._call( + AddWatch(_prefix_root(self.chroot, path), watch, mode), + async_result, + ) + return async_result + + def remove_all_watches(self, path, watcher_type): + """Remove watches from a path. + + This removes all watches of a specified type (data, child, + any) from a given path. + + The `watcher_type` argument specifies which type to use. It + may be one of: + + * :attr:`~kazoo.protocol.states.WatcherType.DATA` + * :attr:`~kazoo.protocol.states.WatcherType.CHILDREN` + * :attr:`~kazoo.protocol.states.WatcherType.ANY` + + To remove persistent watches, specify a watcher type of + :attr:`~kazoo.protocol.states.WatcherType.ANY`. + + :param path: Path of watch to remove. + :param watcher_type: The type of watch to remove. + :type watcher_type: int + """ + + return self.remove_all_watches_async(path, watcher_type).get() + + def remove_all_watches_async(self, path, watcher_type): + """Asynchronously remove watches. Takes the same arguments as + :meth:`remove_all_watches`. + """ + if not isinstance(path, str): + raise TypeError("Invalid type for 'path' (string expected)") + if not isinstance(watcher_type, int): + raise TypeError("Invalid type for 'watcher_type' (int expected)") + if watcher_type not in ( + WatcherType.ANY, + WatcherType.CHILDREN, + WatcherType.DATA, + ): + raise ValueError("Invalid value for 'watcher_type'") + + async_result = self.handler.async_result() + self._call( + RemoveWatches(_prefix_root(self.chroot, path), watcher_type), + async_result, + ) + return async_result + class TransactionRequest(object): + """A Zookeeper Transaction Request A Transaction provides a builder object that can be used to diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index ad4f3b1f..80819b6e 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -20,6 +20,7 @@ ) from kazoo.loggingsupport import BLATHER from kazoo.protocol.serialization import ( + AddWatch, Auth, Close, Connect, @@ -28,6 +29,7 @@ GetChildren2, Ping, PingInstance, + RemoveWatches, ReplyHeader, SASL, Transaction, @@ -35,10 +37,12 @@ int_struct, ) from kazoo.protocol.states import ( + AddWatchMode, Callback, KeeperState, WatchedEvent, EVENT_TYPE_MAP, + WatcherType, ) from kazoo.retry import ( ForceRetryError, @@ -363,6 +367,18 @@ def _write(self, msg, timeout): raise ConnectionDropped("socket connection broken") sent += bytes_sent + def _find_persistent_recursive_watchers(self, path): + parts = path.split("/") + watchers = [] + for count in range(len(parts)): + candidate = "/".join(parts[: count + 1]) + if not candidate: + continue + watchers.extend( + self.client._persistent_recursive_watchers.get(candidate, []) + ) + return watchers + def _read_watch_event(self, buffer, offset): client = self.client watch, offset = Watch.deserialize(buffer, offset) @@ -374,9 +390,13 @@ def _read_watch_event(self, buffer, offset): if watch.type in (CREATED_EVENT, CHANGED_EVENT): watchers.extend(client._data_watchers.pop(path, [])) + watchers.extend(client._persistent_watchers.get(path, [])) + watchers.extend(self._find_persistent_recursive_watchers(path)) elif watch.type == DELETED_EVENT: watchers.extend(client._data_watchers.pop(path, [])) watchers.extend(client._child_watchers.pop(path, [])) + watchers.extend(client._persistent_watchers.get(path, [])) + watchers.extend(self._find_persistent_recursive_watchers(path)) elif watch.type == CHILD_EVENT: watchers.extend(client._child_watchers.pop(path, [])) else: @@ -448,13 +468,35 @@ def _read_response(self, header, buffer, offset): async_object.set(response) - # Determine if watchers should be registered - watcher = getattr(request, "watcher", None) - if not client._stopped.is_set() and watcher: - if isinstance(request, (GetChildren, GetChildren2)): - client._child_watchers[request.path].add(watcher) - else: - client._data_watchers[request.path].add(watcher) + # Determine if watchers should be registered or unregistered + if not client._stopped.is_set(): + watcher = getattr(request, "watcher", None) + if watcher: + if isinstance(request, AddWatch): + if request.mode == AddWatchMode.PERSISTENT: + client._persistent_watchers[request.path].add( + watcher + ) + elif request.mode == AddWatchMode.PERSISTENT_RECURSIVE: + client._persistent_recursive_watchers[ + request.path + ].add(watcher) + elif isinstance(request, (GetChildren, GetChildren2)): + client._child_watchers[request.path].add(watcher) + else: + client._data_watchers[request.path].add(watcher) + if isinstance(request, RemoveWatches): + if request.watcher_type == WatcherType.CHILDREN: + client._child_watchers.pop(request.path, None) + elif request.watcher_type == WatcherType.DATA: + client._data_watchers.pop(request.path, None) + elif request.watcher_type == WatcherType.ANY: + client._child_watchers.pop(request.path, None) + client._data_watchers.pop(request.path, None) + client._persistent_watchers.pop(request.path, None) + client._persistent_recursive_watchers.pop( + request.path, None + ) if isinstance(request, Close): self.logger.log(BLATHER, "Read close response") diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 40e6360c..0a8614da 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -416,6 +416,20 @@ def deserialize(cls, bytes, offset): return data, stat +class RemoveWatches(namedtuple("RemoveWatches", "path watcher_type")): + type = 18 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.watcher_type)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return None + + class Auth(namedtuple("Auth", "auth_type scheme auth")): type = 100 @@ -441,6 +455,20 @@ def deserialize(cls, bytes, offset): return challenge, offset +class AddWatch(namedtuple("AddWatch", "path watcher mode")): + type = 106 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.mode)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return None + + class Watch(namedtuple("Watch", "type state path")): @classmethod def deserialize(cls, bytes, offset): diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 480a586e..50fb9258 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -251,3 +251,44 @@ def data_length(self): @property def children_count(self): return self.numChildren + + +class AddWatchMode(object): + """Modes for use with :meth:`~kazoo.client.KazooClient.add_watch` + + .. attribute:: PERSISTENT + + The watch is not removed when trigged. + + .. attribute:: PERSISTENT_RECURSIVE + + The watch is not removed when trigged, and applies to all + paths underneath the supplied path as well. + """ + + PERSISTENT = 0 + PERSISTENT_RECURSIVE = 1 + + +class WatcherType(object): + """Watcher types for use with + :meth:`~kazoo.client.KazooClient.remove_all_watches` + + .. attribute:: CHILDREN + + Child watches. + + .. attribute:: DATA + + Data watches. + + .. attribute:: ANY + + Any type of watch (child, data, persistent, or persistent + recursive). + + """ + + CHILDREN = 1 + DATA = 2 + ANY = 3 diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index e376baaf..4882179e 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -26,7 +26,13 @@ KazooException, ) from kazoo.protocol.connection import _CONNECTION_DROP -from kazoo.protocol.states import KeeperState, KazooState +from kazoo.protocol.states import ( + AddWatchMode, + KazooState, + KeeperState, + WatcherType, + EventType, +) from kazoo.tests.util import CI_ZK_VERSION @@ -1158,6 +1164,147 @@ def test_request_queuing_session_expired(self): finally: client.stop() + def _require_zk_version(self, major, minor): + skip = False + if CI_ZK_VERSION and CI_ZK_VERSION < (major, minor): + skip = True + elif CI_ZK_VERSION and CI_ZK_VERSION >= (major, minor): + skip = False + else: + ver = self.client.server_version() + if ver[1] < minor: + skip = True + if skip: + pytest.skip("Must use Zookeeper %s.%s or above" % (major, minor)) + + def test_persistent_recursive_watch(self): + # This tests adding and removing a persistent recursive watch. + self._require_zk_version(3, 6) + events = [] + + def callback(event): + events.append(dict(type=event.type, path=event.path)) + + client = self.client + client.add_watch("/a", callback, AddWatchMode.PERSISTENT_RECURSIVE) + full_path = client.chroot + "/a" + assert len(client._persistent_recursive_watchers[full_path]) == 1 + client.create("/a") + client.create("/a/b") + client.create("/a/b/c", value=b"1") + client.create("/a/b/d", value=b"1") + client.set("/a/b/c", value=b"2") + client.set("/a/b/d", value=b"2") + client.delete("/a", recursive=True) + # Remove the watch + client.remove_all_watches("/a", WatcherType.ANY) + # Perform one more call that we don't expect to see + client.create("/a") + # Wait in case the callback does arrive + time.sleep(3) + assert client._persistent_recursive_watchers[full_path] == set() + assert events == [ + dict(type=EventType.CREATED, path="/a"), + dict(type=EventType.CREATED, path="/a/b"), + dict(type=EventType.CREATED, path="/a/b/c"), + dict(type=EventType.CREATED, path="/a/b/d"), + dict(type=EventType.CHANGED, path="/a/b/c"), + dict(type=EventType.CHANGED, path="/a/b/d"), + dict(type=EventType.DELETED, path="/a/b/c"), + dict(type=EventType.DELETED, path="/a/b/d"), + dict(type=EventType.DELETED, path="/a/b"), + dict(type=EventType.DELETED, path="/a"), + ] + + def test_persistent_watch(self): + # This tests adding and removing a persistent watch. + self._require_zk_version(3, 6) + events = [] + + def callback(event): + events.append(dict(type=event.type, path=event.path)) + + client = self.client + client.add_watch("/a", callback, AddWatchMode.PERSISTENT) + full_path = client.chroot + "/a" + assert len(client._persistent_watchers[full_path]) == 1 + client.create("/a") + # This shouldn't appear since the watch is not recursive + client.create("/a/b") + client.delete("/a", recursive=True) + # Remove the watch + client.remove_all_watches("/a", WatcherType.ANY) + # Perform one more call that we don't expect to see + client.create("/a") + # Wait in case the callback does arrive + time.sleep(3) + assert client._persistent_watchers[full_path] == set() + assert events == [ + dict(type=EventType.CREATED, path="/a"), + dict(type=EventType.DELETED, path="/a"), + ] + + def test_remove_data_watch(self): + # Test that removing a data watch leaves a child watch in place. + self._require_zk_version(3, 6) + callback_event = threading.Event() + + def child_callback(event): + callback_event.set() + + def data_callback(event): + pass + + client = self.client + client.create("/a") + client.get("/a", watch=data_callback) + client.get_children("/a", watch=child_callback) + client.remove_all_watches("/a", WatcherType.DATA) + client.create("/a/b") + callback_event.wait(30) + + def test_remove_children_watch(self): + # Test that removing a children watch leaves a data watch in place. + self._require_zk_version(3, 6) + callback_event = threading.Event() + + def data_callback(event): + callback_event.set() + + def child_callback(event): + pass + + client = self.client + client.create("/a") + client.get("/a", watch=data_callback) + client.get_children("/a", watch=child_callback) + client.remove_all_watches("/a", WatcherType.CHILDREN) + client.set("/a", b"1") + callback_event.wait(30) + + def test_invalid_add_watch_values(self): + def callback(event): + return + + client = self.client + with pytest.raises(TypeError): + client.add_watch(b"/a", callback, AddWatchMode.PERSISTENT) + with pytest.raises(TypeError): + client.add_watch("/a", "test", AddWatchMode.PERSISTENT) + with pytest.raises(TypeError): + client.add_watch("/a", callback, "1") + with pytest.raises(ValueError): + client.add_watch("/a", callback, 42) + + def test_invalid_remove_all_watch_values(self): + client = self.client + with pytest.raises(TypeError): + client.remove_all_watches(b"/a", WatcherType.ANY) + with pytest.raises(TypeError): + client.remove_all_watches("/a", "test") + with pytest.raises(ValueError): + client.remove_all_watches("/a", 42) + class TestSSLClient(KazooTestCase): def setUp(self):