From 2b1b0ab6b4eec241b8d47da128313c3d2733d15e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 1 Feb 2024 07:19:43 +0000 Subject: [PATCH] Fix test. --- .../unittest/test_engine_with_samplers.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 2a7bf1efd4..8ff7d56d11 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -342,20 +342,30 @@ def _test_penalty( def _test_logprobs( model_artifact_path, use_staging_engine, - max_num_sequences=4, - max_input_len=512, num_requests=5, top_logprobs=3, + max_num_batched_tokens=2048 ): - prompt = "hi" + prompt = "hi could you please implement merge sort?" engine = create_engine( model_artifact_path, use_staging_engine, - max_num_sequences, - max_input_len, + max_num_batched_tokens, ) - s = 113 - requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True, top_logprobs=top_logprobs, logprobs=True) for n in range(s, s+num_requests)] + requests = [ + create_request( + idx=str(n), + prompt=prompt, + temp=0, + freq_pen=0, + pre_pen=0, + max_tokens=300, + stop=None, + ignore_eos=True, + top_logprobs=top_logprobs, + logprobs=True + ) for n in range(num_requests) + ] engine.add(requests) generated = ["" for _ in range(num_requests)] @@ -366,7 +376,7 @@ def _test_logprobs( assert len(res.sequences) == 1 seq = res.sequences[0] - assert seq.finish_reason is not None or len(list(seq.logprobs.content[0]["top_logprobs"])) == top_logprobs + assert seq.finish_reason is not None or len(seq.logprob_info[0].top_logprobs) == top_logprobs if seq.is_finished: assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens