Skip to content

Commit

Permalink
add evaluation package
Browse files Browse the repository at this point in the history
  • Loading branch information
floatingbigcat committed Jan 14, 2025
1 parent b4a22aa commit ade0e67
Show file tree
Hide file tree
Showing 22 changed files with 2,034 additions and 0 deletions.
14 changes: 14 additions & 0 deletions evaluation/fishfarm/fishfarm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from . import chat_templates, models, tasks
from .models import Message, Model, Role
from .tasks import Task, TaskResult

__all__ = [
"chat_templates",
"tasks",
"models",
"Task",
"TaskResult",
"Model",
"Message",
"Role",
]
13 changes: 13 additions & 0 deletions evaluation/fishfarm/fishfarm/chat_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
LLAMA3 = (
"{% set loop_messages = messages %}"
"{% for message in loop_messages %}"
"{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>"
"\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
"{% if loop.index0 == 0 %}{% set content = bos_token + content %}"
"{% endif %}"
"{{ content }}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
"{% endif %}"
)
94 changes: 94 additions & 0 deletions evaluation/fishfarm/fishfarm/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from types import TracebackType
from typing import Optional, Tuple, Type


class _DeferredImportExceptionContextManager:
"""Context manager to defer exceptions from imports.
Catches :exc:`ImportError` and :exc:`SyntaxError`.
If any exception is caught, this class raises an :exc:`ImportError` when being checked.
"""

def __init__(self) -> None:
self._deferred: Optional[Tuple[Exception, str]] = None

def __enter__(self) -> "_DeferredImportExceptionContextManager":
"""Enter the context manager.
Returns:
Itself.
"""
return self

def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
traceback: Optional[TracebackType],
) -> Optional[bool]:
"""Exit the context manager.
Args:
exc_type:
Raised exception type. :obj:`None` if nothing is raised.
exc_value:
Raised exception object. :obj:`None` if nothing is raised.
traceback:
Associated traceback. :obj:`None` if nothing is raised.
Returns:
:obj:`None` if nothing is deferred, otherwise :obj:`True`.
:obj:`True` will suppress any exceptions avoiding them from propagating.
"""
if isinstance(exc_value, (ImportError, SyntaxError)):
if isinstance(exc_value, ImportError):
message = (
"Tried to import '{}' but failed. Please make sure that the package is "
"installed correctly to use this feature. Actual error: {}."
).format(exc_value.name, exc_value)
elif isinstance(exc_value, SyntaxError):
message = (
"Tried to import a package but failed due to a syntax error in {}. Please "
"make sure that the Python version is correct to use this feature. Actual "
"error: {}."
).format(exc_value.filename, exc_value)
else:
assert False

self._deferred = (exc_value, message)
return True
return None

def is_successful(self) -> bool:
"""Return whether the context manager has caught any exceptions.
Returns:
:obj:`True` if no exceptions are caught, :obj:`False` otherwise.
"""
return self._deferred is None

def check(self) -> None:
"""Check whether the context manager has caught any exceptions.
Raises:
:exc:`ImportError`:
If any exception was caught from the caught exception.
"""
if self._deferred is not None:
exc_value, message = self._deferred
raise ImportError(message) from exc_value


