From 585547bcf4362c93fed88afc4332066b25832372 Mon Sep 17 00:00:00 2001 From: Mikhail Denisenko Date: Sat, 25 May 2024 19:09:00 -0400 Subject: [PATCH] Add unit-tests for exponential_backoff and small fix for it --- src/pytds/utils.py | 9 +++++--- tests/utils_tests.py | 51 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/src/pytds/utils.py b/src/pytds/utils.py index fabe6b9..9ab7655 100644 --- a/src/pytds/utils.py +++ b/src/pytds/utils.py @@ -49,13 +49,16 @@ def exponential_backoff( work_actual_time, ) ex_handler(ex) - if time.time() >= end_time: - raise TimeoutError() from ex - remaining_attempt_time = try_time - (time.time() - try_start_time) + cur_time = time.time() + remaining_attempt_time = try_time - (cur_time - try_start_time) logger.info("Will retry after %f seconds", remaining_attempt_time) if remaining_attempt_time > 0: time.sleep(remaining_attempt_time) + cur_time += remaining_attempt_time + if cur_time >= end_time: + raise TimeoutError() from ex try_time *= backoff_factor + try_time = min(try_time, end_time - cur_time) def parse_server(server: str) -> tuple[str, str]: diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 857e6fc..60f9538 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -1,6 +1,57 @@ +import time +import pytest import pytds.utils def test_parse_server(): assert pytds.utils.parse_server(".") == ("localhost", "") assert pytds.utils.parse_server("(local)") == ("localhost", "") + + +def test_exponential_backoff_success_first_attempt(): + """ + Test exponential backoff succeeding on first attempt + """ + got_exception = {'value': None} + + def ex_handler(ex): + got_exception['value'] = ex + + res = pytds.utils.exponential_backoff( + work=lambda t: t, + ex_handler=ex_handler, + max_time_sec=1, + first_attempt_time_sec=0.1, + ) + # result should be what was returned by work lambda + # and it should be equal to what was passed as first attempt timeout + # since this is what is passed to work lambda and what it returns + assert res == 0.1 + assert got_exception['value'] is None + + +def test_exponential_backoff_timeout(): + """ + Should perform 4 attempts with expected timeouts for each when attempts fail + """ + context = {'attempts': 0} + start_time = time.time() + + def work(t): + context['attempts'] += 1 + print(f"attempt {context['attempts']}, timeout {t:0.1f}, start time {time.time() - start_time:0.1f}") + raise RuntimeError("raising test exception") + + with pytest.raises(TimeoutError): + pytds.utils.exponential_backoff( + work=work, + ex_handler=lambda ex: None, + max_time_sec=1, + first_attempt_time_sec=0.1, + ) + # attempts are + # 1: timeout 0.1, ends at 0.1 + # 2: timeout 0.2, ends at 0.3 + # 3: timeout 0.4, ends at 0.7 + # 4: timeout 0.3, ends at 1.0 + assert context['attempts'] == 4