diff --git a/.gitignore b/.gitignore index 13e309d3b..fcae47e75 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ .venv test_weights.npz .exo_used_ports +.exo_node_id .idea # Byte-compiled / optimized / DLL files @@ -166,4 +167,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ diff --git a/exo/helpers.py b/exo/helpers.py index 71281a359..2b4027a4a 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -5,6 +5,8 @@ import random import platform import psutil +import uuid +from pathlib import Path DEBUG = int(os.getenv("DEBUG", default="0")) DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0")) @@ -167,3 +169,35 @@ def find_longest_prefix(self, argument: str) -> Optional[Tuple[K, V]]: return None return max(matches, key=lambda x: len(x[0])) + +def is_valid_uuid(val): + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + +def get_or_create_node_id(): + NODE_ID_FILE = Path(os.path.dirname(os.path.abspath(__file__))) / ".exo_node_id" + try: + if NODE_ID_FILE.is_file(): + with open(NODE_ID_FILE, "r") as f: + stored_id = f.read().strip() + if is_valid_uuid(stored_id): + if DEBUG >= 2: print(f"Retrieved existing node ID: {stored_id}") + return stored_id + else: + if DEBUG >= 2: print("Stored ID is not a valid UUID. Generating a new one.") + + new_id = str(uuid.uuid4()) + with open(NODE_ID_FILE, "w") as f: + f.write(new_id) + + if DEBUG >= 2: print(f"Generated and stored new node ID: {new_id}") + return new_id + except IOError as e: + if DEBUG >= 2: print(f"IO error creating node_id: {e}") + return str(uuid.uuid4()) + except Exception as e: + if DEBUG >= 2: print(f"Unexpected error creating node_id: {e}") + return str(uuid.uuid4()) diff --git a/main.py b/main.py index a456f2daf..3b47c1c9c 100644 --- a/main.py +++ b/main.py @@ -8,11 +8,11 @@ from exo.networking.grpc.grpc_discovery import GRPCDiscovery from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy from exo.api import ChatGPTAPI -from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info +from exo.helpers import print_yellow_exo, find_available_port, DEBUG, get_inference_engine, get_system_info, get_or_create_node_id # parse args parser = argparse.ArgumentParser(description="Initialize GRPC Discovery") -parser.add_argument("--node-id", type=str, default=str(uuid.uuid4()), help="Node ID") +parser.add_argument("--node-id", type=str, default=None, help="Node ID") parser.add_argument("--node-host", type=str, default="0.0.0.0", help="Node host") parser.add_argument("--node-port", type=int, default=None, help="Node port") parser.add_argument("--listen-port", type=int, default=5678, help="Listening port for discovery") @@ -40,6 +40,7 @@ args.node_port = find_available_port(args.node_host) if DEBUG >= 1: print(f"Using available port: {args.node_port}") +args.node_id = args.node_id or get_or_create_node_id() discovery = GRPCDiscovery(args.node_id, args.node_port, args.listen_port, args.broadcast_port, discovery_timeout=args.discovery_timeout) node = StandardNode( args.node_id,