Skip to content

Commit

Permalink
Add do_sample in AsyncLMDeployPipeline (lmdeploy_wrapper.py) (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkungoo authored Jan 2, 2025
1 parent a58c914 commit c337aa8
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions lagent/llms/lmdeploy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,19 @@ async def generate(self,
assert len(inputs) == len(session_ids)

prompt = inputs
do_sample = kwargs.pop('do_sample', None)
gen_params = self.update_gen_params(**kwargs)
if do_sample is None:
do_sample = self.do_sample
if do_sample is not None and self.version < (0, 6, 0):
raise RuntimeError(
'`do_sample` parameter is not supported by lmdeploy until '
f'v0.6.0, but currently using lmdeloy {self.str_version}')
if self.version >= (0, 6, 0):
if do_sample is None:
do_sample = gen_params['top_k'] > 1 or gen_params[
'temperature'] > 0
gen_params.update(do_sample=do_sample)
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)

Expand Down

0 comments on commit c337aa8

Please sign in to comment.