diff --git a/pymilvus/decorators.py b/pymilvus/decorators.py index 53ab1201e..f9a8b0250 100644 --- a/pymilvus/decorators.py +++ b/pymilvus/decorators.py @@ -42,25 +42,31 @@ def handler(*args, **kwargs): # This has to make sure every timeout parameter is passing # throught kwargs form as `timeout=10` _timeout = kwargs.get("timeout", None) + _retry_times = kwargs.get("retry_times", None) _retry_on_rate_limit = kwargs.get("retry_on_rate_limit", True) retry_timeout = _timeout if _timeout is not None and isinstance(_timeout, int) else None + final_retry_times = ( + _retry_times + if _retry_times is not None and isinstance(_retry_times, int) + else retry_times + ) counter = 1 back_off = initial_back_off start_time = time.time() def timeout(start_time: Optional[float] = None) -> bool: """If timeout is valid, use timeout as the retry limits, - If timeout is None, use retry_times as the retry limits. + If timeout is None, use final_retry_times as the retry limits. """ if retry_timeout is not None: return time.time() - start_time >= retry_timeout - return counter > retry_times + return counter > final_retry_times to_msg = ( f"Retry timeout: {retry_timeout}s" if retry_timeout is not None - else f"Retry run out of {retry_times} retry times" + else f"Retry run out of {final_retry_times} retry times" ) while True: diff --git a/tests/test_decorators.py b/tests/test_decorators.py index e26de53d3..500887157 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -102,6 +102,20 @@ def test_api(self, code): # the first execute + 0 retry times assert self.execute_times == 1 + def test_retry_decorators_set_retry_times(self): + self.count_retry_times = 0 + + @retry_on_rpc_failure() + def test_api(self, code, retry_on_rate_limit, **kwargs): + self.count_retry_times += 1 + self.mock_milvus_exception(code) + + with pytest.raises(MilvusException) as e: + test_api(self, ErrorCode.RATE_LIMIT, retry_on_rate_limit=True, retry_times=3) + + # the first execute + 0 retry times + assert self.count_retry_times == 3 + 1 + @pytest.mark.parametrize("times", [0, 1, 2, 3]) def test_retry_decorators_rate_limit_without_retry(self, times): self.count_test_retry_decorators_force_deny = 0