Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add api_prompts to GenInferencerOutput #1314

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion opencompass/models/huggingface_above_v4_33.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def generate(self,
tokenize_kwargs['add_special_tokens'] = False
tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs)

prompt_list = messages
tokens = {k: v.to(self.model.device) for k, v in tokens.items()}

generation_kwargs = self.generation_kwargs.copy()
Expand All @@ -292,7 +293,7 @@ def generate(self,
for stop in stopping_criteria:
decodeds = [t.split(stop)[0] for t in decodeds]

return decodeds
return prompt_list, decodeds

def get_token_len(self, prompt: str) -> int:
m = _convert_chat_messages([prompt])[0]
Expand Down
2 changes: 1 addition & 1 deletion opencompass/models/vllm_with_tf_above_v4_33.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def generate(self, inputs: List[str], max_out_len: int, stopping_criteria: List[
prompt_list.append(prompt)
output_strs.append(generated_text)

return output_strs
return prompt_list, output_strs

def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.
Expand Down
19 changes: 13 additions & 6 deletions opencompass/openicl/icl_inferencer/icl_base_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,20 @@ def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, Path(save_dir) / filename)

def save_results(self, origin_prompt, prediction, idx, gold=None):
self.results_dict[str(idx)] = {
'origin_prompt': origin_prompt,
'prediction': prediction,
}
def save_results(self,
origin_prompt,
prediction,
idx,
api_prompts=None,
gold=None):
results = {}
if api_prompts:
results['api_prompts'] = api_prompts
results['origin_prompt'] = origin_prompt
results['prediction'] = prediction
if gold:
self.results_dict[str(idx)]['gold'] = gold
results['gold'] = gold
self.results_dict[str(idx)] = results


class PPLInferencerOutputHandler:
Expand Down
13 changes: 10 additions & 3 deletions opencompass/openicl/icl_inferencer/icl_gen_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,25 @@ def inference(self,
results = self.model.generate_from_template(
entry, max_out_len=self.max_out_len, **extra_gen_kwargs)
generated = results
if isinstance(generated, tuple):
api_prompts_list = parsed_entries
prompts, generated = generated
else:
api_prompts_list = [None] * len(generated)
prompts = parsed_entries

num_return_sequences = getattr(self.model, 'generation_kwargs',
{}).get('num_return_sequences', 1)
# 5-3. Save current output
for prompt, prediction, gold in zip(
parsed_entries, batched(generated, num_return_sequences),
golds):
for api_prompts, prompt, prediction, gold in zip(
api_prompts_list, prompts,
batched(generated, num_return_sequences), golds):
if num_return_sequences == 1:
prediction = prediction[0]
output_handler.save_results(prompt,
prediction,
index,
api_prompts=api_prompts,
gold=gold)
index = index + 1

Expand Down
Loading