Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Dec 14, 2024
1 parent 941c761 commit 0690a94
Show file tree
Hide file tree
Showing 10 changed files with 15 additions and 14 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
Expand Down
4 changes: 2 additions & 2 deletions hivemind/moe/client/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert


def get_server_stub(p2p: P2P, server_peer_id: PeerID) -> "ConnectionHandlerStub": # noqa: F821
def get_server_stub(p2p: P2P, server_peer_id: PeerID) -> ConnectionHandlerStub: # noqa: F821
"""Create an RPC stub that can send requests to any expert on the specified remote server"""
return moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_id)

Expand Down Expand Up @@ -199,7 +199,7 @@ def forward(
ctx,
dummy: torch.Tensor,
uid: str,
stub: "ConnectionHandlerStub", # noqa: F821
stub: ConnectionHandlerStub, # noqa: F821
info: Dict[str, Any],
*inputs: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
Expand Down
2 changes: 1 addition & 1 deletion hivemind/moe/client/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tens

if self._expert_info is None:
try:
self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
self._expert_info = next(expert.info for experts_i in chosen_experts for expert in experts_i)
except StopIteration:
raise RuntimeError(
"No responding experts found during beam search. Check that UID prefixes and "
Expand Down
2 changes: 1 addition & 1 deletion hivemind/moe/client/switch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def forward(self, input: torch.Tensor, *args: torch.Tensor, **kwargs: torch.Tens

if self._expert_info is None:
try:
self._expert_info = next((expert.info for experts_i in chosen_experts for expert in experts_i))
self._expert_info = next(expert.info for experts_i in chosen_experts for expert in experts_i)
except StopIteration:
raise RuntimeError(
"No responding experts found during beam search. Check that UID prefixes and "
Expand Down
4 changes: 3 additions & 1 deletion hivemind/moe/expert_uid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from hivemind.p2p import PeerID

ExpertUID, ExpertPrefix, Coordinate, Score = str, str, int, float
ExpertInfo = NamedTuple("ExpertInfo", [("uid", ExpertUID), ("peer_id", PeerID)])
class ExpertInfo(NamedTuple):
uid: ExpertUID
peer_id: PeerID
UID_DELIMITER = "." # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
FLAT_EXPERT = -1 # grid prefix reserved for storing 1d expert uids. Used to speed up find_best_experts in 1d case.
UID_PATTERN = re.compile("^(([^.])+)([.](?:[0]|([1-9]([0-9]*))))+$") # e.g. ffn_expert.98.76.54 - prefix + some dims
Expand Down
2 changes: 1 addition & 1 deletion hivemind/p2p/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


@dataclass(frozen=True)
class P2PContext(object):
class P2PContext:
handle_name: str
local_id: PeerID
remote_id: PeerID = None
Expand Down
2 changes: 1 addition & 1 deletion hivemind/utils/mpfuture.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class MPFuture(base.Future, Generic[ResultType]):
_update_lock = mp.Lock() # global lock that prevents simultaneous writing to the same pipe
_global_sender_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
_pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions
_active_futures: Optional[Dict[UID, "ref[MPFuture]"]] = None # non-done futures originated from this process
_active_futures: Optional[Dict[UID, ref[MPFuture]]] = None # non-done futures originated from this process
_active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively

def __init__(self, *, use_lock: bool = True):
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ required-version = "==0.7.2"
target-version = "py38"

[tool.ruff.lint]
select = ["E", "F"]
ignore = ["E501", "E702"]
select = ["E", "F", "W", "I", "UP", "YTT", "ASYNC", "LOG", "PIE", "SIM", "PLC", "PLE", "FURB"]
ignore = ["E501", "E702", "UP006", "UP007", "SIM102", "SIM103", "SIM108", "SIM300"]
dummy-variable-rgx = "^_$"

[tool.ruff.lint.isort]
known-first-party = ["arguments", "test_utils", "tests", "utils"]
known-local-folder = ["arguments", "test_utils", "tests", "utils"]
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def build_p2p_daemon():
raise FileNotFoundError("Could not find golang installation")
version = parse_version(m.group(1))
if version < parse_version("1.13"):
raise EnvironmentError(f"Newer version of go required: must be >= 1.13, found {version}")
raise OSError(f"Newer version of go required: must be >= 1.13, found {version}")

with tempfile.TemporaryDirectory() as tempdir:
dest = os.path.join(tempdir, "libp2p-daemon.tar.gz")
Expand Down Expand Up @@ -145,7 +145,7 @@ def run(self):

# loading version from setup.py
with codecs.open(os.path.join(here, "hivemind/__init__.py"), encoding="utf-8") as init_file:
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M)
version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.MULTILINE)
version_string = version_match.group(1)

extras = {}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_allreduce_fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from hivemind.averaging.load_balancing import load_balance_peers
from hivemind.averaging.matchmaking import MatchmakingException
from hivemind.proto import averaging_pb2
from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously, anext
from hivemind.utils.asyncio import aenumerate, anext, as_aiter, azip, enter_asynchronously
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down

0 comments on commit 0690a94

Please sign in to comment.