diff --git a/kazoo/client.py b/kazoo/client.py index 27b7c384..2151baac 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -25,6 +25,7 @@ from kazoo.protocol.connection import ConnectionHandler from kazoo.protocol.paths import _prefix_root, normpath from kazoo.protocol.serialization import ( + AddWatch, Auth, CheckVersion, CloseInstance, @@ -38,6 +39,7 @@ SetACL, GetData, Reconfig, + RemoveWatches, SetData, Sync, Transaction, @@ -248,6 +250,7 @@ def __init__( self.state_listeners = set() self._child_watchers = defaultdict(set) self._data_watchers = defaultdict(set) + self._persistent_watchers = defaultdict(set) self._reset() self.read_only = read_only @@ -416,8 +419,12 @@ def _reset_watchers(self): for data_watchers in self._data_watchers.values(): watchers.extend(data_watchers) + for persistent_watchers in self._persistent_watchers.values(): + watchers.extend(persistent_watchers) + self._child_watchers = defaultdict(set) self._data_watchers = defaultdict(set) + self._persistent_watchers = defaultdict(set) ev = WatchedEvent(EventType.NONE, self._state, None) for watch in watchers: @@ -1644,8 +1651,100 @@ def reconfig_async(self, joining, leaving, new_members, from_config): return async_result + def add_watch(self, path, watch, mode): + """Add a watch. + + This method adds persistent watches. Unlike the data and + child watches which may be set by calls to + :meth:`KazooClient.exists`, :meth:`KazooClient.get`, and + :meth:`KazooClient.get_children`, persistent watches are not + removed after being triggered. + + To remove a persistent watch, use + :meth:`KazooClient.remove_all_watches` with an argument of + :attr:`~kazoo.states.WatcherType.ANY`. + + The `mode` argument determines whether or not the watch is + recursive. To set a persistent watch, use + :class:`~kazoo.states.AddWatchMode.PERSISTENT`. To set a + persistent recursive watch, use + :class:`~kazoo.states.AddWatchMode.PERSISTENT_RECURSIVE`. + + :param path: Path of node to watch. + :param watch: Watch callback to set for future changes + to this path. + :param mode: The mode to use. + :type mode: int + + :raises: + :exc:`~kazoo.exceptions.MarshallingError` if mode is + unknown. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + """ + return self.add_watch_async(path, watch, mode).get() + + def add_watch_async(self, path, watch, mode): + """Asynchronously add a watch. Takes the same arguments as + :meth:`add_watch`. + """ + if not isinstance(path, str): + raise TypeError("Invalid type for 'path' (string expected)") + if not callable(watch): + raise TypeError("Invalid type for 'watch' (must be a callable)") + if not isinstance(mode, int): + raise TypeError("Invalid type for 'mode' (int expected)") + + async_result = self.handler.async_result() + self._call( + AddWatch(_prefix_root(self.chroot, path), watch, mode), + async_result, + ) + return async_result + + def remove_all_watches(self, path, watcher_type): + """Remove watches from a path. + + This removes all watches of a specified type (data, child, + any) from a given path. + + The `watcher_type` argument specifies which type to use. It + may be one of: + + * :attr:`~kazoo.states.WatcherType.DATA` + * :attr:`~kazoo.states.WatcherType.CHILD` + * :attr:`~kazoo.states.WatcherType.ANY` + + To remove persistent watches, specify a watcher type of + :attr:`~kazoo.states.WatcherType.ANY`. + + :param path: Path of watch to remove. + :param watcher_type: The type of watch to remove. + :type watcher_type: int + """ + + return self.remove_all_watches_async(path, watcher_type).get() + + def remove_all_watches_async(self, path, watcher_type): + """Asynchronously remove watches. Takes the same arguments as + :meth:`remove_all_watches`. + """ + if not isinstance(path, str): + raise TypeError("Invalid type for 'path' (string expected)") + if not isinstance(watcher_type, int): + raise TypeError("Invalid type for 'watcher_type' (int expected)") + + async_result = self.handler.async_result() + self._call( + RemoveWatches(_prefix_root(self.chroot, path), watcher_type), + async_result, + ) + return async_result + class TransactionRequest(object): + """A Zookeeper Transaction Request A Transaction provides a builder object that can be used to diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py index 1307463c..07b6a91d 100644 --- a/kazoo/exceptions.py +++ b/kazoo/exceptions.py @@ -187,6 +187,11 @@ class NotReadOnlyCallError(ZookeeperError): a read-only server""" +@_zookeeper_exception(-121) +class NoWatcherError(ZookeeperError): + """No watcher was found at the supplied path""" + + class ConnectionClosedError(SessionExpiredError): """Connection is closed""" diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 9b5ce2fb..22912ff3 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -20,6 +20,7 @@ ) from kazoo.loggingsupport import BLATHER from kazoo.protocol.serialization import ( + AddWatch, Auth, Close, Connect, @@ -28,6 +29,7 @@ GetChildren2, Ping, PingInstance, + RemoveWatches, ReplyHeader, SASL, Transaction, @@ -363,6 +365,18 @@ def _write(self, msg, timeout): raise ConnectionDropped("socket connection broken") sent += bytes_sent + def _find_persistent_watchers(self, path): + parts = path.split("/") + watchers = [] + for count in range(len(parts)): + candidate = "/".join(parts[: count + 1]) + if not candidate: + continue + watchers.extend( + self.client._persistent_watchers.get(candidate, []) + ) + return watchers + def _read_watch_event(self, buffer, offset): client = self.client watch, offset = Watch.deserialize(buffer, offset) @@ -374,9 +388,11 @@ def _read_watch_event(self, buffer, offset): if watch.type in (CREATED_EVENT, CHANGED_EVENT): watchers.extend(client._data_watchers.pop(path, [])) + watchers.extend(self._find_persistent_watchers(path)) elif watch.type == DELETED_EVENT: watchers.extend(client._data_watchers.pop(path, [])) watchers.extend(client._child_watchers.pop(path, [])) + watchers.extend(self._find_persistent_watchers(path)) elif watch.type == CHILD_EVENT: watchers.extend(client._child_watchers.pop(path, [])) else: @@ -448,13 +464,25 @@ def _read_response(self, header, buffer, offset): async_object.set(response) - # Determine if watchers should be registered - watcher = getattr(request, "watcher", None) - if not client._stopped.is_set() and watcher: - if isinstance(request, (GetChildren, GetChildren2)): - client._child_watchers[request.path].add(watcher) - else: - client._data_watchers[request.path].add(watcher) + # Determine if watchers should be registered or unregistered + if not client._stopped.is_set(): + watcher = getattr(request, "watcher", None) + if watcher: + if isinstance(request, AddWatch): + client._persistent_watchers[request.path].add(watcher) + elif isinstance(request, (GetChildren, GetChildren2)): + client._child_watchers[request.path].add(watcher) + else: + client._data_watchers[request.path].add(watcher) + if isinstance(request, RemoveWatches): + if request.watcher_type == 1: + client._child_watchers.pop(request.path, None) + if request.watcher_type == 2: + client._data_watchers.pop(request.path, None) + if request.watcher_type == 3: + client._child_watchers.pop(request.path, None) + client._data_watchers.pop(request.path, None) + client._persistent_watchers.pop(request.path, None) if isinstance(request, Close): self.logger.log(BLATHER, "Read close response") diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 40e6360c..0a8614da 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -416,6 +416,20 @@ def deserialize(cls, bytes, offset): return data, stat +class RemoveWatches(namedtuple("RemoveWatches", "path watcher_type")): + type = 18 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.watcher_type)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return None + + class Auth(namedtuple("Auth", "auth_type scheme auth")): type = 100 @@ -441,6 +455,20 @@ def deserialize(cls, bytes, offset): return challenge, offset +class AddWatch(namedtuple("AddWatch", "path watcher mode")): + type = 106 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.mode)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return None + + class Watch(namedtuple("Watch", "type state path")): @classmethod def deserialize(cls, bytes, offset): diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 480a586e..50fb9258 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -251,3 +251,44 @@ def data_length(self): @property def children_count(self): return self.numChildren + + +class AddWatchMode(object): + """Modes for use with :meth:`~kazoo.client.KazooClient.add_watch` + + .. attribute:: PERSISTENT + + The watch is not removed when trigged. + + .. attribute:: PERSISTENT_RECURSIVE + + The watch is not removed when trigged, and applies to all + paths underneath the supplied path as well. + """ + + PERSISTENT = 0 + PERSISTENT_RECURSIVE = 1 + + +class WatcherType(object): + """Watcher types for use with + :meth:`~kazoo.client.KazooClient.remove_all_watches` + + .. attribute:: CHILDREN + + Child watches. + + .. attribute:: DATA + + Data watches. + + .. attribute:: ANY + + Any type of watch (child, data, persistent, or persistent + recursive). + + """ + + CHILDREN = 1 + DATA = 2 + ANY = 3 diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index d05cacbd..e511c700 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -24,7 +24,12 @@ KazooException, ) from kazoo.protocol.connection import _CONNECTION_DROP -from kazoo.protocol.states import KeeperState, KazooState +from kazoo.protocol.states import ( + AddWatchMode, + KazooState, + KeeperState, + WatcherType, +) from kazoo.tests.util import CI_ZK_VERSION @@ -1158,6 +1163,118 @@ def test_request_queuing_session_expired(self): finally: client.stop() + def _require_zk_version(self, major, minor): + skip = False + if CI_ZK_VERSION and CI_ZK_VERSION < (major, minor): + skip = True + elif CI_ZK_VERSION and CI_ZK_VERSION >= (major, minor): + skip = False + else: + ver = self.client.server_version() + if ver[1] < minor: + skip = True + if skip: + pytest.skip("Must use Zookeeper %s.%s or above" % (major, minor)) + + def test_persistent_recursive_watch(self): + # This tests adding and removing a persistent recursive watch. + self._require_zk_version(3, 6) + events = [] + + def callback(event): + events.append(dict(type=event.type, path=event.path)) + + client = self.client + client.add_watch("/a", callback, AddWatchMode.PERSISTENT_RECURSIVE) + client.create("/a") + client.create("/a/b") + client.create("/a/b/c", value=b"1") + client.create("/a/b/d", value=b"1") + client.set("/a/b/c", value=b"2") + client.set("/a/b/d", value=b"2") + client.delete("/a", recursive=True) + # Remove the watch + client.remove_all_watches("/a", WatcherType.ANY) + # Perform one more call that we don't expect to see + client.create("/a") + # Wait in case the callback does arrive + time.sleep(3) + assert events == [ + dict(type="CREATED", path="/a"), + dict(type="CREATED", path="/a/b"), + dict(type="CREATED", path="/a/b/c"), + dict(type="CREATED", path="/a/b/d"), + dict(type="CHANGED", path="/a/b/c"), + dict(type="CHANGED", path="/a/b/d"), + dict(type="DELETED", path="/a/b/c"), + dict(type="DELETED", path="/a/b/d"), + dict(type="DELETED", path="/a/b"), + dict(type="DELETED", path="/a"), + ] + + def test_persistent_watch(self): + # This tests adding and removing a persistent watch. + self._require_zk_version(3, 6) + events = [] + + def callback(event): + events.append(dict(type=event.type, path=event.path)) + + client = self.client + client.add_watch("/a", callback, AddWatchMode.PERSISTENT) + client.create("/a") + # This shouldn't appear since the watch is not recursive + client.create("/a/b") + client.delete("/a", recursive=True) + # Remove the watch + client.remove_all_watches("/a", WatcherType.ANY) + # Perform one more call that we don't expect to see + client.create("/a") + # Wait in case the callback does arrive + time.sleep(3) + assert events == [ + dict(type="CREATED", path="/a"), + dict(type="DELETED", path="/a"), + ] + + def test_remove_data_watch(self): + # Test that removing a data watch leaves a child watch in place. + self._require_zk_version(3, 6) + callback_event = threading.Event() + + def child_callback(event): + callback_event.set() + + def data_callback(event): + pass + + client = self.client + client.create("/a") + client.get("/a", watch=data_callback) + client.get_children("/a", watch=child_callback) + client.remove_all_watches("/a", WatcherType.DATA) + client.create("/a/b") + callback_event.wait(30) + + def test_remove_children_watch(self): + # Test that removing a children watch leaves a data watch in place. + self._require_zk_version(3, 6) + callback_event = threading.Event() + + def data_callback(event): + callback_event.set() + + def child_callback(event): + pass + + client = self.client + client.create("/a") + client.get("/a", watch=data_callback) + client.get_children("/a", watch=child_callback) + client.remove_all_watches("/a", WatcherType.CHILDREN) + client.set("/a", b"1") + callback_event.wait(30) + dummy_dict = { "aversion": 1,