From 187f8a91d90e97fcf2ec787fd16d2171c383b546 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 3 Oct 2024 12:07:21 -0700 Subject: [PATCH] Work around flaky Triton shared mem by retrying --- qa/L0_e2e/test_model.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/qa/L0_e2e/test_model.py b/qa/L0_e2e/test_model.py index eae5a4a..d53f88c 100644 --- a/qa/L0_e2e/test_model.py +++ b/qa/L0_e2e/test_model.py @@ -281,6 +281,7 @@ def test_small(client, model_data, hypothesis_data): model_inputs, model_data.output_sizes, shared_mem=shared_mem, + attempts=100, ) for name, input_ in model_inputs.items(): all_model_inputs[name].append(input_) @@ -324,7 +325,11 @@ def test_small(client, model_data, hypothesis_data): st.one_of(st.just(mode) for mode in valid_shm_modes()) ) all_triton_outputs = client.predict( - model_data.name, all_model_inputs, total_output_sizes, shared_mem=shared_mem + model_data.name, + all_model_inputs, + total_output_sizes, + shared_mem=shared_mem, + attempts=100, ) for output_name in sorted(ground_truth.keys()): @@ -359,7 +364,11 @@ def test_max_batch(client, model_data, shared_mem): } shared_mem = valid_shm_modes()[0] result = client.predict( - model_data.name, max_inputs, model_output_sizes, shared_mem=shared_mem + model_data.name, + max_inputs, + model_output_sizes, + shared_mem=shared_mem, + attempts=100, ) ground_truth = model_data.ground_truth_model.predict(max_inputs)