Skip to content

Commit

Permalink
Fix bug (#121)
Browse files Browse the repository at this point in the history
* rename parameter 'max_out_len' of base_api to 'max_tokens'

* fix bug of hf
  • Loading branch information
liujiangning30 authored Jan 31, 2024
1 parent 559275d commit ae3c7c3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 23 deletions.
12 changes: 4 additions & 8 deletions lagent/llms/base_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import re
import threading
import warnings
from abc import abstractclassmethod
from time import sleep
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -140,8 +138,6 @@ class BaseAPIModel(BaseModel):
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
max_seq_len (int): The maximum sequence length of the model. Defaults
to 2048.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
Expand All @@ -153,27 +149,27 @@ def __init__(self,
model_type: str,
query_per_second: int = 1,
retry: int = 2,
max_seq_len: int = 2048,
template_parser: 'APITemplateParser' = APITemplateParser,
meta_template: Optional[Dict] = None,
*,
max_out_len: int = 512,
max_tokens: int = 512,
top_p: float = 0.8,
top_k: float = None,
temperature: float = 0.8,
repetition_penalty: float = 0.0,
stop_words: Union[List[str], str] = None):
self.model_type = model_type
self.max_seq_len = max_seq_len
self.meta_template = meta_template
self.retry = retry
self.query_per_second = query_per_second
self.token_bucket = TokenBucket(query_per_second)
if template_parser:
self.template_parser = template_parser(meta_template)

if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_out_len=max_out_len,
max_tokens=max_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
Expand Down
2 changes: 2 additions & 0 deletions lagent/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def __init__(self,
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']

if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_tokens=max_tokens,
top_p=top_p,
Expand Down
8 changes: 4 additions & 4 deletions lagent/llms/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def __init__(self,
if not tokenizer_only:
self._load_model(path=path, model_kwargs=model_kwargs)

from transformers.generation.utils import (LogitsProcessorList,
StoppingCriteriaList)
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # noqa: E501
self.logits_processor = LogitsProcessorList()
self.stopping_criteria = StoppingCriteriaList()
self.prefix_allowed_tokens_fn = None

stop_words_id = []
for sw in self.gen_params.get('stop_words', []):
stop_words_id.append(self.tokenizer(sw)['input_ids'][1])
if self.gen_params.get('stop_words'):
for sw in self.gen_params.get('stop_words'):
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
self.additional_eos_token_id = stop_words_id

def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
Expand Down
17 changes: 6 additions & 11 deletions lagent/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ class GPTAPI(BaseAPIModel):
Args:
model_type (str): The name of OpenAI's model.
max_seq_len (int): The maximum allowed sequence length of a model.
Note that the length of prompt + generated tokens shall not exceed
this value. Defaults to 2048.
query_per_second (int): The maximum queries allowed per second
between two consecutive calls of the API. Defaults to 1.
retry (int): Number of retires if the API call fails. Defaults to 2.
Expand All @@ -38,15 +35,14 @@ class GPTAPI(BaseAPIModel):
wrapping of any meta instructions.
openai_api_base (str): The base url of OpenAI's API. Defaults to
'https://api.openai.com/v1/chat/completions'.
gen_params: Default generation configuration which could be overrided
gen_params: Default generation configuration which could be overridden
on the fly of generation.
"""

is_api: bool = True

def __init__(self,
model_type: str = 'gpt-3.5-turbo',
max_seq_len: int = 4096,
query_per_second: int = 1,
retry: int = 2,
key: Union[str, List[str]] = 'ENV',
Expand All @@ -60,7 +56,6 @@ def __init__(self,
**gen_params):
super().__init__(
model_type=model_type,
max_seq_len=max_seq_len,
meta_template=meta_template,
query_per_second=query_per_second,
retry=retry,
Expand Down Expand Up @@ -103,7 +98,7 @@ def chat(
"""Generate responses given the contexts.
Args:
inputs (Union[List[dict], List[List[dict]]]): a list of messages
inputs (Union[List[dict], List[List[dict]]]): a list of messages
or list of lists of messages
gen_params: additional generation configuration
Expand Down Expand Up @@ -137,10 +132,10 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
gen_params = gen_params.copy()

# Hold out 100 tokens due to potential errors in tiktoken calculation
max_out_len = min(
gen_params.pop('max_out_len'),
max_tokens = min(
gen_params.pop('max_tokens'),
self.context_window - len(self.tokenize(str(input))) - 100)
if max_out_len <= 0:
if max_tokens <= 0:
return ''

max_num_retries = 0
Expand Down Expand Up @@ -178,7 +173,7 @@ def _chat(self, messages: List[dict], **gen_params) -> str:
data = dict(
model=self.model_type,
messages=messages,
max_tokens=max_out_len,
max_tokens=max_tokens,
n=1,
stop=gen_params.pop('stop_words'),
frequency_penalty=gen_params.pop('repetition_penalty'),
Expand Down

0 comments on commit ae3c7c3

Please sign in to comment.