-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathextract_answers.py
97 lines (78 loc) · 3.25 KB
/
extract_answers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import pandas as pd
from typing import Any
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig, parse
def parse_args():
parser = argparse.ArgumentParser(description='Extract and evaluate answers using sympy')
parser.add_argument('--input_csv', type=str, required=True, help='Path to input CSV file containing model outputs')
parser.add_argument('--output_csv', type=str, required=True, help='Path to output CSV file for extracted answers')
return parser.parse_args()
def load_csv_data(csv_path: str) -> pd.DataFrame:
"""Load and validate CSV data."""
try:
df = pd.read_csv(csv_path)
required_columns = ['answer']
if not all(col in df.columns for col in required_columns):
raise ValueError(f"CSV must contain columns: {required_columns}")
return df
except Exception as e:
raise Exception(f"Error loading CSV file: {str(e)}")
def serialize_sympy_object(obj: Any) -> str:
"""Convert sympy object to string representation."""
if obj is None:
return ""
try:
if isinstance(obj, (list, tuple)):
return ", ".join(str(x) if x is not None else "" for x in obj)
return str(obj)
except Exception as e:
return f"Error: {str(e)}"
def process_answers(df: pd.DataFrame) -> pd.DataFrame:
"""Process each answer through the sympy extraction workflow."""
results = []
# Set up extraction config and get regexes
extraction_target = (ExprExtractionConfig(), LatexExtractionConfig())
for _, row in df.iterrows():
try:
# Extract answer using regexes
extracted = parse(row['answer'], extraction_config=extraction_target)
feedback = None
extracted_answer = None
if len(extracted) == 2:
extracted_answer = extracted[0]
feedback = extracted[1]
elif len(extracted) == 1:
extracted_answer = extracted[0]
else:
feedback = "No valid extraction found"
extracted_answer = None
result = {
'original_answer': row['answer'],
'extracted_answer': serialize_sympy_object(extracted_answer),
'extracted_feedback': feedback,
'extraction_success': extracted_answer is not None
}
# Copy any other columns from input
for col in df.columns:
if col != 'answer':
result[col] = row[col]
results.append(result)
except Exception as e:
results.append({
'original_answer': row['answer'],
'extracted_answer': '',
'extraction_success': False,
'error': str(e)
})
return pd.DataFrame(results)
def main():
args = parse_args()
# Load input CSV
input_df = load_csv_data(args.input_csv)
# Process answers and extract sympy objects
results_df = process_answers(input_df)
# Save results to output CSV
results_df.to_csv(args.output_csv, index=False)
print(f"Results saved to {args.output_csv}")
if __name__ == "__main__":
main()