diff --git a/evaluation/fishfarm/fishfarm/__init__.py b/evaluation/fishfarm/fishfarm/__init__.py new file mode 100644 index 0000000..5b60b93 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/__init__.py @@ -0,0 +1,14 @@ +from . import chat_templates, models, tasks +from .models import Message, Model, Role +from .tasks import Task, TaskResult + +__all__ = [ + "chat_templates", + "tasks", + "models", + "Task", + "TaskResult", + "Model", + "Message", + "Role", +] diff --git a/evaluation/fishfarm/fishfarm/chat_templates.py b/evaluation/fishfarm/fishfarm/chat_templates.py new file mode 100644 index 0000000..0c9ff7e --- /dev/null +++ b/evaluation/fishfarm/fishfarm/chat_templates.py @@ -0,0 +1,13 @@ +LLAMA3 = ( + "{% set loop_messages = messages %}" + "{% for message in loop_messages %}" + "{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>" + "\n\n'+ message['content'] | trim + '<|eot_id|>' %}" + "{% if loop.index0 == 0 %}{% set content = bos_token + content %}" + "{% endif %}" + "{{ content }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" + "{% endif %}" +) diff --git a/evaluation/fishfarm/fishfarm/imports.py b/evaluation/fishfarm/fishfarm/imports.py new file mode 100644 index 0000000..75baed3 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/imports.py @@ -0,0 +1,94 @@ +from types import TracebackType +from typing import Optional, Tuple, Type + + +class _DeferredImportExceptionContextManager: + """Context manager to defer exceptions from imports. + + Catches :exc:`ImportError` and :exc:`SyntaxError`. + If any exception is caught, this class raises an :exc:`ImportError` when being checked. + + """ + + def __init__(self) -> None: + self._deferred: Optional[Tuple[Exception, str]] = None + + def __enter__(self) -> "_DeferredImportExceptionContextManager": + """Enter the context manager. + + Returns: + Itself. + + """ + return self + + def __exit__( + self, + exc_type: Optional[Type[Exception]], + exc_value: Optional[Exception], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + """Exit the context manager. + + Args: + exc_type: + Raised exception type. :obj:`None` if nothing is raised. + exc_value: + Raised exception object. :obj:`None` if nothing is raised. + traceback: + Associated traceback. :obj:`None` if nothing is raised. + + Returns: + :obj:`None` if nothing is deferred, otherwise :obj:`True`. + :obj:`True` will suppress any exceptions avoiding them from propagating. + + """ + if isinstance(exc_value, (ImportError, SyntaxError)): + if isinstance(exc_value, ImportError): + message = ( + "Tried to import '{}' but failed. Please make sure that the package is " + "installed correctly to use this feature. Actual error: {}." + ).format(exc_value.name, exc_value) + elif isinstance(exc_value, SyntaxError): + message = ( + "Tried to import a package but failed due to a syntax error in {}. Please " + "make sure that the Python version is correct to use this feature. Actual " + "error: {}." + ).format(exc_value.filename, exc_value) + else: + assert False + + self._deferred = (exc_value, message) + return True + return None + + def is_successful(self) -> bool: + """Return whether the context manager has caught any exceptions. + + Returns: + :obj:`True` if no exceptions are caught, :obj:`False` otherwise. + + """ + return self._deferred is None + + def check(self) -> None: + """Check whether the context manager has caught any exceptions. + + Raises: + :exc:`ImportError`: + If any exception was caught from the caught exception. + + """ + if self._deferred is not None: + exc_value, message = self._deferred + raise ImportError(message) from exc_value + + +def try_import() -> _DeferredImportExceptionContextManager: + """Create a context manager that can wrap imports of optional packages to defer exceptions. + + Returns: + Deferred import context manager. + + """ + return _DeferredImportExceptionContextManager() diff --git a/evaluation/fishfarm/fishfarm/logging.py b/evaluation/fishfarm/fishfarm/logging.py new file mode 100644 index 0000000..ef66768 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/logging.py @@ -0,0 +1,190 @@ +""" +Copied from Optuna repo: +https://github.com/optuna/optuna/blob/2595653638506e1b7e025a966a220984a59ab936/optuna/logging.py +Removed some comments for less verbosity. + +In general, `logger.info` is preferred over `print` since it contains module name and timestamp; +We recommend the use of logger object for the fishfarm developers. + +Inside fishfarm, we can call `get_logger(__name__)` from each python file. +Then the root logger format and level are applied to that logger object. +""" + +from __future__ import annotations + +import logging +import os +import sys +import threading +from logging import CRITICAL, DEBUG, ERROR, FATAL, INFO, WARN, WARNING + +import colorlog + +__all__ = [ + "CRITICAL", + "DEBUG", + "ERROR", + "FATAL", + "INFO", + "WARN", + "WARNING", +] + +_lock: threading.Lock = threading.Lock() +_default_handler: logging.Handler | None = None + + +def create_default_formatter() -> logging.Formatter: + """Create a default formatter of log messages. + + This function is not supposed to be directly accessed by library users. + """ + header = "[%(levelname)1.1s %(asctime)s %(name)s]" + message = "%(message)s" + if _color_supported(): + return colorlog.ColoredFormatter( + f"%(log_color)s{header}%(reset)s {message}", + ) + return logging.Formatter(f"{header} {message}") + + +def _color_supported() -> bool: + """Detection of color support.""" + # NO_COLOR environment variable: + if os.environ.get("NO_COLOR", None): + return False + + if not hasattr(sys.stderr, "isatty") or not sys.stderr.isatty(): + return False + else: + return True + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.setFormatter(create_default_formatter()) + + # Apply our default configuration to the library root logger. + library_root_logger: logging.Logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(logging.INFO) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger: logging.Logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_logger(name: str) -> logging.Logger: + """Return a logger with the specified name. + name's prefix should be `fishfarm.` (just like __name__ variable), + otherwise root logger settings will be not reflected. + This function is not supposed to be directly accessed by library users. + """ + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """Return the current level for the fishfarm's root logger. + + Returns: + Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``. + + .. note:: + fishfarm has following logging levels: + + - ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL`` + - ``fishfarm.logging.ERROR`` + - ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN`` + - ``fishfarm.logging.INFO`` + - ``fishfarm.logging.DEBUG`` + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """Set the level for the fishfarm's root logger. + + Args: + verbosity: + Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``. + + .. note:: + fishfarm has following logging levels: + + - ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL`` + - ``fishfarm.logging.ERROR`` + - ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN`` + - ``fishfarm.logging.INFO`` + - ``fishfarm.logging.DEBUG`` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def disable_default_handler() -> None: + """Disable the default handler of the fishfarm's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the fishfarm's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def disable_propagation() -> None: + """Disable propagation of the library log outputs. + + Note that log propagation is disabled by default. You only need to use this function + to stop log propagation when you use :func:`~fishfarm.logging.enable_propagation()`. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """Enable propagation of the library log outputs. + + Please disable the fishfarm's default handler to prevent double logging if the root logger has + been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True diff --git a/evaluation/fishfarm/fishfarm/models/__init__.py b/evaluation/fishfarm/fishfarm/models/__init__.py new file mode 100644 index 0000000..5c8ad93 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/models/__init__.py @@ -0,0 +1,12 @@ +from .base import (GenerationRequest, GenerationResult, Message, Model, + NLLRequest, NLLResult, Role) + +__all__ = [ + "GenerationRequest", + "GenerationResult", + "NLLRequest", + "NLLResult", + "Model", + "Role", + "Message", +] diff --git a/evaluation/fishfarm/fishfarm/models/base.py b/evaluation/fishfarm/fishfarm/models/base.py new file mode 100644 index 0000000..7a97f0d --- /dev/null +++ b/evaluation/fishfarm/fishfarm/models/base.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Literal, Optional, Sequence + +Role = Literal["system", "user", "assistant", "assistant_prefill"] + + +@dataclass +class Message: + + role: Role + content: str + + +@dataclass +class GenerationRequest: + + messages: list[Message] + + max_tokens: Optional[int] = None + stop: Sequence[str] = () + + +@dataclass +class GenerationResult: + + request: GenerationRequest + generation: str + + +@dataclass +class NLLRequest: + + messages: list[Message] + + +@dataclass +class NLLResult: + + request: NLLRequest + sum_nll: float + num_considered_tokens: int + + +class Model: + + def generate( + self, requests: Sequence[GenerationRequest] + ) -> Iterable[GenerationResult]: + raise NotImplementedError() + + def nll(self, requests: Sequence[NLLRequest]) -> Iterable[NLLResult]: + raise NotImplementedError() diff --git a/evaluation/fishfarm/fishfarm/models/tokenization_utils.py b/evaluation/fishfarm/fishfarm/models/tokenization_utils.py new file mode 100644 index 0000000..7fb92df --- /dev/null +++ b/evaluation/fishfarm/fishfarm/models/tokenization_utils.py @@ -0,0 +1,62 @@ +import dataclasses +from typing import Optional + +from transformers import PreTrainedTokenizerBase + +from .base import Message + + +class MaskedTokens: + + text: str + token_ids: list[int] + mask: list[bool] + + def __init__(self) -> None: + self.text = "" + self.token_ids = [] + self.mask = [] + + def extend( + self, + messages: list[Message], + mask_value: bool, + tokenizer: PreTrainedTokenizerBase, + chat_template: Optional[str], + add_generation_prompt: bool, + ) -> None: + if len(messages) == 0: + # `tokenizer.apply_chat_template` does not accept an empty list. + raise ValueError("At least one message is required.") + + all_text: str = tokenizer.apply_chat_template( + conversation=[dataclasses.asdict(message) for message in messages], + chat_template=chat_template, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + assert all_text.startswith(self.text) + new_text = all_text[len(self.text) :] + new_token_ids: list[int] = tokenizer.encode(new_text, add_special_tokens=False) + + self.token_ids.extend(new_token_ids) + self.mask.extend([mask_value] * len(new_token_ids)) + self.text = all_text + + +def tokenize_messages( + messages: list[Message], + tokenizer: PreTrainedTokenizerBase, + chat_template: Optional[str], +) -> MaskedTokens: + masked_tokens = MaskedTokens() + + for i, message in enumerate(messages): + if message.role != "assistant": + continue + + masked_tokens.extend(messages[:i], False, tokenizer, chat_template, True) + masked_tokens.extend(messages[: i + 1], True, tokenizer, chat_template, False) + + masked_tokens.extend(messages, False, tokenizer, chat_template, True) + return masked_tokens diff --git a/evaluation/fishfarm/fishfarm/models/vllm_model.py b/evaluation/fishfarm/fishfarm/models/vllm_model.py new file mode 100644 index 0000000..e779829 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/models/vllm_model.py @@ -0,0 +1,145 @@ +import copy +import dataclasses +from typing import Any, Iterable, Optional, Sequence + +from fishfarm.models.base import NLLRequest, NLLResult +from transformers import PreTrainedTokenizerBase + +from ..imports import try_import +from .base import GenerationRequest, GenerationResult, Message, Model +from .tokenization_utils import tokenize_messages + +with try_import() as _imports: + import vllm + +_imports.check() + + +class VLLMModel(Model): + + def __init__( + self, + llm: vllm.LLM, + sampling_params: vllm.SamplingParams, + chat_template: Optional[str], + ) -> None: + self.llm = llm + self.chat_template = chat_template + self.sampling_params = sampling_params + + def get_tokenizer(self) -> PreTrainedTokenizerBase: + tokenizer = self.llm.get_tokenizer() + + if not hasattr(tokenizer, "apply_chat_template"): + if hasattr(tokenizer, "tokenizer"): + tokenizer = tokenizer.tokenizer + else: + raise ValueError( + "The tokenizer does not have the 'apply_chat_template' method. " + "This is likely because of the versions of vLLM or transformers." + ) + + return tokenizer + + def _into_prompt(self, messages: Sequence[Message]) -> str: + tokenizer = self.get_tokenizer() + prefill_text = "" + n_assistant_prefill = sum([m.role == "assistant_prefill" for m in messages]) + if n_assistant_prefill > 1: + raise ValueError( + f"There must be at most one assistant_prefill role, but got {n_assistant_prefill}", + ) + if n_assistant_prefill: + assert ( + messages[-1].role == "assistant_prefill" + ), "assistant_prefill role must be the last message" + prefill_text = messages[-1].content + messages = messages[:-1] + prompt: str = tokenizer.apply_chat_template( + conversation=[dataclasses.asdict(message) for message in messages], + chat_template=self.chat_template, + tokenize=False, + add_generation_prompt=True, + ) + prompt += prefill_text + return prompt + + def _predict_log_probs(self, token_ids_list: list[list[int]]) -> list[list[float]]: + sampling_params = copy.copy(self.sampling_params) + sampling_params.prompt_logprobs = 1 + sampling_params.max_tokens = 1 + + completions = self.llm.generate( + prompt_token_ids=token_ids_list, + sampling_params=sampling_params, + ) + + log_probs_list = [] + for token_ids, completion in zip(token_ids_list, completions): + log_probs = [] + assert completion.prompt_logprobs is not None + assert token_ids == completion.prompt_token_ids + assert len(token_ids) == len(completion.prompt_logprobs) + for token_id, logprob_dict in zip(token_ids, completion.prompt_logprobs): + if logprob_dict is None: + log_probs.append(0.0) + else: + logprob_entry: Any = logprob_dict[token_id] + + if isinstance(logprob_entry, float): + log_probs.append(logprob_entry) + else: + log_probs.append(logprob_entry.logprob) + + log_probs_list.append(log_probs) + + return log_probs_list + + def generate( + self, requests: Sequence[GenerationRequest] + ) -> Iterable[GenerationResult]: + + prompts = [self._into_prompt(request.messages) for request in requests] + completions = self.llm.generate( + prompts=prompts, + sampling_params=self.sampling_params, + ) + + for request, completion in zip(requests, completions): + yield GenerationResult( + request=request, generation=completion.outputs[0].text + ) + + def nll(self, requests: Sequence[NLLRequest]) -> Iterable[NLLResult]: + masked_tokens_list = [ + tokenize_messages( + request.messages, self.get_tokenizer(), self.chat_template + ) + for request in requests + ] + log_probs_list = self._predict_log_probs( + [masked_tokens.token_ids for masked_tokens in masked_tokens_list] + ) + + results = [] + for log_probs, masked_tokens, request in zip( + log_probs_list, masked_tokens_list, requests + ): + assert len(log_probs) == len(masked_tokens.mask) + + sum_nll = 0.0 + num_considered_tokens = 0 + for log_prob, mask_value in zip(log_probs, masked_tokens.mask): + if mask_value: + sum_nll += -log_prob + num_considered_tokens += 1 + + results.append( + NLLResult( + request=request, + sum_nll=sum_nll, + num_considered_tokens=num_considered_tokens, + ) + ) + + return results diff --git a/evaluation/fishfarm/fishfarm/tasks/__init__.py b/evaluation/fishfarm/fishfarm/tasks/__init__.py new file mode 100644 index 0000000..4e3728c --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/__init__.py @@ -0,0 +1,8 @@ +from . import base +from .base import Task, TaskResult + +__all__ = [ + "base", + "TaskResult", + "Task", +] diff --git a/evaluation/fishfarm/fishfarm/tasks/ai2_arc.py b/evaluation/fishfarm/fishfarm/tasks/ai2_arc.py new file mode 100644 index 0000000..345d67e --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/ai2_arc.py @@ -0,0 +1,118 @@ +import random +import re +from dataclasses import dataclass +from typing import Iterable, Optional, Sequence + +from ..models import GenerationRequest, Message, Model +from .base import Task, TaskResult + + +def extract_answer(text: str) -> Optional[str]: + pattern = r"answer is \(?([A-J])\)?" + match = re.search(pattern, text) + if match: + return match.group(1) + else: + return extract_again(text) + + +def extract_again(text: str) -> Optional[str]: + match = re.search(r".*[aA]nswer:\s*([A-J])", text) + if match: + return match.group(1) + else: + return extract_final(text) + + +def extract_final(text: str) -> Optional[str]: + pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)" + match = re.search(pattern, text, re.DOTALL) + if match: + return match.group(0) + else: + return None + + +def is_correct(pred: Optional[str], answer: str, options: list[str]) -> bool: + if not pred: + random.seed(42) + x = random.randint(0, len(options) - 1) + if ["A", "B", "C", "D", "E"][x] == answer: + return True + else: + return False + elif pred == answer: + return True + else: + return False + + +@dataclass +class Ai2ArcSample: + + question: str + question_id: str + options: list[str] + answer: str + + +def mean(iterable: Iterable[float]) -> float: + total, count = 0.0, 0 + for x in iterable: + total += x + count += 1 + return total / count + + +class Ai2ArcTask(Task): + def __init__( + self, + samples: Sequence[Ai2ArcSample], + context_messages: Sequence[Message] = (), + ): + self.samples = list(samples) + self.context_messages = context_messages + + @property + def num_samples(self) -> int: + return len(self.samples) + + def evaluate( + self, + model: Model, + sample_ids: Optional[Sequence[int]] = None, + ) -> TaskResult: + if sample_ids is None: + sample_ids = range(len(self.samples)) + samples = [self.samples[sample_id] for sample_id in sample_ids] + + requests = [] + for sample in samples: + messages = list(self.context_messages) + messages.append(Message(role="user", content=sample.question)) + requests.append(GenerationRequest(messages=messages)) + + sample_details = [] + for sample, result in zip(samples, model.generate(requests)): + output = result.generation + prediction = extract_answer(result.generation) + + sample_details.append( + dict( + problem=sample.question, + output=output, + answer=sample.answer, + prediction=prediction, + correct=is_correct(prediction, sample.answer, sample.options), + ) + ) + + aggregate_metrics = { + "acc": mean( + float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0 + for sd in sample_details + ) + } + return TaskResult( + aggregate_metrics=aggregate_metrics, sample_details=sample_details + ) diff --git a/evaluation/fishfarm/fishfarm/tasks/base.py b/evaluation/fishfarm/fishfarm/tasks/base.py new file mode 100644 index 0000000..dbfa6eb --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/base.py @@ -0,0 +1,28 @@ +import abc +from dataclasses import dataclass +from typing import Any, Optional, Sequence + +from ..models import Model + + +@dataclass +class TaskResult: + + aggregate_metrics: dict[str, float] + sample_details: list[dict[str, Any]] + + +class Task(abc.ABC): + + @property + @abc.abstractmethod + def num_samples(self) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def evaluate( + self, + model: Model, + sample_ids: Optional[Sequence[int]] = None, + ) -> TaskResult: + raise NotImplementedError() diff --git a/evaluation/fishfarm/fishfarm/tasks/competation_math.py b/evaluation/fishfarm/fishfarm/tasks/competation_math.py new file mode 100644 index 0000000..1b01e3a --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/competation_math.py @@ -0,0 +1,391 @@ +from dataclasses import dataclass +from math import isclose +from typing import Any, Iterable, Optional, Sequence, Union + +from sympy import N, simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr + +from ..models import GenerationRequest, Message, Model +from .base import Task, TaskResult + + +def _fix_fracs(string: str) -> str: + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string: str) -> str: + if len(string.split("/")) != 2: + return string + a: str = string.split("/")[0] + b: str = string.split("/")[1] + try: + a_int: int = int(a) + b_int: int = int(b) + assert string == "{}/{}".format(a_int, b_int) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except (AssertionError, ValueError): + return string + + +def _remove_right_units(string: str) -> str: + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def _fix_sqrt(string: str) -> str: + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def _strip_string(string: str) -> str: + string = string.replace("\n", "") + + string = string.replace("\\!", "") + + string = string.replace("\\\\", "\\") + + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + string = string.replace("\\$", "") + + string = _remove_right_units(string) + + string = string.replace(r"\\%", "") + string = string.replace(r"\%", "") + + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + string = _fix_sqrt(string) + + string = string.replace(" ", "") + + string = _fix_fracs(string) + + if string == "0.5": + string = "\\frac{1}{2}" + + string = _fix_a_slash_b(string) + + return string + + +def is_digit(s: Union[bool, float, str]) -> bool: + try: + float(str(s).replace(",", "")) + return True + except ValueError: + return False + + +def symbolic_equal(a: str, b: str) -> bool: + def _parse(s: str) -> Any: + for f in [parse_latex, parse_expr]: + try: + return f(s) + except Exception: + pass + return s + + a = _parse(a) + b = _parse(b) + + try: + if simplify(a - b) == 0: + return True + except Exception: + pass + + try: + if isclose(N(a), N(b), rel_tol=1e-3): + return True + except Exception: + pass + return False + + +def math_equal( + prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + is_close: bool = True, +) -> bool: + """ + Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + try: + if is_digit(prediction) and is_digit(reference): + prediction = float(str(prediction).replace(",", "")) + reference = float(str(reference).replace(",", "")) + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + for item in gt_result: + try: + if is_close: + if isclose(item, prediction, rel_tol=1e-4): + return True + else: + if item == prediction: + return True + except Exception: + continue + return False + except Exception: + pass + + if not prediction and prediction not in [0, False]: + return False + + reference = str(reference).strip() + prediction = str(prediction).strip() + + pred_str, ref_str = prediction, reference + if ( + prediction.startswith("[") + and prediction.endswith("]") + and not reference.startswith("(") + ) or ( + prediction.startswith("(") + and prediction.endswith(")") + and not reference.startswith("[") + ): + pred_str = pred_str.strip("[]()") + ref_str = ref_str.strip("[]()") + for s in ["{", "}", "(", ")"]: + ref_str = ref_str.replace(s, "") + pred_str = pred_str.replace(s, "") + if pred_str == ref_str: + return True + + if ( + (prediction.startswith("[") and prediction.endswith("]")) + and (reference.startswith("[") and reference.endswith("]")) + or (prediction.startswith("(") and prediction.endswith(")")) + and (reference.startswith("(") and reference.endswith(")")) + ): + pred_parts = prediction[1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts): + if all( + [ + math_equal( + pred_parts[i], ref_parts[i], include_percentage, is_close + ) + for i in range(len(pred_parts)) + ] + ): + return True + + if symbolic_equal(prediction, reference): + return True + + return False + + +def is_equiv(str1: Optional[str], str2: Optional[str]) -> bool: + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = _strip_string(str1) + ss2 = _strip_string(str2) + return math_equal(ss1, ss2) or ss1 == ss2 + except (AssertionError, TypeError, ValueError): + return math_equal(str1, str2) or str1 == str2 + + +def last_boxed_only_string(string: str) -> Optional[str]: + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx: Optional[int] = None + + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + assert right_brace_idx is not None + retval = string[idx : right_brace_idx + 1] + + return retval + + +def remove_boxed(s: Optional[str]) -> Optional[str]: + left = "\\boxed{" + if s is None: + return None + else: + try: + assert s[: len(left)] == left + assert s[-1] == "}" + return s[len(left) : -1] + except (AssertionError, TypeError, ValueError): + return None + + +@dataclass +class MathSample: + + problem: str + answer: Optional[str] = None + type: Optional[str] = None + + +def mean(iterable: Iterable[float]) -> float: + total, count = 0.0, 0 + for x in iterable: + total += x + count += 1 + return total / count + + +def extract_ans(completion: str) -> Optional[str]: + + split_ans = completion.split("The answer is: ") + if len(split_ans) > 1: + ans = split_ans[-1] + extract_ans_temp = ans.split(".\n")[0] + extract_ans_temp = extract_ans_temp.strip() + if len(extract_ans_temp) > 0 and extract_ans_temp[-1] == ".": + extract_ans = extract_ans_temp[0:-1] + else: + extract_ans = extract_ans_temp + extract_ans = extract_ans.strip() + return extract_ans + else: + return remove_boxed(last_boxed_only_string(completion)) + + +class LatexFormatMathTask(Task): + def __init__( + self, + samples: Sequence[MathSample], + context_messages: Sequence[Message] = (), + ): + self.samples = list(samples) + self.context_messages = context_messages + + @property + def num_samples(self) -> int: + return len(self.samples) + + def evaluate( + self, + model: Model, + sample_ids: Optional[Sequence[int]] = None, + ) -> TaskResult: + if sample_ids is None: + sample_ids = range(len(self.samples)) + samples = [self.samples[sample_id] for sample_id in sample_ids] + + requests = [] + for sample in samples: + messages = list(self.context_messages) + messages.append(Message(role="user", content=sample.problem)) + requests.append(GenerationRequest(messages=messages)) + + sample_details = [] + for sample, result in zip(samples, model.generate(requests)): + output = result.generation + prediction = extract_ans(output) + + sample_details.append( + dict( + problem=sample.problem, + output=output, + answer=sample.answer, + type=sample.type, + prediction=prediction, + correct=is_equiv(sample.answer, prediction), + ) + ) + + aggregate_metrics = { + "acc": mean( + float(sd["correct"]) if isinstance(sd["correct"], (bool)) else 0.0 + for sd in sample_details + ) + } + + return TaskResult( + aggregate_metrics=aggregate_metrics, sample_details=sample_details + ) diff --git a/evaluation/fishfarm/fishfarm/tasks/evalplus/__init__.py b/evaluation/fishfarm/fishfarm/tasks/evalplus/__init__.py new file mode 100644 index 0000000..56aa3dc --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/evalplus/__init__.py @@ -0,0 +1,4 @@ +from .data import load_dataset +from .task import EvalplusTask + +__all__ = ["EvalplusTask", "load_dataset"] diff --git a/evaluation/fishfarm/fishfarm/tasks/evalplus/data.py b/evaluation/fishfarm/fishfarm/tasks/evalplus/data.py new file mode 100644 index 0000000..a2ad377 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/evalplus/data.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass + +from evalplus.data import get_human_eval_plus, get_mbpp_plus + + +@dataclass +class TextToCodeProblem: + id: str + instruction: str + response_prefix: str + + +def get_mbpp_raw_problems() -> list[dict]: + problems = get_mbpp_plus() + return list(problems.values()) + + +def get_humaneval_raw_problems() -> list[dict]: + problems = get_human_eval_plus() + return list(problems.values()) + + +def read_mbpp_plus( + plus_path: str, err_incomplete: bool = True, mini: bool = False +) -> dict[str, dict]: + from evalplus.data.mbpp import (completeness_check, + mbpp_deserialize_inputs, stream_jsonl) + + plus = {task["task_id"]: task for task in stream_jsonl(plus_path)} + for task_id, task in plus.items(): + task["base_input"] = mbpp_deserialize_inputs(task_id, task["base_input"]) + task["plus_input"] = mbpp_deserialize_inputs(task_id, task["plus_input"]) + + if err_incomplete: + completeness_check("MBPP+", plus) + return plus + + +def map_mbpp_problem(p: dict) -> TextToCodeProblem: + id = p["task_id"] + prompt = p["prompt"] + start_index = prompt.index('"""') + end_index = prompt.rindex('"""') + prompt = prompt[start_index + 3 : end_index] + assert_index = prompt.index("assert") + instruction = prompt[:assert_index].strip() + if not instruction.endswith("."): + instruction += "." + assertion = prompt[assert_index:].strip() + instruction = f"""{instruction} Your code should satisfy the following assertion: +```python +{assertion} +```""" + response_prefix = """```python""" + return TextToCodeProblem( + id=str(id), instruction=instruction, response_prefix=response_prefix + ) + + +def map_humaneval_problem(p: dict) -> TextToCodeProblem: + id = p["task_id"] + prompt = p["prompt"] + prompt = prompt.strip() + instruction = f"""Write a solution to the following problem: +```python +{prompt} +```""" + response_prefix = f"""```python +{prompt}""" + return TextToCodeProblem( + id=id, instruction=instruction, response_prefix=response_prefix + ) + + +def load_dataset(source_dataset: str) -> list[TextToCodeProblem]: + if source_dataset not in ("humaneval", "mbpp"): + raise ValueError(f"Unknown source_dataset: {source_dataset}") + + raw_problem_fn = { + "humaneval": get_humaneval_raw_problems, + "mbpp": get_mbpp_raw_problems, + }[source_dataset] + + if source_dataset.startswith("humaneval"): + map_problem_fn = map_humaneval_problem + elif source_dataset.startswith("mbpp"): + map_problem_fn = map_mbpp_problem + else: + raise ValueError(f"Unknown source_dataset: {source_dataset}") + + raw_problems = raw_problem_fn() + problems = list(map(map_problem_fn, raw_problems)) + + return problems diff --git a/evaluation/fishfarm/fishfarm/tasks/evalplus/evaluation.py b/evaluation/fishfarm/fishfarm/tasks/evalplus/evaluation.py new file mode 100644 index 0000000..925fe25 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/evalplus/evaluation.py @@ -0,0 +1,257 @@ +import json +import multiprocessing +import os +import threading +import time +from collections import Counter, defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from datetime import datetime +from typing import Any +from warnings import warn + +import numpy as np +from evalplus.data import (get_human_eval_plus, get_human_eval_plus_hash, + get_mbpp_plus, get_mbpp_plus_hash, load_solutions) +from evalplus.eval import SUCCESS, estimate_pass_at_k, untrusted_check +from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS +from evalplus.evaluate import Result, get_groundtruth +from termcolor import cprint +from tqdm.auto import tqdm + +from ...logging import get_logger + +logger = get_logger(__name__) + + +def check_correctness( + dataset: str, + completion_id: int, + problem: dict[str, Any], + solution: str, + expected_output: dict[str, list], + base_only: bool = False, + fast_check: bool = False, + identifier: str = "HumanEval/0_0", + min_time_limit: float = 0.1, + gt_time_limit_factor: float = 2.0, +) -> dict[str, Result]: + ret = { + "completion_id": completion_id, + "task_id": problem["task_id"], + "_identifier": identifier, + "solution": solution, + } + ret["base"] = untrusted_check( + dataset, + solution, + problem["base_input"], + problem["entry_point"], + expected=expected_output["base"], + atol=problem["atol"], + ref_time=expected_output["base_time"], + fast_check=fast_check, + min_time_limit=min_time_limit, + gt_time_limit_factor=gt_time_limit_factor, + ) + + if not base_only: + ret["plus"] = untrusted_check( + dataset, + solution, + problem["plus_input"], + problem["entry_point"], + expected=expected_output["plus"], + atol=problem["atol"], + ref_time=expected_output["plus_time"], + fast_check=fast_check, + min_time_limit=min_time_limit, + gt_time_limit_factor=gt_time_limit_factor, + ) + return ret + + +def evaluate( + source_dataset: str, + output_path: str, + base_only: bool = False, + parallel: int = 0, + i_just_wanna_run: bool = False, + test_details: bool = False, + min_time_limit: float = 0.2, + gt_time_limit_factor: float = 4.0, + mini: bool = False, +) -> tuple[Any, list[dict[str, Any]]]: + if parallel == 0: + n_workers = max(1, multiprocessing.cpu_count() // 2) + else: + n_workers = parallel + + if os.path.isdir(output_path): + result_path = os.path.join(output_path, "eval_results.json") + else: + assert output_path.endswith(".jsonl") + result_path = output_path.replace(".jsonl", "_eval_results.json") + + if source_dataset == "humaneval": + problems = get_human_eval_plus(mini=mini) + dataset_hash = get_human_eval_plus_hash() + expected_output = get_groundtruth(problems, dataset_hash, []) + elif source_dataset == "mbpp": + problems = get_mbpp_plus(mini=mini) + dataset_hash = get_mbpp_plus_hash() + expected_output = get_groundtruth( + problems, + dataset_hash, + MBPP_OUTPUT_NOT_NONE_TASKS, + ) + + results = { + "date": datetime.now().strftime("%Y-%m-%d %H:%M"), + "hash": dataset_hash, + "eval": {}, + } + + with ProcessPoolExecutor(max_workers=n_workers) as executor: + futures = [] + completion_id: Counter[str] = Counter() + n_samples = 0 + eval_results = defaultdict(list) + remainings = set() + sample_details = [] + + logger.info("Reading samples...") + for sample in tqdm(load_solutions(output_path)): + task_id = sample["task_id"] + explanation = sample.get("explanation", "") + solution = ( + sample["solution"] + if "solution" in sample + else problems[task_id]["prompt"] + sample["completion"] + ) + remainings.add(sample["_identifier"]) + + args = ( + source_dataset, + completion_id[task_id], + problems[task_id], + solution, + expected_output[task_id], + base_only, + not test_details, + sample["_identifier"], + min_time_limit, + gt_time_limit_factor, + ) + + futures.append(executor.submit(check_correctness, *args)) + completion_id[task_id] += 1 + n_samples += 1 + + sample_details.append( + dict( + task_id=task_id, + solution=solution, + explanation=explanation, + problems=problems[task_id], + expected_output=expected_output[task_id], + ) + ) + + assert n_samples == len(remainings), "Missing problems in unfinished" + if len(completion_id) != len(problems): + logger.warning("Warning: Missing problems in samples") + + def stucking_checker() -> None: + while remainings: + last_size = len(remainings) + time.sleep(20) + if last_size != len(remainings) or len(remainings) == 0: + continue + warn("No samples had finished testing in the last 20s") + warn(f"{len(remainings)} samples to be tested: {remainings}") + + threading.Thread(target=stucking_checker).start() + + for future in tqdm(as_completed(futures), total=n_samples): + result = future.result() + remainings.remove(result["_identifier"]) + eval_results[result["task_id"]].append(result) + + for task_id, task_results in eval_results.items(): + task_results.sort(key=lambda x: x["completion_id"]) + results["eval"][task_id] = { + "nfiles": len(task_results), + "base": [x["base"] for x in task_results], + "plus": ([x["plus"] for x in task_results] if not base_only else []), + } + + if os.path.isfile(result_path) and i_just_wanna_run: + decision = "" + while decision.lower() not in ["y", "n"]: + logger.info( + f"{result_path} already exists. Press [Y/N] to overwrite or exit..." + ) + decision = input() + + if decision.lower() == "y": + new_path = result_path + ".bak" + while os.path.isfile(new_path): + new_path += ".bak" + os.rename(result_path, new_path) + logger.info(f"Backup {result_path} to {new_path}") + + if not os.path.isfile(result_path): + with open(result_path, "w") as f: + json.dump(results, f) + + total = np.array([r["nfiles"] for r in results["eval"].values()]) + base_correct = [] + new_correct = [] + + for key, res in results["eval"].items(): + elements = [element for element in sample_details if element["task_id"] == key] + assert ( + len(elements) == 1 + ), f"Expected an element with task_id {key}, found {len(elements)}" + element = elements[0] + + bc = sum([r[0] == SUCCESS for r in res["base"]]) + base_correct.append(bc) + element["base_correct"] = bc + if res["plus"]: + new_bc = sum( + [ + res["plus"][i][0] == res["base"][i][0] == SUCCESS + for i in range(len(res["plus"])) + ] + ) + new_correct.append(new_bc) + element["plus_correct"] = new_bc + + base_correct_array = np.array(base_correct) + + pass_at_k = { + f"pass@{k}": estimate_pass_at_k(total, base_correct_array, k).mean() + for k in [1, 10, 100] + if total.min() >= k + } + + result = {f"{source_dataset}_base_{key}": value for key, value in pass_at_k.items()} + cprint(f"{source_dataset} (base tests)", "red") + for k, v in pass_at_k.items(): + cprint(f"{k}:\t{v:.3f}", "red") + + if new_correct: + cprint(f"{source_dataset}+ (base + extra tests)", "green") + pass_at_k = { + f"pass@{k}": estimate_pass_at_k(total, np.array(new_correct), k).mean() + for k in [1, 10, 100] + if (total >= k).all() + } + result.update( + {f"{source_dataset}_plus_{key}": value for key, value in pass_at_k.items()} + ) + for k, v in pass_at_k.items(): + cprint(f"{k}:\t{v:.3f}", "green") + + return result, sample_details diff --git a/evaluation/fishfarm/fishfarm/tasks/evalplus/generation.py b/evaluation/fishfarm/fishfarm/tasks/evalplus/generation.py new file mode 100644 index 0000000..855369f --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/evalplus/generation.py @@ -0,0 +1,77 @@ +import itertools +from pathlib import Path +from typing import Iterable, List, Sequence, TypeVar + +from evalplus.data import write_jsonl +from tqdm.auto import tqdm + +from ...models import GenerationRequest, Message, Model +from .data import TextToCodeProblem + +_T = TypeVar("_T") + + +def chunked(seq: Sequence[_T], n: int) -> Iterable[Sequence[_T]]: + """Yield successive n-sized chunks from seq.""" + return (seq[i : i + n] for i in range(0, len(seq), n)) + + +def generate( + model: Model, + problems: list[TextToCodeProblem], + context_messages: Sequence[Message], + output_path: str, + n_batches: int = 1, + n_problems_per_batch: int = 1_000_000_000, + n_samples_per_problem: int = 1, +) -> List[str]: + problems_chunked = list(chunked(list(problems), n_problems_per_batch)) + iter = itertools.product(problems_chunked, range(n_batches)) + n_total = len(problems_chunked) * n_batches + + Path(output_path).write_text("") + for problems, batch_idx in tqdm(iter, total=n_total): + task_ids = [problem.id for problem in problems] + all_task_ids = task_ids * n_samples_per_problem + + requests = [] + for problem in problems: + messages = list(context_messages) + messages.append(Message(role="user", content=problem.instruction)) + messages.append( + Message(role="assistant_prefill", content=problem.response_prefix) + ) + requests.append(GenerationRequest(messages=messages)) + completes = model.generate(requests) + completions = [c.generation for c in completes] + + assert len(problems) <= n_problems_per_batch + assert len(completions) == len(problems) * n_samples_per_problem + + samples = [] + for task_id, completion in zip(all_task_ids, completions): + completion_body = completion[ + : ( + index + if (index := completion.find("```")) != -1 + else len(completion) + ) + ] + explanation = completion[ + ( + index + if (index := completion.find("```") + 3) != -1 + else len(completion) + ) : + ].strip() + + samples.append( + dict( + task_id=task_id, + completion=completion_body, + explanation=explanation, + ) + ) + + write_jsonl(output_path, samples, append=True) + return completions diff --git a/evaluation/fishfarm/fishfarm/tasks/evalplus/sanitization.py b/evaluation/fishfarm/fishfarm/tasks/evalplus/sanitization.py new file mode 100644 index 0000000..7ea6946 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/evalplus/sanitization.py @@ -0,0 +1,195 @@ +import ast +import os +import pathlib +import re +import traceback +from typing import Optional + +from evalplus.data import (get_human_eval_plus, get_mbpp_plus, load_solutions, + write_directory, write_jsonl) +from tqdm.auto import tqdm + +from ...logging import get_logger + +logger = get_logger(__name__) + + +def syntax_check(code: str, verbose: bool = False) -> bool: + try: + ast.parse(code) + return True + except (SyntaxError, MemoryError): + if verbose: + traceback.print_exc() + return False + + +def remove_unindented_lines( + code: str, protect_before: str, execeptions: list[str], trim_tails: list[str] +) -> str: + lines = code.splitlines() + cut_idx = [] + cut_enabled = False + for i, line in enumerate(lines): + if not cut_enabled and line.startswith(protect_before): + cut_enabled = True + continue + if line.strip() == "": + continue + if any(line.startswith(e) for e in execeptions): + continue + + lspace = len(line) - len(line.lstrip()) + if lspace == 0: + cut_idx.append(i) + + if any(line.rstrip().startswith(t) for t in trim_tails): + cut_idx.extend(list(range(i, len(lines)))) + break + + return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx]) + + +def to_four_space_indents(old_code: str) -> str: + new_code = "" + for line in old_code.splitlines(): + lspace = len(line) - len(line.lstrip()) + if lspace == 3: + new_code += " " + new_code += line + "\n" + return new_code + + +def sanitize_code( + old_code: str, + entry_point: str, + rm_prefix_lines: Optional[str] = None, + eofs: list = [], +) -> str: + new_code = old_code + if rm_prefix_lines is not None: + new_code = "\n".join( + [ + line + for line in old_code.splitlines() + if not line.startswith(rm_prefix_lines) + ] + ) + + new_code = "\n" + new_code + def_left = "def " + entry_point + + new_code = new_code.replace("\n```python\n", "\n```\n") + for chunk in new_code.split("\n```\n"): + if def_left in chunk: + new_code = chunk + break + + chunks = [chunk for chunk in re.split(rf"{def_left}\s*\(", new_code)] + bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]] + def_left = def_left + "(" + new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else "" + new_code = to_four_space_indents(new_code) + + for eof in eofs or []: + new_code = new_code.split(eof)[0] + + new_code = remove_unindented_lines( + new_code, + protect_before=def_left, + execeptions=["def ", "import ", "from "], + trim_tails=['"""', "if", "print"], + ) + new_code = chunks[0] + new_code + + parts = new_code.split("\ndef ") + includes = [parts[0]] + for fn in new_code.split("\ndef ")[1:]: + if ( + fn.strip().startswith(entry_point + " ") + or fn.strip().startswith(entry_point + "(") + or syntax_check("\ndef " + fn) + ): + includes.append(fn) + new_code = "\ndef ".join(includes) + return new_code.strip() + + +def sanitize( + source_dataset: str, + input_path: str, + eofs: list = [], + inplace: bool = False, + rm_prefix_lines: Optional[str] = None, + debug_task: Optional[str] = None, +) -> str: + entry_point = {} + + if source_dataset == "humaneval": + dataset = get_human_eval_plus() + elif source_dataset == "mbpp": + dataset = get_mbpp_plus() + + for task_id, problem in dataset.items(): + entry_point[task_id] = problem["entry_point"] + + is_folder = os.path.isdir(input_path) + target_path = pathlib.Path(input_path) + if not inplace: + if is_folder: + new_name = target_path.name + "-sanitized" + else: + new_name = target_path.name.replace(".jsonl", "-sanitized.jsonl") + target_path = target_path.parent / new_name + output_path = str(target_path) + + nsan = 0 + ntotal = 0 + + new_solutions = [] + + for solution in tqdm(load_solutions(input_path)): + task_id = solution["task_id"] + dbg_identifier = solution["_identifier"] + if debug_task is not None and task_id != debug_task: + continue + + ntotal += 1 + if "solution" in solution: + old_code = solution["solution"] + else: + assert "completion" in solution + old_code = dataset[task_id]["prompt"] + "\n" + solution["completion"] + + old_code = old_code.strip() + + new_code = sanitize_code( + old_code=old_code, + entry_point=entry_point[task_id], + rm_prefix_lines=rm_prefix_lines, + eofs=eofs, + ).strip() + + if new_code != old_code: + msg = "Sanitized: " + dbg_identifier + if is_folder: + msg += " -> " + dbg_identifier.replace(input_path, output_path) + logger.info(msg) + nsan += 1 + + new_solutions.append( + { + "task_id": task_id, + "solution": new_code, + "explanation": solution["explanation"], + } + ) + + if is_folder: + write_directory(output_path, new_solutions) + else: + write_jsonl(output_path, new_solutions) + + logger.info(f"Sanitized {nsan} out of {ntotal} files.") + + return output_path diff --git a/evaluation/fishfarm/fishfarm/tasks/evalplus/task.py b/evaluation/fishfarm/fishfarm/tasks/evalplus/task.py new file mode 100644 index 0000000..7bc4fb3 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/evalplus/task.py @@ -0,0 +1,54 @@ +import tempfile +from typing import Literal, Optional, Sequence + +from ...models import Message, Model +from ..base import Task, TaskResult +from . import evaluation, generation, sanitization +from .data import TextToCodeProblem + + +class EvalplusTask(Task): + + def __init__( + self, + samples: Sequence[TextToCodeProblem], + context_messages: Sequence[Message] = (), + source_dataset: Literal["humaneval", "mbpp"] = "humaneval", + ): + self.samples = list(samples) + self.context_messages = context_messages + self.source_dataset = source_dataset + if source_dataset not in ("humaneval", "mbpp"): + raise ValueError(f"Unknown source_dataset: {source_dataset}") + + @property + def num_samples(self) -> int: + return len(self.samples) + + def evaluate( + self, + model: Model, + sample_ids: Optional[Sequence[int]] = None, + ) -> TaskResult: + if sample_ids is None: + sample_ids = range(len(self.samples)) + samples = [self.samples[sample_id] for sample_id in sample_ids] + + with tempfile.TemporaryDirectory() as save_dir: + output_path = f"{save_dir}/outputs.jsonl" + + completions = generation.generate( + model, samples, self.context_messages, output_path + ) + + if self.source_dataset == "mbpp": + output_path = sanitization.sanitize(self.source_dataset, output_path) + + result, sample_details = evaluation.evaluate( + self.source_dataset, output_path + ) + + for i, completion in enumerate(completions): + sample_details[i]["output"] = completion + + return TaskResult(aggregate_metrics=result, sample_details=sample_details) diff --git a/evaluation/fishfarm/fishfarm/tasks/language_restricted_math.py b/evaluation/fishfarm/fishfarm/tasks/language_restricted_math.py new file mode 100644 index 0000000..3479094 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/tasks/language_restricted_math.py @@ -0,0 +1,106 @@ +import re +from dataclasses import dataclass +from typing import Iterable, Optional, Sequence + +import huggingface_hub + +from ..imports import try_import +from ..models import GenerationRequest, Message, Model +from .base import Task, TaskResult + +with try_import() as _imports: + import fasttext + +_imports.check() + + +@dataclass +class MathSample: + + problem: str + answer: int + + +def mean(iterable: Iterable[float]) -> float: + total, count = 0.0, 0 + for x in iterable: + total += x + count += 1 + return total / count + + +def extract_answer_number(completion: str) -> Optional[float]: + matches = re.findall(r"\d*\.?\d+", completion) + if not matches: + return None + text = matches[-1] + return float(text.replace(",", "")) + + +class LanguageRestrictedMathTask(Task): + def __init__( + self, + samples: Sequence[MathSample], + context_messages: Sequence[Message] = (), + languages: Sequence[str] = ("ja", "en"), + ): + self.samples = list(samples) + self.languages = languages + self.context_messages = context_messages + if len(self.languages) != 0: + lid176ftz_path = huggingface_hub.hf_hub_download( + "julien-c/fasttext-language-id", "lid.176.ftz" + ) + self.lid_model = fasttext.load_model(lid176ftz_path) + + @property + def num_samples(self) -> int: + return len(self.samples) + + def evaluate( + self, + model: Model, + sample_ids: Optional[Sequence[int]] = None, + ) -> TaskResult: + if sample_ids is None: + sample_ids = range(len(self.samples)) + samples = [self.samples[sample_id] for sample_id in sample_ids] + + requests = [] + for sample in samples: + messages = list(self.context_messages) + messages.append(Message(role="user", content=sample.problem)) + requests.append(GenerationRequest(messages=messages)) + + sample_details = [] + for sample, result in zip(samples, model.generate(requests)): + output = result.generation + prediction = extract_answer_number(result.generation) + if len(self.languages) != 0: + lid_probs = dict( + zip(*self.lid_model.predict(output.replace("\n", ""), k=-1)) + ) + + sample_details.append( + dict( + problem=sample.problem, + output=output, + answer=sample.answer, + prediction=prediction, + correct=sample.answer == prediction, + **{ + f"lang_{lang}": lid_probs.get(f"__label__{lang}", 0.0) + for lang in self.languages + }, + ) + ) + + aggregate_metrics = {"acc": mean(sd["correct"] for sd in sample_details)} + for lang in self.languages: + aggregate_metrics[f"acc_{lang}"] = mean( + (sd["correct"] and sd[f"lang_{lang}"] > 0.5) for sd in sample_details + ) + + return TaskResult( + aggregate_metrics=aggregate_metrics, sample_details=sample_details + ) diff --git a/evaluation/fishfarm/fishfarm/version.py b/evaluation/fishfarm/fishfarm/version.py new file mode 100644 index 0000000..4494976 --- /dev/null +++ b/evaluation/fishfarm/fishfarm/version.py @@ -0,0 +1 @@ +__version__ = "0.1.0dev" diff --git a/evaluation/fishfarm/pyproject.toml b/evaluation/fishfarm/pyproject.toml new file mode 100644 index 0000000..8375b35 --- /dev/null +++ b/evaluation/fishfarm/pyproject.toml @@ -0,0 +1,109 @@ +[project] +name = "fishfarm" +description = "" +readme = "README.md" +license = {file = "LICENSE"} +authors = [ + {name = "Takuya Akiba"}, + {email = "takiba@sakana.ai"} +] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", +] +requires-python = ">=3.10" +dependencies = [ + "huggingface_hub", + "transformers", + "pydantic", + "colorlog" +] +dynamic = ["version"] + +[project.optional-dependencies] +development = [ + "black", + "blackdoc", + "flake8", + "isort", + "mypy", + "pytest", + "pytest-mock", + "types-PyYAML", +] + +full = [ + "vllm", + "langchain", + "langchain-openai", + "fasttext-wheel", + "datasets", + "mysql-connector-python==8.0.32", + "docker==6.1.2", + "evalplus @ git+https://github.com/evalplus/evalplus@1895d2f6aa8895044a7cf69defc24bd57695e885", + "rouge-score" +] + +[project.urls] +repository = "https://github.com/SakanaAI/fishfarm" + +[tool.setuptools.packages.find] +include = ["fishfarm*"] + +[tool.setuptools.dynamic] +version = {attr = "fishfarm.version.__version__"} + +[tool.black] +line-length = 99 +target-version = ['py310'] +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.venv + | venv + | _build + | buck-out + | build + | dist + | docs + | data +)/ +''' + +[tool.isort] +profile = 'black' +src_paths = ['fishfarm', 'tests'] +line_length = 99 +lines_after_imports = 2 + +[tool.mypy] +python_version = "3.10" +strict = true +ignore_missing_imports = true +warn_unused_configs = true +disallow_untyped_defs = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true +disallow_any_generics = false +exclude = ".venv|venv|build|docs|tutorial|data" + +[tool.pytest] +mock_use_standalone_module = true diff --git a/evaluation/fishfarm/tox.ini b/evaluation/fishfarm/tox.ini new file mode 100644 index 0000000..68b3967 --- /dev/null +++ b/evaluation/fishfarm/tox.ini @@ -0,0 +1,8 @@ +[flake8] +max-line-length = 99 +statistics = True +exclude = .venv,venv,build,notebooks,.asv,data +ignore = + E203, + W503, + E704 \ No newline at end of file