-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b4a22aa
commit ade0e67
Showing
22 changed files
with
2,034 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from . import chat_templates, models, tasks | ||
from .models import Message, Model, Role | ||
from .tasks import Task, TaskResult | ||
|
||
__all__ = [ | ||
"chat_templates", | ||
"tasks", | ||
"models", | ||
"Task", | ||
"TaskResult", | ||
"Model", | ||
"Message", | ||
"Role", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
LLAMA3 = ( | ||
"{% set loop_messages = messages %}" | ||
"{% for message in loop_messages %}" | ||
"{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>" | ||
"\n\n'+ message['content'] | trim + '<|eot_id|>' %}" | ||
"{% if loop.index0 == 0 %}{% set content = bos_token + content %}" | ||
"{% endif %}" | ||
"{{ content }}" | ||
"{% endfor %}" | ||
"{% if add_generation_prompt %}" | ||
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" | ||
"{% endif %}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from types import TracebackType | ||
from typing import Optional, Tuple, Type | ||
|
||
|
||
class _DeferredImportExceptionContextManager: | ||
"""Context manager to defer exceptions from imports. | ||
Catches :exc:`ImportError` and :exc:`SyntaxError`. | ||
If any exception is caught, this class raises an :exc:`ImportError` when being checked. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._deferred: Optional[Tuple[Exception, str]] = None | ||
|
||
def __enter__(self) -> "_DeferredImportExceptionContextManager": | ||
"""Enter the context manager. | ||
Returns: | ||
Itself. | ||
""" | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: Optional[Type[Exception]], | ||
exc_value: Optional[Exception], | ||
traceback: Optional[TracebackType], | ||
) -> Optional[bool]: | ||
"""Exit the context manager. | ||
Args: | ||
exc_type: | ||
Raised exception type. :obj:`None` if nothing is raised. | ||
exc_value: | ||
Raised exception object. :obj:`None` if nothing is raised. | ||
traceback: | ||
Associated traceback. :obj:`None` if nothing is raised. | ||
Returns: | ||
:obj:`None` if nothing is deferred, otherwise :obj:`True`. | ||
:obj:`True` will suppress any exceptions avoiding them from propagating. | ||
""" | ||
if isinstance(exc_value, (ImportError, SyntaxError)): | ||
if isinstance(exc_value, ImportError): | ||
message = ( | ||
"Tried to import '{}' but failed. Please make sure that the package is " | ||
"installed correctly to use this feature. Actual error: {}." | ||
).format(exc_value.name, exc_value) | ||
elif isinstance(exc_value, SyntaxError): | ||
message = ( | ||
"Tried to import a package but failed due to a syntax error in {}. Please " | ||
"make sure that the Python version is correct to use this feature. Actual " | ||
"error: {}." | ||
).format(exc_value.filename, exc_value) | ||
else: | ||
assert False | ||
|
||
self._deferred = (exc_value, message) | ||
return True | ||
return None | ||
|
||
def is_successful(self) -> bool: | ||
"""Return whether the context manager has caught any exceptions. | ||
Returns: | ||
:obj:`True` if no exceptions are caught, :obj:`False` otherwise. | ||
""" | ||
return self._deferred is None | ||
|
||
def check(self) -> None: | ||
"""Check whether the context manager has caught any exceptions. | ||
Raises: | ||
:exc:`ImportError`: | ||
If any exception was caught from the caught exception. | ||
""" | ||
if self._deferred is not None: | ||
exc_value, message = self._deferred | ||
raise ImportError(message) from exc_value | ||
|
||
|
||
def try_import() -> _DeferredImportExceptionContextManager: | ||
"""Create a context manager that can wrap imports of optional packages to defer exceptions. | ||
Returns: | ||
Deferred import context manager. | ||
""" | ||
return _DeferredImportExceptionContextManager() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
""" | ||
Copied from Optuna repo: | ||
https://github.com/optuna/optuna/blob/2595653638506e1b7e025a966a220984a59ab936/optuna/logging.py | ||
Removed some comments for less verbosity. | ||
In general, `logger.info` is preferred over `print` since it contains module name and timestamp; | ||
We recommend the use of logger object for the fishfarm developers. | ||
Inside fishfarm, we can call `get_logger(__name__)` from each python file. | ||
Then the root logger format and level are applied to that logger object. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
import os | ||
import sys | ||
import threading | ||
from logging import CRITICAL, DEBUG, ERROR, FATAL, INFO, WARN, WARNING | ||
|
||
import colorlog | ||
|
||
__all__ = [ | ||
"CRITICAL", | ||
"DEBUG", | ||
"ERROR", | ||
"FATAL", | ||
"INFO", | ||
"WARN", | ||
"WARNING", | ||
] | ||
|
||
_lock: threading.Lock = threading.Lock() | ||
_default_handler: logging.Handler | None = None | ||
|
||
|
||
def create_default_formatter() -> logging.Formatter: | ||
"""Create a default formatter of log messages. | ||
This function is not supposed to be directly accessed by library users. | ||
""" | ||
header = "[%(levelname)1.1s %(asctime)s %(name)s]" | ||
message = "%(message)s" | ||
if _color_supported(): | ||
return colorlog.ColoredFormatter( | ||
f"%(log_color)s{header}%(reset)s {message}", | ||
) | ||
return logging.Formatter(f"{header} {message}") | ||
|
||
|
||
def _color_supported() -> bool: | ||
"""Detection of color support.""" | ||
# NO_COLOR environment variable: | ||
if os.environ.get("NO_COLOR", None): | ||
return False | ||
|
||
if not hasattr(sys.stderr, "isatty") or not sys.stderr.isatty(): | ||
return False | ||
else: | ||
return True | ||
|
||
|
||
def _get_library_name() -> str: | ||
return __name__.split(".")[0] | ||
|
||
|
||
def _get_library_root_logger() -> logging.Logger: | ||
return logging.getLogger(_get_library_name()) | ||
|
||
|
||
def _configure_library_root_logger() -> None: | ||
global _default_handler | ||
|
||
with _lock: | ||
if _default_handler: | ||
# This library has already configured the library root logger. | ||
return | ||
_default_handler = logging.StreamHandler() # Set sys.stderr as stream. | ||
_default_handler.setFormatter(create_default_formatter()) | ||
|
||
# Apply our default configuration to the library root logger. | ||
library_root_logger: logging.Logger = _get_library_root_logger() | ||
library_root_logger.addHandler(_default_handler) | ||
library_root_logger.setLevel(logging.INFO) | ||
library_root_logger.propagate = False | ||
|
||
|
||
def _reset_library_root_logger() -> None: | ||
global _default_handler | ||
|
||
with _lock: | ||
if not _default_handler: | ||
return | ||
|
||
library_root_logger: logging.Logger = _get_library_root_logger() | ||
library_root_logger.removeHandler(_default_handler) | ||
library_root_logger.setLevel(logging.NOTSET) | ||
_default_handler = None | ||
|
||
|
||
def get_logger(name: str) -> logging.Logger: | ||
"""Return a logger with the specified name. | ||
name's prefix should be `fishfarm.` (just like __name__ variable), | ||
otherwise root logger settings will be not reflected. | ||
This function is not supposed to be directly accessed by library users. | ||
""" | ||
|
||
_configure_library_root_logger() | ||
return logging.getLogger(name) | ||
|
||
|
||
def get_verbosity() -> int: | ||
"""Return the current level for the fishfarm's root logger. | ||
Returns: | ||
Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``. | ||
.. note:: | ||
fishfarm has following logging levels: | ||
- ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL`` | ||
- ``fishfarm.logging.ERROR`` | ||
- ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN`` | ||
- ``fishfarm.logging.INFO`` | ||
- ``fishfarm.logging.DEBUG`` | ||
""" | ||
|
||
_configure_library_root_logger() | ||
return _get_library_root_logger().getEffectiveLevel() | ||
|
||
|
||
def set_verbosity(verbosity: int) -> None: | ||
"""Set the level for the fishfarm's root logger. | ||
Args: | ||
verbosity: | ||
Logging level, e.g., ``fishfarm.logging.DEBUG`` and ``fishfarm.logging.INFO``. | ||
.. note:: | ||
fishfarm has following logging levels: | ||
- ``fishfarm.logging.CRITICAL``, ``fishfarm.logging.FATAL`` | ||
- ``fishfarm.logging.ERROR`` | ||
- ``fishfarm.logging.WARNING``, ``fishfarm.logging.WARN`` | ||
- ``fishfarm.logging.INFO`` | ||
- ``fishfarm.logging.DEBUG`` | ||
""" | ||
|
||
_configure_library_root_logger() | ||
_get_library_root_logger().setLevel(verbosity) | ||
|
||
|
||
def disable_default_handler() -> None: | ||
"""Disable the default handler of the fishfarm's root logger.""" | ||
|
||
_configure_library_root_logger() | ||
|
||
assert _default_handler is not None | ||
_get_library_root_logger().removeHandler(_default_handler) | ||
|
||
|
||
def enable_default_handler() -> None: | ||
"""Enable the default handler of the fishfarm's root logger.""" | ||
|
||
_configure_library_root_logger() | ||
|
||
assert _default_handler is not None | ||
_get_library_root_logger().addHandler(_default_handler) | ||
|
||
|
||
def disable_propagation() -> None: | ||
"""Disable propagation of the library log outputs. | ||
Note that log propagation is disabled by default. You only need to use this function | ||
to stop log propagation when you use :func:`~fishfarm.logging.enable_propagation()`. | ||
""" | ||
|
||
_configure_library_root_logger() | ||
_get_library_root_logger().propagate = False | ||
|
||
|
||
def enable_propagation() -> None: | ||
"""Enable propagation of the library log outputs. | ||
Please disable the fishfarm's default handler to prevent double logging if the root logger has | ||
been configured. | ||
""" | ||
|
||
_configure_library_root_logger() | ||
_get_library_root_logger().propagate = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from .base import (GenerationRequest, GenerationResult, Message, Model, | ||
NLLRequest, NLLResult, Role) | ||
|
||
__all__ = [ | ||
"GenerationRequest", | ||
"GenerationResult", | ||
"NLLRequest", | ||
"NLLResult", | ||
"Model", | ||
"Role", | ||
"Message", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Iterable, Literal, Optional, Sequence | ||
|
||
Role = Literal["system", "user", "assistant", "assistant_prefill"] | ||
|
||
|
||
@dataclass | ||
class Message: | ||
|
||
role: Role | ||
content: str | ||
|
||
|
||
@dataclass | ||
class GenerationRequest: | ||
|
||
messages: list[Message] | ||
|
||
max_tokens: Optional[int] = None | ||
stop: Sequence[str] = () | ||
|
||
|
||
@dataclass | ||
class GenerationResult: | ||
|
||
request: GenerationRequest | ||
generation: str | ||
|
||
|
||
@dataclass | ||
class NLLRequest: | ||
|
||
messages: list[Message] | ||
|
||
|
||
@dataclass | ||
class NLLResult: | ||
|
||
request: NLLRequest | ||
sum_nll: float | ||
num_considered_tokens: int | ||
|
||
|
||
class Model: | ||
|
||
def generate( | ||
self, requests: Sequence[GenerationRequest] | ||
) -> Iterable[GenerationResult]: | ||
raise NotImplementedError() | ||
|
||
def nll(self, requests: Sequence[NLLRequest]) -> Iterable[NLLResult]: | ||
raise NotImplementedError() |
Oops, something went wrong.