-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_cot.py
189 lines (156 loc) · 7.14 KB
/
run_cot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import json
from typing import Optional
from tqdm import tqdm
from fire import Fire
from agent import Model
import tiktoken
from utils.data import construct_markdown_table
from utils.execute import markdown_to_df, remove_merged_suffixes
from utils.table import transpose, sort_dataframe
from run_helper import load_dataset, get_cot_prompt, query, check_transpose, check_sort, read_json_file
def main(
model:Optional[str] = "gpt-3.5-turbo-0125", # base model of the agent (for short prompt to save money)
long_model:Optional[str] = "gpt-3.5-turbo-0125", # long model of the agent (only used for long prompt)
provider: str = "openai", # openai, huggingface, vllm
dataset:str = "wtq", # wtq or tabfact
perturbation: str = "none", # none, transpose, shuffle, transpose_shuffle
norm: bool = True, # whether to NORM the table
disable_resort: bool = True, # whether to disable the resort stage in NORM
norm_cache: bool = True, # whether to cache the normalization results so that we can reuse them
sub_sample: bool = True, # whether to only run on the subset sampled data points
resume:int = 0, # resume from the i-th data point
stop_at:int = 1e6, # stop at the i-th data point
self_consistency:int = 10, # how many times to do self consistency
temperature:float=0.8, # temperature for model
log_dir: str = "output/wtq_dp", # directory to store the logs
cache_dir: str = "cache", # directory to store the cache (normalization results)
):
token_content = []
#### create log & cache dir and save config ####
os.makedirs(log_dir, exist_ok=True)
os.makedirs(cache_dir, exist_ok=True)
tokens_path = os.path.join(log_dir, "token.json")
# store the config
config_path = os.path.join(log_dir, "config.json")
with open(config_path, "w") as f:
json.dump({key: value for key, value in locals().items() if key != 'f'}, f, indent=4)
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo-0125")
input_tokens = 0
output_tokens = 0
#### load dataset and cot prompt ####
data = load_dataset(dataset)
cot_prompt = get_cot_prompt(dataset)
#### load the model ####
if model:
model = Model(model, provider=provider)
if long_model:
long_model = Model(long_model, provider=provider)
#### load the cache ####
transpose_cache = read_json_file(os.path.join(cache_dir, "transpose.json"))
resort_cache = read_json_file(os.path.join(cache_dir, "resort.json"))
#### prepare the iterator ####
global_i = 0
break_flag = False
total = sum([len(d['sampled_indices']) for d in data]) if sub_sample else sum([len(d['questions']) for d in data])
pbar = tqdm(total=stop_at if stop_at < total else total)
#### start the loop ####
for table_idx, d in enumerate(data):
if break_flag:
break
index_list = d['sampled_indices'] if sub_sample else range(len(d["questions"]))
# if the table is empty, skip
if len(index_list) == 0:
continue
# load table infos
table_id = d["table_id"]
title = d["title"]
if perturbation == "none":
table = construct_markdown_table(**d["table"])
elif perturbation == "transpose":
table = construct_markdown_table(**d["transposed_table"])
elif perturbation == "shuffle":
table = construct_markdown_table(**d["row_shuffled_table"])
elif perturbation == "transpose_shuffle":
table = construct_markdown_table(**d["row_shuffled_transposed_table"])
df = markdown_to_df(table)
# transpose and sort if necessary
transpose_flag = False
resort_list = []
if norm:
transpose_flag = check_transpose(model, long_model, table, title, table_id, perturbation, transpose_cache, norm_cache, cache_dir)
if transpose_flag:
transposed_df = transpose(df)
df = remove_merged_suffixes(transposed_df)
if not disable_resort:
resort_list = check_sort(model, long_model, df, title, table_id, perturbation, resort_cache, norm_cache, cache_dir)
df = sort_dataframe(df, resort_list)
# reset the table
table = df.to_markdown()
for idx in index_list:
if global_i < resume:
global_i += 1
pbar.update(1)
continue
elif global_i >= stop_at:
break_flag = True
break
question = d["questions"][idx]
answer = d["answers"][idx]
question_id = d["ids"][idx]
prompt = cot_prompt.replace("[TABLE]", table)\
.replace("[QUESTION]", question)\
.replace("[TITLE]", title)\
.strip()
text, response = query(model, long_model, prompt, temperature, self_consistency)
tmp_count1 = len(tokenizer.encode(str(prompt)))
tmp_count2 = len(tokenizer.encode(str(text)))
token_content.append({
"idx": global_i,
"type": "dp",
"input tokens": tmp_count1,
"output tokens": tmp_count2
})
input_tokens += tmp_count1
output_tokens += tmp_count2
# print("第idx: {} 个问题结束")
log_path = os.path.join(log_dir, "log", f"{global_i}.txt")
os.makedirs(os.path.dirname(log_path), exist_ok=True)
with open(log_path, "w") as f:
f.write("===================Title===================\n")
f.write(title + "\n")
f.write("===================Table===================\n")
f.write(table + "\n")
f.write("===================Question===================\n")
f.write(question + "\n")
f.write("===================Text===================\n")
f.write(text if isinstance(text, str) else "\n".join(text))
f.write("\n")
f.write("===================Answer===================\n")
f.write(",".join(answer) if isinstance(answer, list) else str(answer))
f.write("\n")
res = {
"idx": global_i,
"answer": answer,
"text": text,
"transpose": transpose_flag,
"resort": resort_list,
"question_id": question_id,
"table_id": table_id,
"title": title,
"table": table,
"question": question,
}
with open(os.path.join(log_dir, "result.jsonl"), "a") as f:
json.dump(res, f)
f.write("\n")
global_i += 1
pbar.update(1)
token_content.append({
"input_tokens": input_tokens,
"output_tokens": output_tokens
})
with open(tokens_path, "w") as f:
json.dump(token_content, f, indent=4)
if __name__ == "__main__":
Fire(main)