From 198308b1eb22407519473ea9cc0ec71fb7011dd1 Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Tue, 17 Dec 2024 17:28:55 +0000 Subject: [PATCH] more robust udp broadcast --- exo/networking/udp/udp_discovery.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/exo/networking/udp/udp_discovery.py b/exo/networking/udp/udp_discovery.py index 66331da32..7117c7cd2 100644 --- a/exo/networking/udp/udp_discovery.py +++ b/exo/networking/udp/udp_discovery.py @@ -23,15 +23,29 @@ def datagram_received(self, data, addr): asyncio.create_task(self.on_message(data, addr)) +def get_broadcast_address(ip_addr: str) -> str: + try: + # Split IP into octets and create broadcast address for the subnet + ip_parts = ip_addr.split('.') + return f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}.255" + except: + return "255.255.255.255" + + class BroadcastProtocol(asyncio.DatagramProtocol): - def __init__(self, message: str, broadcast_port: int): + def __init__(self, message: str, broadcast_port: int, source_ip: str): self.message = message self.broadcast_port = broadcast_port + self.source_ip = source_ip def connection_made(self, transport): sock = transport.get_extra_info("socket") sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port)) + # Try both subnet-specific and global broadcast + broadcast_addr = get_broadcast_address(self.source_ip) + transport.sendto(self.message.encode("utf-8"), (broadcast_addr, self.broadcast_port)) + if broadcast_addr != "255.255.255.255": + transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port)) class UDPDiscovery(Discovery): @@ -99,14 +113,17 @@ async def task_broadcast_presence(self): transport = None try: - # Create socket with explicit broadcast permission sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except AttributeError: + pass sock.bind((addr, 0)) - # Create transport with the pre-configured socket transport, _ = await asyncio.get_event_loop().create_datagram_endpoint( - lambda: BroadcastProtocol(message, self.broadcast_port), + lambda: BroadcastProtocol(message, self.broadcast_port, addr), sock=sock ) except Exception as e: