-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathxgptscore.py
115 lines (108 loc) · 4.32 KB
/
xgptscore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import json
import logging
from .process import MODE_PROCESS_MAP
from .process_utils import XPGTItem, truncate_items, get_query_messages
from .openai_utils import openai_completions, _chatml_to_prompt
from typing import List, Union
from dacite import from_dict
from pathlib import Path
from functools import partial
def xgptscore(
items: List[Union[XPGTItem, dict]],
mode: str,
model_name: str,
num_workers: int = None,
batch_size: int = None,
**kwargs,
):
config_path = os.path.join(os.path.dirname(
__file__), f"mode_configs/{mode}.json")
config_path = Path(config_path)
if not config_path.exists():
logging.warning(
f"Config file {config_path} does not exist. Use default config.")
config_path = config_path.with_name("default.json")
with open(config_path, "r") as f:
config = json.load(f)
config.update(kwargs)
if "max_lengths" in config:
items = truncate_items(items, config["max_lengths"])
if isinstance(items[0], dict):
items = [from_dict(data_class=XPGTItem, data=item) for item in items]
process_func = MODE_PROCESS_MAP[mode]
if "process_kwargs" in config:
process_func = partial(process_func, **config["process_kwargs"])
process_results = list(map(process_func, items))
total_round = len([x for x in process_results[0] if x['do_query']])
logging.warning(f"Total chat rounds: {total_round}")
logging.warning(f"Total chat messages: {len(items)}")
# query and process
round = 0
queried_messages = [[] for _ in range(len(items))]
total_price = 0
total_time = 0
round_completions = []
while True:
round += 1
logging.warning(f"Processing chat round {round}/{total_round}")
query_messages = list(
map(get_query_messages, process_results, queried_messages))
query_messages, postprocess_funcs = list(zip(*query_messages))
chatml_prompts = list(map(_chatml_to_prompt, query_messages))
openai_results = openai_completions(
chatml_prompts,
model_name=model_name,
num_procs=num_workers,
batch_size=batch_size,
**config['decoding'],
)
completions = openai_results['completions']
total_price += sum(openai_results['price_per_example'])
total_time += sum(openai_results['time_per_example'])
logging.warning(f"Round {round} price: {total_price}$")
logging.warning(f"Round {round} time: {total_time}")
postprocess_completions = [postprocess_funcs[idx](
completion) for idx, completion in enumerate(completions)]
round_completions.append(postprocess_completions)
for idx, completion in enumerate(completions):
queried_messages[idx] = query_messages[idx] + \
[{"role": "assistant", "content": completion}
] # add the assistant response
if round == total_round:
_query_messages = list(
map(get_query_messages, process_results, queried_messages))
assert all([x is None for x in _query_messages]
), "All messages should be queried"
break
logging.warning(f"Total price: {total_price}$")
logging.warning(f"Total time: {total_time}")
logging.warning(f"Total time per example: {total_time / len(items)}")
round_completions = list(zip(*round_completions))
return dict(
round_completions=round_completions,
messages_records=queried_messages,
)
"""
Example Usage:
task = "translation"
with open("example.json", "r") as f:
items = json.load(f)
xgptitems = []
for item in items:
for cand in item['candidates']:
xgptitems.append(XPGTItem(
task=task,
instruction=item['instruction'],
input=item['input'],
ref_output=item['output'],
hypo_output=cand['text']
))
result = xgptscore(xgptitems, "ea", "ChatGPT")
idx = 0
for item in items:
for cand in item['candidates']:
cand['responses'] = result['round_completions'][idx]
cand['messages_records'] = result['messages_records'][idx]
json.dump(items, open("example_result.json", "w"), indent=4, ensure_ascii=False)
"""