Skip to content

Commit

Permalink
more robust udp broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Dec 17, 2024
1 parent 1f108a0 commit 198308b
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions exo/networking/udp/udp_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,29 @@ def datagram_received(self, data, addr):
asyncio.create_task(self.on_message(data, addr))


def get_broadcast_address(ip_addr: str) -> str:
try:
# Split IP into octets and create broadcast address for the subnet
ip_parts = ip_addr.split('.')
return f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}.255"
except:
return "255.255.255.255"


class BroadcastProtocol(asyncio.DatagramProtocol):
def __init__(self, message: str, broadcast_port: int):
def __init__(self, message: str, broadcast_port: int, source_ip: str):
self.message = message
self.broadcast_port = broadcast_port
self.source_ip = source_ip

def connection_made(self, transport):
sock = transport.get_extra_info("socket")
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))
# Try both subnet-specific and global broadcast
broadcast_addr = get_broadcast_address(self.source_ip)
transport.sendto(self.message.encode("utf-8"), (broadcast_addr, self.broadcast_port))
if broadcast_addr != "255.255.255.255":
transport.sendto(self.message.encode("utf-8"), ("255.255.255.255", self.broadcast_port))


class UDPDiscovery(Discovery):
Expand Down Expand Up @@ -99,14 +113,17 @@ async def task_broadcast_presence(self):

transport = None
try:
# Create socket with explicit broadcast permission
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except AttributeError:
pass
sock.bind((addr, 0))

# Create transport with the pre-configured socket
transport, _ = await asyncio.get_event_loop().create_datagram_endpoint(
lambda: BroadcastProtocol(message, self.broadcast_port),
lambda: BroadcastProtocol(message, self.broadcast_port, addr),
sock=sock
)
except Exception as e:
Expand Down

0 comments on commit 198308b

Please sign in to comment.