diff --git a/sleap/gui/utils.py b/sleap/gui/utils.py index 4f8215706..b974e26e8 100644 --- a/sleap/gui/utils.py +++ b/sleap/gui/utils.py @@ -1,6 +1,7 @@ """Generic module containing utilities used for the GUI.""" import zmq +import time from typing import Optional @@ -12,6 +13,7 @@ def is_port_free(port: int, zmq_context: Optional[zmq.Context] = None) -> bool: try: socket.bind(address) socket.unbind(address) + time.sleep(0.1) return True except zmq.error.ZMQError: return False @@ -26,3 +28,28 @@ def select_zmq_port(zmq_context: Optional[zmq.Context] = None) -> int: port = socket.bind_to_random_port("tcp://127.0.0.1") socket.close() return port + + +def find_free_port(port: int, zmq_context: zmq.Context): + """Find free port to bind to. + + Args: + port: The port to start searching from. + zmq_context: The ZMQ context to use. + + Returns: + The free port. + """ + attempts = 0 + max_attempts = 10 + while not is_port_free(port=port, zmq_context=zmq_context): + if attempts >= max_attempts: + raise RuntimeError( + f"Could not find free port to display training progress after " + f"{max_attempts} attempts. Please check your network settings " + "or use the CLI `sleap-train` command." + ) + port = select_zmq_port(zmq_context=zmq_context) + attempts += 1 + + return port diff --git a/sleap/gui/widgets/monitor.py b/sleap/gui/widgets/monitor.py index fff8a0327..9a17eecff 100644 --- a/sleap/gui/widgets/monitor.py +++ b/sleap/gui/widgets/monitor.py @@ -12,7 +12,7 @@ import matplotlib.transforms as mtransforms from qtpy import QtCore, QtWidgets -from sleap.gui.utils import is_port_free, select_zmq_port +from sleap.gui.utils import find_free_port from sleap.gui.widgets.mpl import MplCanvas from sleap.nn.config.training_job import TrainingJobConfig @@ -788,30 +788,6 @@ def _setup_zmq(self, zmq_context: Optional[zmq.Context] = None): self.sub = self.ctx.socket(zmq.SUB) self.sub.subscribe("") - def find_free_port(port: int, zmq_context: zmq.Context): - """Find free port to bind to. - - Args: - port: The port to start searching from. - zmq_context: The ZMQ context to use. - - Returns: - The free port. - """ - attempts = 0 - max_attempts = 10 - while not is_port_free(port=port, zmq_context=zmq_context): - if attempts >= max_attempts: - raise RuntimeError( - f"Could not find free port to display training progress after " - f"{max_attempts} attempts. Please check your network settings " - "or use the CLI `sleap-train` command." - ) - port = select_zmq_port(zmq_context=self.ctx) - attempts += 1 - - return port - # Find a free port and bind to it. self.zmq_ports["publish_port"] = find_free_port( port=self.zmq_ports["publish_port"], zmq_context=self.ctx