Skip to content

Commit

Permalink
[Feature] Support Omni-Math (#1837)
Browse files Browse the repository at this point in the history
* support omni-math

* update config

* upload README

* Delete opencompass/configs/datasets/omni_math/__init__.py

---------

Co-authored-by: liushz <[email protected]>
  • Loading branch information
jnanliu and liushz authored Jan 23, 2025
1 parent 35ec307 commit 70f2c96
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 3 deletions.
43 changes: 43 additions & 0 deletions opencompass/configs/datasets/omni_math/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Omni-Math

[Omni-Math](https://huggingface.co/datasets/KbsdJames/Omni-MATH) contains 4428 competition-level problems. These problems are meticulously categorized into 33 (and potentially more) sub-domains and span across 10 distinct difficulty levels, enabling a nuanced analysis of model performance across various mathematical disciplines and levels of complexity.

* Project Page: https://omni-math.github.io/
* Github Repo: https://github.com/KbsdJames/Omni-MATH
* Omni-Judge (opensource evaluator of this dataset): https://huggingface.co/KbsdJames/Omni-Judge

## Omni-Judge

> Omni-Judge is an open-source mathematical evaluation model designed to assess whether a solution generated by a model is correct given a problem and a standard answer.
You should deploy the omni-judge server like:
```bash
set -x

lmdeploy serve api_server KbsdJames/Omni-Judge --server-port 8000 \
--tp 1 \
--cache-max-entry-count 0.9 \
--log-level INFO
```

and set the server url in opencompass config file:

```python
from mmengine.config import read_base

with read_base():
from opencompass.configs.datasets.omni_math.omni_math_gen import omni_math_datasets


omni_math_dataset = omni_math_datasets[0]
omni_math_dataset['eval_cfg']['evaluator'].update(
url=['http://172.30.8.45:8000',
'http://172.30.16.113:8000'],
)
```

## Performance

| llama-3_1-8b-instruct | qwen-2_5-7b-instruct | InternLM3-8b-Instruct |
| -- | -- | -- |
| 15.18 | 29.97 | 32.75 |
4 changes: 4 additions & 0 deletions opencompass/configs/datasets/omni_math/omni_math_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .omni_math_gen_18cc08 import omni_math_datasets # noqa: F401, F403
45 changes: 45 additions & 0 deletions opencompass/configs/datasets/omni_math/omni_math_gen_18cc08.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer

from opencompass.datasets.omni_math import OmniMathDataset, OmniMathEvaluator


reader_cfg = dict(
input_columns=['problem'],
output_column='answer'
)

infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='please answer the following mathematical question, put your final answer in \\boxed{}.\n\n{problem}'),
]
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(
type=GenInferencer,
max_out_len=2048,
temperature=0.0
)
)

eval_cfg = dict(
evaluator=dict(
type=OmniMathEvaluator,
url=[]
)
)

omni_math_datasets = [
dict(
type=OmniMathDataset,
abbr='OmniMath',
reader_cfg=reader_cfg,
infer_cfg=infer_cfg,
eval_cfg=eval_cfg
)
]
118 changes: 118 additions & 0 deletions opencompass/datasets/omni_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import concurrent.futures
from typing import List

import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer

from opencompass.models.turbomind_api import TurboMindAPIModel
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET, MODELS

from .base import BaseDataset


@LOAD_DATASET.register_module()
class OmniMathDataset(BaseDataset):

@staticmethod
def load():
dataset = load_dataset('KbsdJames/Omni-MATH')['test']
return dataset


@ICL_EVALUATORS.register_module()
class OmniMathEvaluator(BaseEvaluator):
api_meta_template = dict(round=[
dict(role='HUMAN', api_role='HUMAN'),
dict(role='BOT', api_role='BOT', generate=True),
])

def __init__(self, url):
if isinstance(url, str):
url = [url]

self.model = [
MODELS.build(
dict(
type=TurboMindAPIModel,
model_name='KbsdJames/Omni-Judge',
api_addr=url,
meta_template=self.api_meta_template,
temperature=0.0,
max_seq_len=8192,
)) for url in url
]
self.tokenizer = AutoTokenizer.from_pretrained('KbsdJames/Omni-Judge',
trust_remote_code=True)

def batch_infer(self, models: List[TurboMindAPIModel],
inputs: List[str]) -> List[str]:
batch_num = len(models)
batch_size = (len(inputs) + batch_num - 1) // batch_num
result_responses = []

with concurrent.futures.ThreadPoolExecutor(
max_workers=batch_num) as executor:
futures = [
executor.submit(models[i].generate,
inputs[i * batch_size:(i + 1) * batch_size])
for i in range(batch_num)
]
for response in executor.map(lambda f: f.result(), futures):
result_responses.extend(response)

return result_responses

def parse_response(self, response):
response = '## Student Final Answer\n' + response.strip()

parts = response.split('## ')
info = {}

for part in parts[1:]:
lines = part.strip().split('\n')
title = lines[0].strip()
content = '\n'.join(lines[1:]).strip()

if title == 'Justification':
info[title] = content
else:
info[title] = lines[1].strip() if len(lines) > 1 else ''

if info == {}:
return False
try:
correctness = info['Equivalence Judgement']
if correctness == 'TRUE':
return True
else:
return False
except Exception as e:
print(e)
return False

def score(self, predictions, references, origin_prompt, test_set):
questions = [d['problem'] for d in test_set]

contexts = []
for question, reference, candidate in zip(questions, references,
predictions):
context = self.tokenizer.get_context(question, reference,
candidate)
contexts.append(context)

responses = self.batch_infer(self.model, contexts)
labels = list(map(self.parse_response, responses))

details = []
for question, reference, candidate, response, label in zip(
questions, references, predictions, responses, labels):
details.append({
'question': question,
'reference': reference,
'candidate': candidate,
'response': response,
'label': label
})
return {'details': details, 'accuracy': np.mean(labels) * 100}
13 changes: 10 additions & 3 deletions opencompass/models/turbomind_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,28 @@ class TurboMindAPIModel(BaseModel):
is_api: bool = True

def __init__(self,
model_name: str = None,
api_addr: str = 'http://0.0.0.0:23333',
api_key: str | None = None,
max_seq_len: int = 2048,
meta_template: Optional[Dict] = None,
end_str: Optional[str] = None,
temperature: float = None,
**kwargs):
super().__init__(path='',
max_seq_len=max_seq_len,
meta_template=meta_template)
from lmdeploy.serve.openai.api_client import APIClient
self.chatbot = APIClient(api_addr, api_key)
self.model_name = self.chatbot.available_models[0]
self.model_name = model_name
self.logger = get_logger()
self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
self.api_addr = api_addr
self.end_str = end_str
self.temperature = temperature

def generate(
self,
Expand All @@ -84,6 +87,9 @@ def generate(
List[str]: A list of generated strings.
"""

if self.temperature is not None:
temperature = self.temperature

with ThreadPoolExecutor() as executor:
results = list(
executor.map(self._generate, inputs,
Expand Down Expand Up @@ -125,13 +131,14 @@ def _generate(self, prompt: PromptType, max_out_len: int,

response = ''
for output in self.chatbot.completions_v1(
session_id=threading.currentThread().ident,
prompt=prompt,
model=self.model_name,
max_tokens=max_out_len,
temperature=temperature,
top_p=0.8,
top_k=1):
top_k=50,
session_id=threading.currentThread().ident,
):
response += output['choices'][0]['text']
response = valid_str(response)
if end_str:
Expand Down

0 comments on commit 70f2c96

Please sign in to comment.