diff --git a/nodepool/zk/__init__.py b/nodepool/zk/__init__.py index ad47e06d..39b8635d 100644 --- a/nodepool/zk/__init__.py +++ b/nodepool/zk/__init__.py @@ -9,12 +9,15 @@ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. + import logging import time from abc import ABCMeta from threading import Thread -from kazoo.client import KazooClient +import kazoo.client +from nodepool.zk.vendor.client import ZuulKazooClient +from nodepool.zk.vendor.connection import ZuulConnectionHandler from kazoo.handlers.threading import KazooTimeoutError from kazoo.protocol.states import KazooState @@ -22,6 +25,9 @@ from nodepool.zk.handler import PoolSequentialThreadingHandler +kazoo.client.ConnectionHandler = ZuulConnectionHandler + + class ZooKeeperClient(object): log = logging.getLogger("nodepool.zk.ZooKeeperClient") @@ -135,7 +141,7 @@ def connect(self): args['keyfile'] = self.tls_key args['certfile'] = self.tls_cert args['ca'] = self.tls_ca - self.client = KazooClient(**args) + self.client = ZuulKazooClient(**args) self.client.add_listener(self._connectionListener) # Manually retry initial connection attempt while True: diff --git a/nodepool/zk/vendor/__init__.py b/nodepool/zk/vendor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nodepool/zk/vendor/client.py b/nodepool/zk/vendor/client.py new file mode 100644 index 00000000..96af00d8 --- /dev/null +++ b/nodepool/zk/vendor/client.py @@ -0,0 +1,108 @@ +# This file is derived from the Kazoo project +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from collections import defaultdict + +from kazoo.client import ( + _prefix_root, + KazooClient, +) +from kazoo.protocol.states import ( + Callback, + EventType, + WatchedEvent, +) + +from nodepool.zk.vendor.serialization import AddWatch + + +class ZuulKazooClient(KazooClient): + def __init__(self, *args, **kw): + self._persistent_watchers = defaultdict(set) + self._persistent_recursive_watchers = defaultdict(set) + super().__init__(*args, **kw) + + def _reset_watchers(self): + watchers = [] + for child_watchers in self._child_watchers.values(): + watchers.extend(child_watchers) + + for data_watchers in self._data_watchers.values(): + watchers.extend(data_watchers) + + for persistent_watchers in self._persistent_watchers.values(): + watchers.extend(persistent_watchers) + + for pr_watchers in self._persistent_recursive_watchers.values(): + watchers.extend(pr_watchers) + + self._child_watchers = defaultdict(set) + self._data_watchers = defaultdict(set) + self._persistent_watchers = defaultdict(set) + self._persistent_recursive_watchers = defaultdict(set) + + ev = WatchedEvent(EventType.NONE, self._state, None) + for watch in watchers: + self.handler.dispatch_callback(Callback("watch", watch, (ev,))) + + def add_watch(self, path, watch, mode): + """Add a watch. + + This method adds persistent watches. Unlike the data and + child watches which may be set by calls to + :meth:`KazooClient.exists`, :meth:`KazooClient.get`, and + :meth:`KazooClient.get_children`, persistent watches are not + removed after being triggered. + + To remove a persistent watch, use + :meth:`KazooClient.remove_all_watches` with an argument of + :attr:`~kazoo.states.WatcherType.ANY`. + + The `mode` argument determines whether or not the watch is + recursive. To set a persistent watch, use + :class:`~kazoo.states.AddWatchMode.PERSISTENT`. To set a + persistent recursive watch, use + :class:`~kazoo.states.AddWatchMode.PERSISTENT_RECURSIVE`. + + :param path: Path of node to watch. + :param watch: Watch callback to set for future changes + to this path. + :param mode: The mode to use. + :type mode: int + + :raises: + :exc:`~kazoo.exceptions.MarshallingError` if mode is + unknown. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + """ + return self.add_watch_async(path, watch, mode).get() + + def add_watch_async(self, path, watch, mode): + """Asynchronously add a watch. Takes the same arguments as + :meth:`add_watch`. + """ + if not isinstance(path, str): + raise TypeError("Invalid type for 'path' (string expected)") + if not callable(watch): + raise TypeError("Invalid type for 'watch' (must be a callable)") + if not isinstance(mode, int): + raise TypeError("Invalid type for 'mode' (int expected)") + + async_result = self.handler.async_result() + self._call( + AddWatch(_prefix_root(self.chroot, path), watch, mode), + async_result, + ) + return async_result diff --git a/nodepool/zk/vendor/connection.py b/nodepool/zk/vendor/connection.py new file mode 100644 index 00000000..a3b7b56f --- /dev/null +++ b/nodepool/zk/vendor/connection.py @@ -0,0 +1,185 @@ +# This file is derived from the Kazoo project +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from kazoo.exceptions import ( + EXCEPTIONS, + NoNodeError, +) +from kazoo.loggingsupport import BLATHER +from kazoo.protocol.connection import ( + ConnectionHandler, + CREATED_EVENT, + DELETED_EVENT, + CHANGED_EVENT, + CHILD_EVENT, + CLOSE_RESPONSE, +) +from kazoo.protocol.serialization import ( + Close, + Exists, + Transaction, + GetChildren, + GetChildren2, + Watch, +) +from kazoo.protocol.states import ( + Callback, + WatchedEvent, + EVENT_TYPE_MAP, +) + +from nodepool.zk.vendor.serialization import ( + AddWatch, + RemoveWatches, +) +from nodepool.zk.vendor.states import ( + AddWatchMode, + WatcherType, +) + + +class ZuulConnectionHandler(ConnectionHandler): + def _find_persistent_recursive_watchers(self, path): + parts = path.split("/") + watchers = [] + for count in range(len(parts)): + candidate = "/".join(parts[: count + 1]) + if not candidate: + candidate = '/' + watchers.extend( + self.client._persistent_recursive_watchers.get(candidate, []) + ) + return watchers + + def _read_watch_event(self, buffer, offset): + client = self.client + watch, offset = Watch.deserialize(buffer, offset) + path = watch.path + + self.logger.debug("Received EVENT: %s", watch) + + watchers = [] + + if watch.type in (CREATED_EVENT, CHANGED_EVENT): + watchers.extend(client._data_watchers.pop(path, [])) + watchers.extend(client._persistent_watchers.get(path, [])) + watchers.extend(self._find_persistent_recursive_watchers(path)) + elif watch.type == DELETED_EVENT: + watchers.extend(client._data_watchers.pop(path, [])) + watchers.extend(client._child_watchers.pop(path, [])) + watchers.extend(client._persistent_watchers.get(path, [])) + watchers.extend(self._find_persistent_recursive_watchers(path)) + elif watch.type == CHILD_EVENT: + watchers.extend(client._child_watchers.pop(path, [])) + else: + self.logger.warn("Received unknown event %r", watch.type) + return + + # Strip the chroot if needed + path = client.unchroot(path) + ev = WatchedEvent(EVENT_TYPE_MAP[watch.type], client._state, path) + + # Last check to ignore watches if we've been stopped + if client._stopped.is_set(): + return + + # Dump the watchers to the watch thread + for watch in watchers: + client.handler.dispatch_callback(Callback("watch", watch, (ev,))) + + def _read_response(self, header, buffer, offset): + client = self.client + request, async_object, xid = client._pending.popleft() + if header.zxid and header.zxid > 0: + client.last_zxid = header.zxid + if header.xid != xid: + exc = RuntimeError( + "xids do not match, expected %r " "received %r", + xid, + header.xid, + ) + async_object.set_exception(exc) + raise exc + + # Determine if its an exists request and a no node error + exists_error = ( + header.err == NoNodeError.code and request.type == Exists.type + ) + + # Set the exception if its not an exists error + if header.err and not exists_error: + callback_exception = EXCEPTIONS[header.err]() + self.logger.debug( + "Received error(xid=%s) %r", xid, callback_exception + ) + if async_object: + async_object.set_exception(callback_exception) + elif request and async_object: + if exists_error: + # It's a NoNodeError, which is fine for an exists + # request + async_object.set(None) + else: + try: + response = request.deserialize(buffer, offset) + except Exception as exc: + self.logger.exception( + "Exception raised during deserialization " + "of request: %s", + request, + ) + async_object.set_exception(exc) + return + self.logger.debug( + "Received response(xid=%s): %r", xid, response + ) + + # We special case a Transaction as we have to unchroot things + if request.type == Transaction.type: + response = Transaction.unchroot(client, response) + + async_object.set(response) + + # Determine if watchers should be registered or unregistered + if not client._stopped.is_set(): + watcher = getattr(request, "watcher", None) + if watcher: + if isinstance(request, AddWatch): + if request.mode == AddWatchMode.PERSISTENT: + client._persistent_watchers[request.path].add( + watcher + ) + elif request.mode == AddWatchMode.PERSISTENT_RECURSIVE: + client._persistent_recursive_watchers[ + request.path + ].add(watcher) + elif isinstance(request, (GetChildren, GetChildren2)): + client._child_watchers[request.path].add(watcher) + else: + client._data_watchers[request.path].add(watcher) + if isinstance(request, RemoveWatches): + if request.watcher_type == WatcherType.CHILDREN: + client._child_watchers.pop(request.path, None) + elif request.watcher_type == WatcherType.DATA: + client._data_watchers.pop(request.path, None) + elif request.watcher_type == WatcherType.ANY: + client._child_watchers.pop(request.path, None) + client._data_watchers.pop(request.path, None) + client._persistent_watchers.pop(request.path, None) + client._persistent_recursive_watchers.pop( + request.path, None + ) + + if isinstance(request, Close): + self.logger.log(BLATHER, "Read close response") + return CLOSE_RESPONSE diff --git a/nodepool/zk/vendor/serialization.py b/nodepool/zk/vendor/serialization.py new file mode 100644 index 00000000..3e88957d --- /dev/null +++ b/nodepool/zk/vendor/serialization.py @@ -0,0 +1,46 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from collections import namedtuple + +from kazoo.protocol.serialization import ( + int_struct, + write_string, +) + + +class RemoveWatches(namedtuple("RemoveWatches", "path watcher_type")): + type = 18 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.watcher_type)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return None + + +class AddWatch(namedtuple("AddWatch", "path watcher mode")): + type = 106 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.mode)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return None diff --git a/nodepool/zk/vendor/states.py b/nodepool/zk/vendor/states.py new file mode 100644 index 00000000..7b1f69cf --- /dev/null +++ b/nodepool/zk/vendor/states.py @@ -0,0 +1,51 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +class AddWatchMode(object): + """Modes for use with :meth:`~kazoo.client.KazooClient.add_watch` + + .. attribute:: PERSISTENT + + The watch is not removed when trigged. + + .. attribute:: PERSISTENT_RECURSIVE + + The watch is not removed when trigged, and applies to all + paths underneath the supplied path as well. + """ + + PERSISTENT = 0 + PERSISTENT_RECURSIVE = 1 + + +class WatcherType(object): + """Watcher types for use with + :meth:`~kazoo.client.KazooClient.remove_all_watches` + + .. attribute:: CHILDREN + + Child watches. + + .. attribute:: DATA + + Data watches. + + .. attribute:: ANY + + Any type of watch (child, data, persistent, or persistent + recursive). + + """ + + CHILDREN = 1 + DATA = 2 + ANY = 3