From 2ed193adf3d1d23c2d9811c4692f634aafd7f9e4 Mon Sep 17 00:00:00 2001 From: Shawn/Yuxuan Tong Date: Sun, 15 Sep 2024 00:57:42 +0800 Subject: [PATCH] Allow inaccurate majority voting --- README.md | 64 +--------- nbs/index.ipynb | 321 +++--------------------------------------------- symeval/core.py | 21 ++-- 3 files changed, 37 insertions(+), 369 deletions(-) diff --git a/README.md b/README.md index c69ef51..2e0c6ce 100644 --- a/README.md +++ b/README.md @@ -3,31 +3,6 @@ -``` python -``` - - The autoreload extension is already loaded. To reload it, use: - %reload_ext autoreload - - [autoreload of symeval.core failed: Traceback (most recent call last): - File "/ssddata/tongyx/miniconda3/envs/dart-math/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check - superreload(m, reload, self.old_objects) - File "/ssddata/tongyx/miniconda3/envs/dart-math/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload - module = reload(module) - ^^^^^^^^^^^^^^ - File "/ssddata/tongyx/miniconda3/envs/dart-math/lib/python3.11/importlib/__init__.py", line 169, in reload - _bootstrap._exec(spec, module) - File "", line 621, in _exec - File "", line 940, in exec_module - File "", line 241, in _call_with_frames_removed - File "/ssddata/tongyx/projects/symeval/symeval/core.py", line 228, in - class EvaluatorBatchBase(EvaluatorBase): - File "/ssddata/tongyx/projects/symeval/symeval/core.py", line 250, in EvaluatorBatchBase - ) -> Tuple[List[str], List[bool]]: - ~~~~~^^^^^^^^^^^^^^^^^^^^^^^ - TypeError: type 'Tuple' is not subscriptable - ] - ## Installation For common users/developers, please just run the following command the @@ -82,7 +57,8 @@ target="_blank" style="float:right; font-size:smaller">source > EvaluatorMathBatch (strict_extract:bool=True, > include_percentage:bool=True, rel_tol:float=1e-09, > abs_tol:float=1e-08, percent_rel_tol:float=0.001, -> ascii_only:bool=True, timeout:int=5) +> ascii_only:bool=True, timeout:int=5, n_procs:int=2, +> use_tqdm:bool=True) *Batch evaluator for math problems, capable of extracting answer segment from complex resp and processing various mathematical objects @@ -97,7 +73,9 @@ text (e.g. bool values).* | abs_tol | float | 1e-08 | The absolute tolerance for numerical comparisons. Necessary for precision issues. | | percent_rel_tol | float | 0.001 | The absolute tolerance for percentage comparisons. | | ascii_only | bool | True | Only allowing ASCII characters | -| timeout | int | 5 | | +| timeout | int | 5 | The timeout for each evaluation. | +| n_procs | int | 2 | | +| use_tqdm | bool | True | | #### Accurately Extracting Answer Strings @@ -112,8 +90,6 @@ can: math_evaluator.extract_ans("Therefore, $1+1=\\boxed{2}$.") ``` - '2' - ``` python # Answer around "answer" math_evaluator.extract_ans( @@ -121,8 +97,6 @@ math_evaluator.extract_ans( ) ``` - '6' - ``` python # Use the last number by default math_evaluator.extract_ans( @@ -131,15 +105,11 @@ math_evaluator.extract_ans( # More cases ... ``` - '' - ``` python # Normalize fraction math_evaluator.extract_ans("The answer is 1/2") ``` - '\\frac{1}{2}' - ``` python # Normalize pmatrix math_evaluator.extract_ans( @@ -148,8 +118,6 @@ math_evaluator.extract_ans( # More cases ... ``` - '\\begin{array}3\\\\frac{\\pi}{2}\\end{array}' - #### Correctly Processing Various Mathematical Objects / Special Text [`EvaluatorMath`](https://tongyx361.github.io/symeval/core.html#evaluatormath), @@ -165,14 +133,10 @@ calculation, is able to correctly process math_evaluator.eq("x+y", "y+x") == True # Expression ``` - True - ``` python math_evaluator.eq("\\frac{1}{2}", "0.5") == True # LaTeX ``` - True - ``` python math_evaluator.eq( "\\begin{array}1\\\\2\\end{array}", @@ -180,21 +144,15 @@ math_evaluator.eq( ) # Matrix (Vector) ``` - True - ``` python math_evaluator.eq("{1,2}", "{2,1}", compare_sets=True) # Set ``` - True - ``` python math_evaluator.eq("no", "false") # Bool # More mathematical objects and special texts ... ``` - True - More test cases:
@@ -244,8 +202,6 @@ test_eq(math_evaluator.eq("\\frac{2003}{2}", "1001"), False) math_evaluator.get_maj_answers(["", "", "1", "2", "2", "3", "3", "3"]) ``` - ['', '', '1', '1', '2', '2', '2', '3'] - ### Parsing LaTeX #### Interval @@ -258,20 +214,14 @@ from symeval import latex2sympy_interval latex2sympy_interval("(-11,-10)\\cup\\{-\\sqrt{110}\\}") ``` -$\displaystyle \left(-11, -10\right)$ - ``` python latex2sympy_interval("(-\\infty, 0) \\cup (0, \\infty)") ``` -$\displaystyle \left(-\infty, 0\right) \cup \left(0, \infty\right)$ - ``` python latex2sympy_interval("(a+b,b]") ``` -$\displaystyle \left(a + b, b\right]$ - #### Matrix / Vector ``` python @@ -284,16 +234,12 @@ math_evaluator = EvaluatorMathBatch() math_evaluator.latex2matrix(r"\sqrt{400\cos^2(9\pi/44)},\frac{\pi}{4}") ``` -$\displaystyle \left[\begin{matrix}\sqrt{400 \cos^{2}{\left(\frac{9 \pi}{44} \right)}} & \frac{\pi}{4}\end{matrix}\right]$ - ``` python math_evaluator.latex2matrix( r"\begin{pmatrix} \frac{1}{2} & 0 & -\frac{\sqrt{3}}{2} \\ 0 & 1 & 0 \\ \frac{\sqrt{3}}{2} & 0 & \frac{1}{2} \end{pmatrix}" ) ``` -$\displaystyle \left[\begin{matrix}\frac{1}{2} & 0 & - \frac{\sqrt{3}}{2}\\0 & 1 & 0\\\frac{\sqrt{3}}{2} & 0 & \frac{1}{2}\end{matrix}\right]$ - ``` python test_eq( math_evaluator.latex2matrix("\\begin{pmatrix}-18\\\\-49\\\\96\\end{pmatrix}"), diff --git a/nbs/index.ipynb b/nbs/index.ipynb index ad75366..276f3d9 100644 --- a/nbs/index.ipynb +++ b/nbs/index.ipynb @@ -4,40 +4,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[autoreload of symeval.core failed: Traceback (most recent call last):\n", - " File \"/ssddata/tongyx/miniconda3/envs/dart-math/lib/python3.11/site-packages/IPython/extensions/autoreload.py\", line 276, in check\n", - " superreload(m, reload, self.old_objects)\n", - " File \"/ssddata/tongyx/miniconda3/envs/dart-math/lib/python3.11/site-packages/IPython/extensions/autoreload.py\", line 475, in superreload\n", - " module = reload(module)\n", - " ^^^^^^^^^^^^^^\n", - " File \"/ssddata/tongyx/miniconda3/envs/dart-math/lib/python3.11/importlib/__init__.py\", line 169, in reload\n", - " _bootstrap._exec(spec, module)\n", - " File \"\", line 621, in _exec\n", - " File \"\", line 940, in exec_module\n", - " File \"\", line 241, in _call_with_frames_removed\n", - " File \"/ssddata/tongyx/projects/symeval/symeval/core.py\", line 228, in \n", - " class EvaluatorBatchBase(EvaluatorBase):\n", - " File \"/ssddata/tongyx/projects/symeval/symeval/core.py\", line 250, in EvaluatorBatchBase\n", - " ) -> Tuple[List[str], List[bool]]:\n", - " ~~~~~^^^^^^^^^^^^^^^^^^^^^^^\n", - "TypeError: type 'Tuple' is not subscriptable\n", - "]\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -124,65 +91,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/markdown": [ - "---\n", - "\n", - "[source](https://github.com/tongyx361/symeval/blob/main/symeval/core.py#LNone){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", - "\n", - "### EvaluatorMathBatch\n", - "\n", - "> EvaluatorMathBatch (strict_extract:bool=True,\n", - "> include_percentage:bool=True, rel_tol:float=1e-09,\n", - "> abs_tol:float=1e-08, percent_rel_tol:float=0.001,\n", - "> ascii_only:bool=True, timeout:int=5)\n", - "\n", - "*Batch evaluator for math problems, capable of extracting answer segment from complex resp and processing various mathematical objects\n", - "(e.g. fractions, symbolic expressions, matrices, vectors) and special text (e.g. bool values).*\n", - "\n", - "| | **Type** | **Default** | **Details** |\n", - "| -- | -------- | ----------- | ----------- |\n", - "| strict_extract | bool | True | |\n", - "| include_percentage | bool | True | Whether to include percentage comparisons. |\n", - "| rel_tol | float | 1e-09 | The relative tolerance for numerical comparisons. |\n", - "| abs_tol | float | 1e-08 | The absolute tolerance for numerical comparisons. Necessary for precision issues. |\n", - "| percent_rel_tol | float | 0.001 | The absolute tolerance for percentage comparisons. |\n", - "| ascii_only | bool | True | Only allowing ASCII characters |\n", - "| timeout | int | 5 | |" - ], - "text/plain": [ - "---\n", - "\n", - "[source](https://github.com/tongyx361/symeval/blob/main/symeval/core.py#LNone){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", - "\n", - "### EvaluatorMathBatch\n", - "\n", - "> EvaluatorMathBatch (strict_extract:bool=True,\n", - "> include_percentage:bool=True, rel_tol:float=1e-09,\n", - "> abs_tol:float=1e-08, percent_rel_tol:float=0.001,\n", - "> ascii_only:bool=True, timeout:int=5)\n", - "\n", - "*Batch evaluator for math problems, capable of extracting answer segment from complex resp and processing various mathematical objects\n", - "(e.g. fractions, symbolic expressions, matrices, vectors) and special text (e.g. bool values).*\n", - "\n", - "| | **Type** | **Default** | **Details** |\n", - "| -- | -------- | ----------- | ----------- |\n", - "| strict_extract | bool | True | |\n", - "| include_percentage | bool | True | Whether to include percentage comparisons. |\n", - "| rel_tol | float | 1e-09 | The relative tolerance for numerical comparisons. |\n", - "| abs_tol | float | 1e-08 | The absolute tolerance for numerical comparisons. Necessary for precision issues. |\n", - "| percent_rel_tol | float | 0.001 | The absolute tolerance for percentage comparisons. |\n", - "| ascii_only | bool | True | Only allowing ASCII characters |\n", - "| timeout | int | 5 | |" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "show_doc(EvaluatorMathBatch, title_level=3)" ] @@ -208,18 +117,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'2'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# MATH-style boxed answer\n", "math_evaluator.extract_ans(\"Therefore, $1+1=\\\\boxed{2}$.\")" @@ -229,18 +127,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'6'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Answer around \"answer\"\n", "math_evaluator.extract_ans(\n", @@ -252,18 +139,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "''" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Use the last number by default\n", "math_evaluator.extract_ans(\n", @@ -276,18 +152,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'\\\\frac{1}{2}'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Normalize fraction\n", "math_evaluator.extract_ans(\"The answer is 1/2\")" @@ -297,18 +162,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'\\\\begin{array}3\\\\\\\\frac{\\\\pi}{2}\\\\end{array}'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Normalize pmatrix\n", "math_evaluator.extract_ans(\n", @@ -338,18 +192,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "math_evaluator.eq(\"x+y\", \"y+x\") == True # Expression" ] @@ -358,18 +201,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "math_evaluator.eq(\"\\\\frac{1}{2}\", \"0.5\") == True # LaTeX" ] @@ -378,18 +210,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "math_evaluator.eq(\n", " \"\\\\begin{array}1\\\\\\\\2\\\\end{array}\",\n", @@ -401,18 +222,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "math_evaluator.eq(\"{1,2}\", \"{2,1}\", compare_sets=True) # Set" ] @@ -421,18 +231,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "math_evaluator.eq(\"no\", \"false\") # Bool\n", "# More mathematical objects and special texts ..." @@ -498,18 +297,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['', '', '1', '1', '2', '2', '2', '3']" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "math_evaluator.get_maj_answers([\"\", \"\", \"1\", \"2\", \"2\", \"3\", \"3\", \"3\"])" ] @@ -541,21 +329,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/latex": [ - "$\\displaystyle \\left(-11, -10\\right)$" - ], - "text/plain": [ - "Interval.open(-11, -10)" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "latex2sympy_interval(\"(-11,-10)\\\\cup\\\\{-\\\\sqrt{110}\\\\}\")" ] @@ -564,21 +338,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/latex": [ - "$\\displaystyle \\left(-\\infty, 0\\right) \\cup \\left(0, \\infty\\right)$" - ], - "text/plain": [ - "Union(Interval.open(-oo, 0), Interval.open(0, oo))" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "latex2sympy_interval(\"(-\\\\infty, 0) \\\\cup (0, \\\\infty)\")" ] @@ -587,21 +347,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/latex": [ - "$\\displaystyle \\left(a + b, b\\right]$" - ], - "text/plain": [ - "Interval.Lopen(a + b, b)" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "latex2sympy_interval(\"(a+b,b]\")" ] @@ -628,21 +374,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/latex": [ - "$\\displaystyle \\left[\\begin{matrix}\\sqrt{400 \\cos^{2}{\\left(\\frac{9 \\pi}{44} \\right)}} & \\frac{\\pi}{4}\\end{matrix}\\right]$" - ], - "text/plain": [ - "Matrix([[sqrt(400*cos((9*pi)/44)**2), pi/4]])" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "math_evaluator.latex2matrix(r\"\\sqrt{400\\cos^2(9\\pi/44)},\\frac{\\pi}{4}\")" ] @@ -651,24 +383,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/latex": [ - "$\\displaystyle \\left[\\begin{matrix}\\frac{1}{2} & 0 & - \\frac{\\sqrt{3}}{2}\\\\0 & 1 & 0\\\\\\frac{\\sqrt{3}}{2} & 0 & \\frac{1}{2}\\end{matrix}\\right]$" - ], - "text/plain": [ - "Matrix([\n", - "[ 1/2, 0, -1*sqrt(3)/2],\n", - "[ 0, 1, 0],\n", - "[sqrt(3)/2, 0, 1/2]])" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "math_evaluator.latex2matrix(\n", " r\"\\begin{pmatrix} \\frac{1}{2} & 0 & -\\frac{\\sqrt{3}}{2} \\\\ 0 & 1 & 0 \\\\ \\frac{\\sqrt{3}}{2} & 0 & \\frac{1}{2} \\end{pmatrix}\"\n", diff --git a/symeval/core.py b/symeval/core.py index 0bb52dd..12b4887 100644 --- a/symeval/core.py +++ b/symeval/core.py @@ -1,4 +1,4 @@ -import regex +import re as regex import warnings from datetime import datetime from math import isclose @@ -12,6 +12,7 @@ Tuple as T_Tuple, Match, Set, + Counter as T_Counter ) from pebble import ProcessPool from collections import Counter @@ -206,13 +207,13 @@ def get_maj_answers(self, answers: List[str]) -> List[str]: """Get the majority answers.""" maj_answers: List[str] = [] - ans_votes: Counter[str] = Counter() + ans_votes: T_Counter[str] = Counter() # Normalize all the answers for answer in answers: for exist_ans in ans_votes: correct: bool try: - correct = self.eq(exist_ans, answer) + correct = self.eq(answer, exist_ans) except Exception: correct = False if correct: @@ -225,7 +226,7 @@ def get_maj_answers(self, answers: List[str]) -> List[str]: return maj_answers - def get_maj_ans_from_votes(self, ans_votes: Counter[str]) -> str: + def get_maj_ans_from_votes(self, ans_votes: T_Counter[str]) -> str: maj_ans = ans_votes.most_common(1)[0][0] if maj_ans == "" and len(ans_votes) > 1: maj_ans = ans_votes.most_common(2)[1][0] @@ -309,7 +310,9 @@ def batch_eq( for ref, pred in zip(ref_answers, pred_answers) ] - def batch_get_maj_answers(self, answers_list: List[List[str]]) -> List[List[str]]: + def batch_get_maj_answers( + self, answers_list: List[List[str]], accurate: bool = True + ) -> List[List[str]]: """Get the majority answers for a batch of answers.""" maj_answers_list: List[List[str]] = [] # Gather all pairs to evaluate @@ -323,7 +326,11 @@ def batch_get_maj_answers(self, answers_list: List[List[str]]) -> List[List[str] all_ans_is: List[str] all_ans_js: List[str] all_ans_is, all_ans_js = zip(*all_ans_pairs) - all_eqs: List[bool] = self.batch_eq(all_ans_is, all_ans_js) + all_eqs: List[bool] = ( + self.batch_eq(all_ans_is, all_ans_js) + if accurate + else [ans_i == ans_j for ans_i, ans_j in all_ans_pairs] + ) all_pairs2eq: T_Dict[T_Tuple[str, str], bool] = dict( zip(all_ans_pairs, all_eqs) @@ -331,7 +338,7 @@ def batch_get_maj_answers(self, answers_list: List[List[str]]) -> List[List[str] # Get the majority answers for answers in answers_list: maj_answers: List[str] = [] - ans_votes: Counter[str] = Counter() + ans_votes: T_Counter[str] = Counter() for i_ans, answer in enumerate(answers): exist: bool = False for j_ans in range(i_ans):