diff --git a/src/infer/inference.py b/src/infer/inference.py index 4eaaea3..77b275f 100644 --- a/src/infer/inference.py +++ b/src/infer/inference.py @@ -1,5 +1,7 @@ """ -This file is based on: https://github.com/microsoft/ProphetNet/tree/master/CRITIC +This script support vllm batch inference with cot/pal/tora prompt. +Also sopport inference of fine-tuned models like WizardMath/ToRA. +Code based on: https://github.com/microsoft/ProphetNet/tree/master/CRITIC """ import random import os @@ -19,8 +21,10 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--data_name", default="gsm8k", type=str) + parser.add_argument("--data_dir", default="./data", type=str) parser.add_argument("--model_name_or_path", default="gpt-4", type=str) - parser.add_argument("--prompt_type", default="pal", type=str) + parser.add_argument("--output_dir", default="./output", type=str) + parser.add_argument("--prompt_type", default="tora", type=str) parser.add_argument("--split", default="test", type=str) parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data parser.add_argument("--seed", default=0, type=int) @@ -28,7 +32,8 @@ def parse_args(): parser.add_argument("--end", default=-1, type=int) parser.add_argument("--temperature", default=0, type=float) parser.add_argument("--n_sampling", default=1, type=int) - parser.add_argument("--top_p", default=0.95, type=float) + parser.add_argument("--top_p", default=1, type=float) + parser.add_argument("--max_tokens_per_call", default=1024, type=int) parser.add_argument("--shuffle", action="store_true") parser.add_argument("--use_train_prompt_format", action="store_true") args = parser.parse_args() @@ -36,12 +41,14 @@ def parse_args(): return args -def main(args): - examples = load_data(args.data_name, args.split) +def prepare_data(args): + examples = load_data(args.data_name, args.split, args.data_dir) # sample `num_test_sample` from dataset if args.num_test_sample > 0: examples = random.sample(examples, args.num_test_sample) + elif args.num_test_sample == -1: + args.num_test_sample = len(examples) # shuffle if args.shuffle: @@ -53,19 +60,18 @@ def main(args): args.end = len(examples) examples = examples[args.start:args.end] - # get out_file + # get out_file name dt_string = datetime.now().strftime("%m-%d_%H-%M") model_name = "/".join(args.model_name_or_path.split("/")[-2:]) - file_prompt_type = args.prompt_type.replace("program_only", "tora") - out_file_prefix = f'{args.split}_{file_prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}' - out_file = f'outputs/{model_name}/{args.data_name}/{out_file_prefix}_s{args.start}_e{args.end}_{dt_string}.jsonl' - os.makedirs(f'outputs/{model_name}/{args.data_name}', exist_ok=True) + out_file_prefix = f'{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}' + out_file = f'{args.output_dir}/{model_name}/{args.data_name}/{out_file_prefix}_s{args.start}_e{args.end}_{dt_string}.jsonl' + os.makedirs(f'{args.output_dir}/{model_name}/{args.data_name}', exist_ok=True) - # all files in the output folder - processed_files = [f for f in os.listdir(f"outputs/{model_name}/{args.data_name}/") if f.endswith(".jsonl") and f.startswith(out_file_prefix)] + # load all processed samples + processed_files = [f for f in os.listdir(f"{args.output_dir}/{model_name}/{args.data_name}/") if f.endswith(".jsonl") and f.startswith(out_file_prefix)] processed_samples = [] for f in processed_files: - processed_samples.extend(list(load_jsonl(f"outputs/{model_name}/{args.data_name}/{f}"))) + processed_samples.extend(list(load_jsonl(f"{args.output_dir}/{model_name}/{args.data_name}/{f}"))) # dedepulicate processed_samples = {sample['idx']: sample for sample in processed_samples} @@ -76,9 +82,13 @@ def main(args): print(f"Idx {args.start} - {args.end}: Remain {len(examples)}/{total_examples} samples.") if len(examples) == 0: pass - # return else: print(examples[0]) + return examples, processed_samples, out_file + + +def main(args): + examples, processed_samples, out_file = prepare_data(args) # init python executor if "pal" in args.prompt_type: @@ -102,7 +112,8 @@ def main(args): sample = {'idx': idx, 'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans, 'prompt': full_prompt} # add remain fields - for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 'ques_type', 'ans_type']: + for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 'ques_type', \ + 'ans_type', 'answer_type', 'dataset', 'subfield', 'filed', 'theorem', 'answer']: if key in example: sample[key] = example[key] samples.append(sample) @@ -119,7 +130,7 @@ def main(args): end_prompts = [] max_func_call = 1 if args.prompt_type in ['cot', 'pal'] else 4 - stop_tokens = ["", "---", "```output"] + stop_tokens = ["", "```output"] if args.prompt_type in ['cot']: stop_tokens.append("\n\n") @@ -140,7 +151,7 @@ def main(args): outputs = llm.generate(prompts, SamplingParams( temperature=args.temperature, top_p=args.top_p, - max_tokens=1024, + max_tokens=args.max_tokens_per_call, n=1, stop=stop_tokens, )) @@ -158,12 +169,12 @@ def main(args): if args.prompt_type == "pal": remain_prompts.append((i, query)) if "```python" in output: - output = extract_program(output) + output = extract_program(query) remain_codes.append(output) elif args.prompt_type == "cot": end_prompts.append((i, query)) elif ("boxed" not in output and output.endswith("```")): - program = extract_program(output) + program = extract_program(query) remain_prompts.append((i, query)) remain_codes.append(program) else: @@ -173,13 +184,8 @@ def main(args): remain_results = executor.batch_apply(remain_codes) for k in range(len(remain_prompts)): i, query = remain_prompts[k] - pred, report = remain_results[k] - pred, report = str(pred).strip(), str(report).strip() - if len(pred) > 100: - pred = pred[:50] + "..." + pred[-50:] - if len(report) > 100: - report = report[:50] + "..." + report[-50:] - exec_result = pred if pred else report + res, report = remain_results[k] + exec_result = res if res else report if "pal" in args.prompt_type: exec_result = "\\boxed{" + exec_result + "}" exec_result = f"\n```output\n{exec_result}\n```\n" diff --git a/src/infer/inference_api.py b/src/infer/inference_api.py new file mode 100644 index 0000000..936687f --- /dev/null +++ b/src/infer/inference_api.py @@ -0,0 +1,208 @@ +""" +This script support LLM API inference with cot/pal/tora prompt. +It can be used to generate tora corpus. +Code based on: https://github.com/microsoft/ProphetNet/tree/master/CRITIC +""" +import json +import random +import os +import pprint +import re +import argparse +import time +from datetime import datetime +from tqdm import tqdm +from sympy.printing.pretty import pretty + +from api.llm_api import llm_api # use your own API like OpenAI API +from utils.python_executor import PythonExecutor +from utils.utils import * +from utils.parser import * +# from utils.trajectory import * +from eval.grader import * +from utils.data_loader import load_data +from infer.inference import prepare_data + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--data_name", default="gsm8k", type=str) + parser.add_argument("--data_dir", default="./data", type=str) + parser.add_argument("--model_name_or_path", default="gpt-4", type=str) + parser.add_argument("--output_dir", default="./output", type=str) + parser.add_argument("--prompt_type", default="tora", type=str) + parser.add_argument("--split", default="test", type=str) + parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data + parser.add_argument("--seed", default=0, type=int) + parser.add_argument("--start", default=0, type=int) + parser.add_argument("--end", default=-1, type=int) + parser.add_argument("--temperature", default=0, type=float) + parser.add_argument("--n_sampling", default=1, type=int) + parser.add_argument("--top_p", default=1, type=float) + parser.add_argument("--max_tokens_per_call", default=1024, type=int) + parser.add_argument("--shuffle", action="store_true") + parser.add_argument("--use_train_prompt_format", action="store_true") + args = parser.parse_args() + args.top_p = 1 if args.temperature == 0 else args.top_p # top_p must be 1 when using greedy sampling (vllm) + return args + + +def api_with_func_call(engine, prompt, max_tokens, temperature, n, top_p, executor, max_func_call=4, verbose=False): + if n > 1: + assert temperature > 0 + + if verbose: + print("\n======= API with function call (START) =======") + + next_batch_queries = [""] * n + end_queries = [] + for i in range(max_func_call): + batch_outputs = [] + batch_queries = next_batch_queries + if len(batch_queries) == 0: + break + # get all outputs + # support batch inference when n > 1 + if i == 0: + results = llm_api( + engine=engine, prompt=prompt + batch_queries[0], max_tokens=max_tokens, temperature=temperature, + n=n, top_p=top_p, stop=["```output\n", "---"], + ) + batch_outputs.extend(results) + else: + for k, query in enumerate(batch_queries): + print("Call {} / {}".format(k+1, len(batch_queries))) + results = llm_api( + engine=engine, prompt=prompt + query, max_tokens=max_tokens, temperature=temperature, + n=1, top_p=top_p, stop=["```output\n", "---"], + ) + batch_outputs.append(results[0]) + + # process all outputs + next_batch_queries = [] + for query, output in zip(batch_queries, batch_outputs): + output = output.rstrip() + query += output + if verbose: + print("\n", "-" * 20) + print(output, end="") + if "boxed" not in output and output.endswith("```"): + program = extract_program(query) + prediction, report = executor.apply(program) + exec_result = prediction if prediction else report + exec_result = f"\n```output\n{exec_result.strip()}\n```\n" + query += exec_result + if verbose: + print(exec_result, end="") + # not end + if i == max_func_call - 1: + query += "\nReach max function call limit." + next_batch_queries.append(query) + else: + end_queries.append(query) + + end_queries.extend(next_batch_queries) + return end_queries + + + +def main(args): + examples, processed_samples, out_file = prepare_data(args) + # init python executor + if "pal" in args.prompt_type: + executor = PythonExecutor(get_answer_expr='solution()') + else: + executor = PythonExecutor(get_answer_from_stdout=True) + + writer = open(out_file, 'w') + correct, wrong = 0, 0 + + for example in tqdm(examples, total=len(examples)): + idx = example['idx'] + + # parse question and answer + example['question'] = parse_question(example, args.data_name) + gt_cot, gt_ans = parse_ground_truth(example, args.data_name) + full_prompt = construct_prompt(args, example) + + # call LLM, return list + if "tora" in args.prompt_type: + results = api_with_func_call( + engine=args.model_name_or_path, + prompt=full_prompt, + max_tokens=args.max_tokens_per_call, + temperature=args.temperature, + n=args.n_sampling, + top_p=args.top_p, + executor=executor, + ) + else: + stop_tokens = ["", "---", "```output"] + if args.prompt_type in ['cot']: + stop_tokens.append("\n\n") + results = llm_api( + engine=args.model_name_or_path, + prompt=full_prompt, + max_tokens=args.max_tokens_per_call, + temperature=args.temperature, + n=args.n_sampling, + top_p=args.top_p, + stop=stop_tokens, + ) + # deal with error + if results == ['error']: + print(">>> Error API call") + continue + print("Get {} results".format(len(results))) + # get prediction + predictions = [] + reports = [] + for r in results: + pred, report = run_execute(executor, r, args.prompt_type, execute=True) + predictions.append(pred) + reports.append(report) + print("Executed {} results".format(len(predictions))) + + scores = [math_equal(p, gt_ans, timeout=True) for p in predictions] + + is_correct = scores[0] + if is_correct: + correct += 1 + else: + wrong += 1 + + sample = {'idx': idx, 'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans, + 'pred': predictions, 'score': scores} + + if args.prompt_type == "cot": + sample.update({'code': results}) + elif "tora" in args.prompt_type or "pal" in args.prompt_type: + sample.update({'report': reports, 'code': results}) + # add remain fields + for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 'ques_type', \ + 'ans_type', 'answer_type', 'dataset', 'subfield', 'filed', 'theorem', 'answer']: + if key in example: + sample[key] = example[key] + + print(idx) + show_sample(sample) + if correct + wrong > 0: + print("Avg Acc:", correct / (correct + wrong)) + print() + + try: + writer.write(json.dumps(sample) + '\n') + writer.flush() + except: + print(">>> Error writing to file") + continue + + writer.close() + print() + print(correct / (correct + wrong)) + + +if __name__ == "__main__": + args = parse_args() + set_seed(args.seed) + main(args) \ No newline at end of file diff --git a/src/scripts/infer.sh b/src/scripts/infer.sh index cd881bf..120ee2b 100644 --- a/src/scripts/infer.sh +++ b/src/scripts/infer.sh @@ -1,22 +1,22 @@ set -ex -# MODEL_NAME_OR_PATH="llm-agents/tora-code-34b-v1.0" -MODEL_NAME_OR_PATH="llm-agents/tora-70b-v1.0" +MODEL_NAME_OR_PATH="llm-agents/tora-code-34b-v1.0" +# MODEL_NAME_OR_PATH="llm-agents/tora-70b-v1.0" # DATA_LIST = ['math', 'gsm8k', 'gsm-hard', 'svamp', 'tabmwp', 'asdiv', 'mawps'] -DATA="math" -# DATA="gsm8k" +DATA_NAME="math" +# DATA_NAME="gsm8k" SPLIT="test" PROMPT_TYPE="tora" NUM_TEST_SAMPLE=-1 -CUDA_VISIBLE_DEVICES=2,3 TOKENIZERS_PARALLELISM=false \ -python -m infer.inference \ +CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \ +python -um infer.inference \ --model_name_or_path ${MODEL_NAME_OR_PATH} \ ---data ${DATA} \ +--data_name ${DATA_NAME} \ --split ${SPLIT} \ --prompt_type ${PROMPT_TYPE} \ --use_train_prompt_format \ @@ -24,6 +24,6 @@ python -m infer.inference \ --seed 0 \ --temperature 0 \ --n_sampling 1 \ ---top_p 0.95 \ +--top_p 1 \ --start 0 \ --end -1 \ diff --git a/src/scripts/infer_api.sh b/src/scripts/infer_api.sh new file mode 100644 index 0000000..ff03570 --- /dev/null +++ b/src/scripts/infer_api.sh @@ -0,0 +1,23 @@ +set -ex + +MODEL_NAME_OR_PATH="gpt-4" +DATA_NAME="math" + +SPLIT="train" +PROMPT_TYPE="tora" +NUM_TEST_SAMPLE=-1 + +python -um infer.inference_api \ + --model_name_or_path $MODEL_NAME_OR_PATH \ + --output_dir ../outputs/ \ + --data_name $DATA_NAME \ + --split $SPLIT \ + --prompt_type $PROMPT_TYPE \ + --num_test_sample $NUM_TEST_SAMPLE \ + --seed 0 \ + --temperature 0 \ + --n_sampling 1 \ + --top_p 1 \ + --start 0 \ + --end -1 \ + diff --git a/src/utils/data_loader.py b/src/utils/data_loader.py index 2c4efa5..95010bb 100644 --- a/src/utils/data_loader.py +++ b/src/utils/data_loader.py @@ -2,17 +2,19 @@ import json import random from datasets import load_dataset, Dataset, concatenate_datasets -from utils.utils import load_jsonl +from utils.utils import load_jsonl, lower_keys -def load_data(data_name, split): - data_file = f"data/{data_name}/{split}.json" +def load_data(data_name, split, data_dir='./data'): + data_file = f"{data_dir}/{data_name}/{split}.jsonl" if os.path.exists(data_file): examples = list(load_jsonl(data_file)) else: if data_name == "math": - dataset = load_dataset("competition_math", split=split, name="main", cache_dir="data_name/temp") + dataset = load_dataset("competition_math", split=split, name="main", cache_dir=f"{data_dir}/temp") + elif data_name == "theorem-qa": + dataset = load_dataset("wenhu/TheoremQA", split=split) elif data_name == "gsm8k": - dataset = load_dataset(data_name, split=split, name="main") + dataset = load_dataset(data_name, split=split) elif data_name == "gsm-hard": dataset = load_dataset("reasoning-machines/gsm-hard", split="train") elif data_name == "svamp": @@ -26,7 +28,7 @@ def load_data(data_name, split): examples = [] # four sub-tasks for data_name in ["singleeq", "singleop", "addsub", "multiarith"]: - sub_examples = list(load_jsonl(f"data_name/mawps/{data_name}.jsonl")) + sub_examples = list(load_jsonl(f"{data_dir}/mawps/{data_name}.jsonl")) for example in sub_examples: example['type'] = data_name examples.extend(sub_examples) @@ -36,7 +38,7 @@ def load_data(data_name, split): dataset = dataset.select(random.sample(range(len(dataset)), 1000)) elif data_name == "tabmwp": examples = [] - with open(f"data_name/tabmwp/tabmwp_{split}.json", "r") as f: + with open(f"{data_dir}/tabmwp/tabmwp_{split}.json", "r") as f: data_dict = json.load(f) examples.extend(data_dict.values()) dataset = Dataset.from_list(examples) @@ -45,7 +47,7 @@ def load_data(data_name, split): examples = [] for data_name in ["reasoning_about_colored_objects", "penguins_in_a_table",\ "date_understanding", "repeat_copy_logic", "object_counting"]: - with open(f"data_name/bbh/bbh/{data_name}.json", "r") as f: + with open(f"{data_dir}/bbh/bbh/{data_name}.json", "r") as f: sub_examples = json.load(f)["examples"] for example in sub_examples: example['type'] = data_name @@ -54,15 +56,16 @@ def load_data(data_name, split): else: raise NotImplementedError(data_name) - if 'idx' not in dataset.column_names: - dataset = dataset.map(lambda x, i: {'idx': i, **x}, with_indices=True) - - os.makedirs(f"data_name/{data_name}", exist_ok=True) - dataset.to_json(data_file) examples = list(dataset) + examples = [lower_keys(example) for example in examples] + dataset = Dataset.from_list(examples) + os.makedirs(f"{data_dir}/{data_name}", exist_ok=True) + dataset.to_json(data_file) + + # add 'idx' in the first column + if 'idx' not in examples[0]: + examples = [{'idx': i, **example} for i, example in enumerate(examples)] # dedepulicate & sort - examples = {example['idx']: example for example in examples} - examples = list(examples.values()) examples = sorted(examples, key=lambda x: x['idx']) return examples \ No newline at end of file diff --git a/src/utils/python_executor.py b/src/utils/python_executor.py old mode 100755 new mode 100644 index ad65727..5ea30a1 --- a/src/utils/python_executor.py +++ b/src/utils/python_executor.py @@ -1,10 +1,10 @@ +import os import io import regex import pickle import traceback import copy import datetime -import multiprocessing import dateutil.relativedelta import multiprocess from multiprocess import Pool @@ -94,7 +94,7 @@ def execute( with redirect_stdout(program_io): timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) program_io.seek(0) - result = program_io.readlines()[-1] + result = program_io.read() elif answer_symbol: timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) result = runtime._global_vars[answer_symbol] @@ -104,23 +104,30 @@ def execute( else: timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) result = timeout(timeout_length)(runtime.eval_code)(code[-1]) - exec_info = "Done" + report = "Done" str(result) pickle.dumps(result) # serialization check except: result = '' - exec_info = traceback.format_exc().split('\n')[-2] - return result, exec_info + report = traceback.format_exc().split('\n')[-2] + return result, report def apply(self, code): return self.batch_apply([code])[0] + @staticmethod + def truncate(s, max_length=400): + half = max_length // 2 + if len(s) > max_length: + s = s[:half] + "..." + s[-half:] + return s + def batch_apply(self, batch_code): all_code_snippets = self.process_generation_to_code(batch_code) timeout_cnt = 0 all_exec_results = [] - with ProcessPool(max_workers=min(len(all_code_snippets), multiprocessing.cpu_count())) as pool: + with ProcessPool(max_workers=min(len(all_code_snippets), os.cpu_count())) as pool: executor = partial( self.execute, get_answer_from_stdout=self.get_answer_from_stdout, @@ -157,8 +164,11 @@ def batch_apply(self, batch_code): progress_bar.close() batch_results = [] - for code, (result, exec_info) in zip(all_code_snippets, all_exec_results): - batch_results.append((result, exec_info)) + for code, (res, report) in zip(all_code_snippets, all_exec_results): + # post processing + res, report = str(res).strip(), str(report).strip() + res, report = self.truncate(res), self.truncate(report) + batch_results.append((res, report)) return batch_results diff --git a/src/utils/utils.py b/src/utils/utils.py index 230cf4e..c8c283e 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -36,9 +36,22 @@ def save_jsonl(samples, save_path): print("Saved to", save_path) +def lower_keys(example): + new_example = {} + for key, value in example.items(): + if key != key.lower(): + new_key = key.lower() + new_example[new_key] = value + else: + new_example[key] = value + return new_example + + def load_prompt(data_name, prompt_type): if data_name in ['gsm-hard', 'svamp', 'tabmwp', 'asdiv', 'mawps']: data_name = "gsm8k" + if data_name in ['math-oai']: + data_name = "math" if prompt_type in ['platypus_fs', 'wizard_zs']: prompt_type = "cot" prompt_path = "./prompts/{}/{}.md".format(prompt_type, data_name) @@ -56,7 +69,7 @@ def construct_prompt(args, example): demo_prompt = load_prompt(args.data_name, args.prompt_type) if args.use_train_prompt_format: full_prompt = f"<|user|>\n{example['question']}\n<|assistant|>\n" - elif "tora" in args.prompt_type or "pot" in args.prompt_type: + elif "tora" in args.prompt_type: context = f"Question: {example['question']}\n\nSolution:" full_prompt = demo_prompt + context elif args.prompt_type in ["direct", "cot"]: @@ -83,19 +96,33 @@ def construct_prompt(args, example): raise NotImplementedError(args.prompt_type) return full_prompt -def show_sample(sample): +key_map = { + "gt": "Ground Truth", + "pred": "Prediction", + "gt_cot": "Reference CoT", + "score": "Score", +} + +def show_sample(sample, print_all_preds=False): print("=="*20) - print("idx:", sample['idx']) - for key in ["type", "level"]: + for key in ["idx", "type", "level", "dataset"]: if key in sample: - print("{}: {}".format(key, sample[key])) - print("question:", sample['question']) + # capitalize + print("{}: {}".format(key[0].upper() + key[1:], sample[key])) + print("Question:", repr(sample['question'])) if 'code' in sample: - for code in sample['code']: - print('-'*20) - print("code:", code) - print("execution", sample['report']) - for key in ["pred", "gt", "score", "unit", "gt_cot"]: + if print_all_preds: + for code in sample['code']: + print('-'*20) + print("code:", code) + print("Execution:", sample['report']) + else: + print("Solution:\n", sample['code'][0]) + print("Execution:", sample['report'][0]) + if 'pred' in sample: + print("Prediction:", repr(sample['pred'][0])) + for key in ["gt", "score", "unit", "gt_cot"]: if key in sample: - print("{}: {}".format(key, sample[key])) - print() \ No newline at end of file + _key = key_map.get(key, key) + print("{}: {}".format(_key, repr(sample[key]))) + print()