diff --git a/hivemind/moe/server/layers/optim.py b/hivemind/moe/server/layers/optim.py index f280ba427..00eb85e75 100644 --- a/hivemind/moe/server/layers/optim.py +++ b/hivemind/moe/server/layers/optim.py @@ -1,11 +1,10 @@ import torch -class OptimizerWrapper(torch.optim.Optimizer): +class OptimizerWrapper: """A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" def __init__(self, optim: torch.optim.Optimizer): - super().__init__(optim.param_groups, optim.defaults) self.optim = optim @property diff --git a/tests/test_training.py b/tests/test_training.py index c63b5116d..94c7ea993 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -20,7 +20,12 @@ def test_training(max_steps: int = 100, threshold: float = 0.9): SGD = partial(torch.optim.SGD, lr=0.05) with background_server( - num_experts=2, device="cpu", optim_cls=SGD, hidden_dim=64, num_handlers=1 + num_experts=2, + device="cpu", + optim_cls=SGD, + hidden_dim=64, + num_handlers=1, + clip_grad_norm=1.0, ) as server_peer_info: dht = DHT(initial_peers=server_peer_info.addrs, start=True) expert1, expert2 = create_remote_experts(