Skip to content

Commit

Permalink
Add torch.manual_seed for test_fault_tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Jul 14, 2024
1 parent f151e23 commit a12451e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/test_allreduce_fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import asyncio
from enum import Enum, auto

import pytest
import torch

import hivemind
from hivemind.averaging.averager import *
from hivemind.averaging.averager import AllReduceRunner, AveragingMode, GatheredData
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.load_balancing import load_balance_peers
from hivemind.averaging.matchmaking import MatchmakingException
from hivemind.proto import averaging_pb2
from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
from hivemind.utils.asyncio import AsyncIterator, aenumerate, as_aiter, azip, enter_asynchronously
from hivemind.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -138,6 +140,8 @@ async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[avera
],
)
def test_fault_tolerance(fault0: Fault, fault1: Fault):
torch.manual_seed(0)

def _make_tensors():
return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]

Expand Down

0 comments on commit a12451e

Please sign in to comment.