-
Notifications
You must be signed in to change notification settings - Fork 483
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* support omni-math * update config * upload README * Delete opencompass/configs/datasets/omni_math/__init__.py --------- Co-authored-by: liushz <[email protected]>
- Loading branch information
Showing
5 changed files
with
220 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Omni-Math | ||
|
||
[Omni-Math](https://huggingface.co/datasets/KbsdJames/Omni-MATH) contains 4428 competition-level problems. These problems are meticulously categorized into 33 (and potentially more) sub-domains and span across 10 distinct difficulty levels, enabling a nuanced analysis of model performance across various mathematical disciplines and levels of complexity. | ||
|
||
* Project Page: https://omni-math.github.io/ | ||
* Github Repo: https://github.com/KbsdJames/Omni-MATH | ||
* Omni-Judge (opensource evaluator of this dataset): https://huggingface.co/KbsdJames/Omni-Judge | ||
|
||
## Omni-Judge | ||
|
||
> Omni-Judge is an open-source mathematical evaluation model designed to assess whether a solution generated by a model is correct given a problem and a standard answer. | ||
You should deploy the omni-judge server like: | ||
```bash | ||
set -x | ||
|
||
lmdeploy serve api_server KbsdJames/Omni-Judge --server-port 8000 \ | ||
--tp 1 \ | ||
--cache-max-entry-count 0.9 \ | ||
--log-level INFO | ||
``` | ||
|
||
and set the server url in opencompass config file: | ||
|
||
```python | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from opencompass.configs.datasets.omni_math.omni_math_gen import omni_math_datasets | ||
|
||
|
||
omni_math_dataset = omni_math_datasets[0] | ||
omni_math_dataset['eval_cfg']['evaluator'].update( | ||
url=['http://172.30.8.45:8000', | ||
'http://172.30.16.113:8000'], | ||
) | ||
``` | ||
|
||
## Performance | ||
|
||
| llama-3_1-8b-instruct | qwen-2_5-7b-instruct | InternLM3-8b-Instruct | | ||
| -- | -- | -- | | ||
| 15.18 | 29.97 | 32.75 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .omni_math_gen_18cc08 import omni_math_datasets # noqa: F401, F403 |
45 changes: 45 additions & 0 deletions
45
opencompass/configs/datasets/omni_math/omni_math_gen_18cc08.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import ZeroRetriever | ||
from opencompass.openicl.icl_inferencer import GenInferencer | ||
|
||
from opencompass.datasets.omni_math import OmniMathDataset, OmniMathEvaluator | ||
|
||
|
||
reader_cfg = dict( | ||
input_columns=['problem'], | ||
output_column='answer' | ||
) | ||
|
||
infer_cfg = dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict( | ||
round=[ | ||
dict(role='HUMAN', prompt='please answer the following mathematical question, put your final answer in \\boxed{}.\n\n{problem}'), | ||
] | ||
) | ||
), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict( | ||
type=GenInferencer, | ||
max_out_len=2048, | ||
temperature=0.0 | ||
) | ||
) | ||
|
||
eval_cfg = dict( | ||
evaluator=dict( | ||
type=OmniMathEvaluator, | ||
url=[] | ||
) | ||
) | ||
|
||
omni_math_datasets = [ | ||
dict( | ||
type=OmniMathDataset, | ||
abbr='OmniMath', | ||
reader_cfg=reader_cfg, | ||
infer_cfg=infer_cfg, | ||
eval_cfg=eval_cfg | ||
) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import concurrent.futures | ||
from typing import List | ||
|
||
import numpy as np | ||
from datasets import load_dataset | ||
from transformers import AutoTokenizer | ||
|
||
from opencompass.models.turbomind_api import TurboMindAPIModel | ||
from opencompass.openicl.icl_evaluator import BaseEvaluator | ||
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS | ||
|
||
from .base import BaseDataset | ||
|
||
|
||
@LOAD_DATASET.register_module() | ||
class OmniMathDataset(BaseDataset): | ||
|
||
@staticmethod | ||
def load(): | ||
dataset = load_dataset('KbsdJames/Omni-MATH')['test'] | ||
return dataset | ||
|
||
|
||
@ICL_EVALUATORS.register_module() | ||
class OmniMathEvaluator(BaseEvaluator): | ||
api_meta_template = dict(round=[ | ||
dict(role='HUMAN', api_role='HUMAN'), | ||
dict(role='BOT', api_role='BOT', generate=True), | ||
]) | ||
|
||
def __init__(self, url): | ||
if isinstance(url, str): | ||
url = [url] | ||
|
||
self.model = [ | ||
MODELS.build( | ||
dict( | ||
type=TurboMindAPIModel, | ||
model_name='KbsdJames/Omni-Judge', | ||
api_addr=url, | ||
meta_template=self.api_meta_template, | ||
temperature=0.0, | ||
max_seq_len=8192, | ||
)) for url in url | ||
] | ||
self.tokenizer = AutoTokenizer.from_pretrained('KbsdJames/Omni-Judge', | ||
trust_remote_code=True) | ||
|
||
def batch_infer(self, models: List[TurboMindAPIModel], | ||
inputs: List[str]) -> List[str]: | ||
batch_num = len(models) | ||
batch_size = (len(inputs) + batch_num - 1) // batch_num | ||
result_responses = [] | ||
|
||
with concurrent.futures.ThreadPoolExecutor( | ||
max_workers=batch_num) as executor: | ||
futures = [ | ||
executor.submit(models[i].generate, | ||
inputs[i * batch_size:(i + 1) * batch_size]) | ||
for i in range(batch_num) | ||
] | ||
for response in executor.map(lambda f: f.result(), futures): | ||
result_responses.extend(response) | ||
|
||
return result_responses | ||
|
||
def parse_response(self, response): | ||
response = '## Student Final Answer\n' + response.strip() | ||
|
||
parts = response.split('## ') | ||
info = {} | ||
|
||
for part in parts[1:]: | ||
lines = part.strip().split('\n') | ||
title = lines[0].strip() | ||
content = '\n'.join(lines[1:]).strip() | ||
|
||
if title == 'Justification': | ||
info[title] = content | ||
else: | ||
info[title] = lines[1].strip() if len(lines) > 1 else '' | ||
|
||
if info == {}: | ||
return False | ||
try: | ||
correctness = info['Equivalence Judgement'] | ||
if correctness == 'TRUE': | ||
return True | ||
else: | ||
return False | ||
except Exception as e: | ||
print(e) | ||
return False | ||
|
||
def score(self, predictions, references, origin_prompt, test_set): | ||
questions = [d['problem'] for d in test_set] | ||
|
||
contexts = [] | ||
for question, reference, candidate in zip(questions, references, | ||
predictions): | ||
context = self.tokenizer.get_context(question, reference, | ||
candidate) | ||
contexts.append(context) | ||
|
||
responses = self.batch_infer(self.model, contexts) | ||
labels = list(map(self.parse_response, responses)) | ||
|
||
details = [] | ||
for question, reference, candidate, response, label in zip( | ||
questions, references, predictions, responses, labels): | ||
details.append({ | ||
'question': question, | ||
'reference': reference, | ||
'candidate': candidate, | ||
'response': response, | ||
'label': label | ||
}) | ||
return {'details': details, 'accuracy': np.mean(labels) * 100} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters