diff --git a/.github/workflows/bench_job.yml b/.github/workflows/bench_job.yml index 64cd632fb..2900f303f 100644 --- a/.github/workflows/bench_job.yml +++ b/.github/workflows/bench_job.yml @@ -74,7 +74,7 @@ jobs: export PATH="/usr/local/bin:/opt/homebrew/bin:$PATH" echo "Starting exo daemon..." - DEBUG=6 DEBUG_DISCOVERY=6 exo --node-id="${MY_NODE_ID}" --node-id-filter="${ALL_NODE_IDS}" --chatgpt-api-port 52415 > output1.log 2>&1 & + DEBUG=6 DEBUG_DISCOVERY=6 exo --node-id="${MY_NODE_ID}" --node-id-filter="${ALL_NODE_IDS}" --interface-type-filter="Ethernet" --chatgpt-api-port 52415 > output1.log 2>&1 & PID1=$! echo "Exo process started with PID: $PID1" tail -f output1.log & diff --git a/exo/main.py b/exo/main.py index 184e04165..a4c964467 100644 --- a/exo/main.py +++ b/exo/main.py @@ -59,6 +59,7 @@ parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key") parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name") parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)") +parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)") args = parser.parse_args() print(f"Selected inference engine: {args.inference_engine}") @@ -90,8 +91,9 @@ for chatgpt_api_endpoint in chatgpt_api_endpoints: print(f" - {terminal_link(chatgpt_api_endpoint)}") -# Convert node-id-filter to list if provided +# Convert node-id-filter and interface-type-filter to lists if provided allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None +allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None if args.discovery_module == "udp": discovery = UDPDiscovery( @@ -101,7 +103,8 @@ args.broadcast_port, lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), discovery_timeout=args.discovery_timeout, - allowed_node_ids=allowed_node_ids + allowed_node_ids=allowed_node_ids, + allowed_interface_types=allowed_interface_types ) elif args.discovery_module == "tailscale": discovery = TailscaleDiscovery( diff --git a/exo/networking/udp/udp_discovery.py b/exo/networking/udp/udp_discovery.py index 168ebbee6..ff253a60a 100644 --- a/exo/networking/udp/udp_discovery.py +++ b/exo/networking/udp/udp_discovery.py @@ -3,7 +3,7 @@ import socket import time import traceback -from typing import List, Dict, Callable, Tuple, Coroutine +from typing import List, Dict, Callable, Tuple, Coroutine, Optional from exo.networking.discovery import Discovery from exo.networking.peer_handle import PeerHandle from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES @@ -45,7 +45,8 @@ def __init__( broadcast_interval: int = 2.5, discovery_timeout: int = 30, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES, - allowed_node_ids: List[str] = None, + allowed_node_ids: Optional[List[str]] = None, + allowed_interface_types: Optional[List[str]] = None, ): self.node_id = node_id self.node_port = node_port @@ -56,6 +57,7 @@ def __init__( self.discovery_timeout = discovery_timeout self.device_capabilities = device_capabilities self.allowed_node_ids = allowed_node_ids + self.allowed_interface_types = allowed_interface_types self.known_peers: Dict[str, Tuple[PeerHandle, float, float, int]] = {} self.broadcast_task = None self.listen_task = None @@ -147,6 +149,12 @@ async def on_listen_message(self, data, addr): peer_prio = message["priority"] peer_interface_name = message["interface_name"] peer_interface_type = message["interface_type"] + + # Skip if interface type is not in allowed list + if self.allowed_interface_types and peer_interface_type not in self.allowed_interface_types: + if DEBUG_DISCOVERY >= 2: print(f"Ignoring peer {peer_id} as its interface type {peer_interface_type} is not in the allowed interface types list") + return + device_capabilities = DeviceCapabilities(**message["device_capabilities"]) if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}":