def try_import() -> _DeferredImportExceptionContextManager:
"""Create a context manager that can wrap imports of optional packages to defer exceptions.
Returns:
Deferred import context manager.
"""
return _DeferredImportExceptionContextManager()
190 changes: 190 additions & 0 deletions evaluation/fishfarm/fishfarm/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""
Copied from Optuna repo:
https://github.com/optuna/optuna/blob/2595653638506e1b7e025a966a220984a59ab936/optuna/logging.py
Removed some comments for less verbosity.
In general, `logger.info` is preferred over `print` since it contains module name and timestamp;
We recommend the use of logger object for the fishfarm developers.
Inside fishfarm, we can call `get_logger(__name__)` from each python file.
Then the root logger format and level are applied to that logger object.
"""

from __future__ import annotations

import logging
import os
import sys
import threading
from logging import CRITICAL, DEBUG, ERROR, FATAL, INFO, WARN, WARNING

import colorlog

__all__ = [
"CRITICAL",
"DEBUG",
"ERROR",
"FATAL",
"INFO",
"WARN",
"WARNING",
]

_lock: threading.Lock = threading.Lock()
_default_handler: logging.Handler | None = None


def create_default_formatter() -> logging.Formatter:
"""Create a default formatter of log messages.
This function is not supposed to be directly accessed by library users.
"""
header = "[%(levelname)1.1s %(asctime)s %(name)s]"
message = "%(message)s"
if _color_supported():
return colorlog.ColoredFormatter(
f"%(log_color)s{header}%(reset)s {message}",
)
return logging.Formatter(f"{header} {message}")


def _color_supported() -> bool:
"""Detection of color support."""
# NO_COLOR environment variable:
if os.environ.get("NO_COLOR", None):
return False

if not hasattr(sys.stderr, "isatty") or not sys.stderr.isatty():
return False
else:
return True


def _get_library_name() -> str:
return __name__.split(".")[0]


def _get_library_root_logger() -> logging.Logger:
return logging.getLogger(_get_library_name())


def _configure_library_root_logger() -> None:
global _default_handler

with _lock:
if _default_handler:
# This library has already configured the library root logger.
return
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
_default_handler.setFormatter(create_default_formatter())

# Apply our default configuration to the library root logger.
library_root_logger: logging.Logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(logging.INFO)
library_root_logger.propagate = False


def _reset_library_root_logger() -> None:
global _default_handler

with _lock:
if not _default_handler:
return

library_root_logger: logging.Logger = _get_library_root_logger()
library_root_logger.removeHandler(_default_handler)
library_root_logger.setLevel(logging.NOTSET)
_default_handler = None


def get_logger(name: str) -> logging.Logger:
"""Return a logger with the specified name.
name's prefix should be `fishfarm.` (just like __name__ variable),
otherwise root logger settings will be not reflected.
This function is not supposed to be directly accessed by library users.
"""

_configure_library_root_logger()
return logging.getLogger(name)


def get_verbosity() -> int:
"""Return the current level for the fishfarm's root logger.
Returns:
Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``.
.. note::
fishfarm has following logging levels:
- ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL``
- ``fishfarm.logging.ERROR``
- ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN``
- ``fishfarm.logging.INFO``
- ``fishfarm.logging.DEBUG``
"""

_configure_library_root_logger()
return _get_library_root_logger().getEffectiveLevel()


def set_verbosity(verbosity: int) -> None:
"""Set the level for the fishfarm's root logger.
Args:
verbosity:
Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``.
.. note::
fishfarm has following logging levels:
- ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL``
- ``fishfarm.logging.ERROR``
- ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN``
- ``fishfarm.logging.INFO``
- ``fishfarm.logging.DEBUG``
"""

_configure_library_root_logger()
_get_library_root_logger().setLevel(verbosity)


def disable_default_handler() -> None:
"""Disable the default handler of the fishfarm's root logger."""

_configure_library_root_logger()

assert _default_handler is not None
_get_library_root_logger().removeHandler(_default_handler)


def enable_default_handler() -> None:
"""Enable the default handler of the fishfarm's root logger."""

_configure_library_root_logger()

assert _default_handler is not None
_get_library_root_logger().addHandler(_default_handler)


def disable_propagation() -> None:
"""Disable propagation of the library log outputs.
Note that log propagation is disabled by default. You only need to use this function
to stop log propagation when you use :func:`~fishfarm.logging.enable_propagation()`.
"""

_configure_library_root_logger()
_get_library_root_logger().propagate = False


def enable_propagation() -> None:
"""Enable propagation of the library log outputs.
Please disable the fishfarm's default handler to prevent double logging if the root logger has
been configured.
"""

_configure_library_root_logger()
_get_library_root_logger().propagate = True
12 changes: 12 additions & 0 deletions evaluation/fishfarm/fishfarm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .base import (GenerationRequest, GenerationResult, Message, Model,
NLLRequest, NLLResult, Role)

__all__ = [
"GenerationRequest",
"GenerationResult",
"NLLRequest",
"NLLResult",
"Model",
"Role",
"Message",
]
54 changes: 54 additions & 0 deletions evaluation/fishfarm/fishfarm/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Literal, Optional, Sequence

Role = Literal["system", "user", "assistant", "assistant_prefill"]


@dataclass
class Message:

role: Role
content: str


@dataclass
class GenerationRequest:

messages: list[Message]

max_tokens: Optional[int] = None
stop: Sequence[str] = ()


@dataclass
class GenerationResult:

request: GenerationRequest
generation: str


@dataclass
class NLLRequest:

messages: list[Message]


@dataclass
class NLLResult:

request: NLLRequest
sum_nll: float
num_considered_tokens: int


class Model:

def generate(
self, requests: Sequence[GenerationRequest]
) -> Iterable[GenerationResult]:
raise NotImplementedError()

def nll(self, requests: Sequence[NLLRequest]) -> Iterable[NLLResult]:
raise NotImplementedError()
Loading

0 comments on commit ade0e67

Please sign in to comment.