Skip to content

Commit

Permalink
imporve log
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jan 8, 2025
1 parent d23a988 commit 47e17dd
Show file tree
Hide file tree
Showing 16 changed files with 78 additions and 67 deletions.
2 changes: 1 addition & 1 deletion src/llamafactory/chat/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
try:
asyncio.get_event_loop()
except RuntimeError:
logger.warning_once("There is no current event loop, creating a new one.")
logger.warning_rank0_once("There is no current event loop, creating a new one.")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def merge_dataset(
return all_datasets[0]
elif data_args.mix_strategy == "concat":
if data_args.streaming:
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
logger.warning_rank0_once("The samples between different datasets will not be mixed in streaming mode.")

return concatenate_datasets(all_datasets)
elif data_args.mix_strategy.startswith("interleave"):
if not data_args.streaming:
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
logger.warning_rank0_once("We recommend using `mix_strategy=concat` in non-streaming mode.")

return interleave_datasets(
datasets=all_datasets,
Expand Down
7 changes: 3 additions & 4 deletions src/llamafactory/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@

import numpy as np
from datasets import DatasetDict, load_dataset, load_from_disk
from transformers.utils.versions import require_version

from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import has_tokenized_data
from ..extras.misc import check_version, has_tokenized_data
from .aligner import align_dataset
from .data_utils import merge_dataset, split_dataset
from .parser import get_dataset_list
Expand Down Expand Up @@ -84,7 +83,7 @@ def _load_single_dataset(
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")

if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore

Expand All @@ -103,7 +102,7 @@ def _load_single_dataset(
dataset = dataset.to_hf_dataset()

elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
check_version("openmind>=0.8.0", mandatory=True)
from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore

Expand Down
8 changes: 6 additions & 2 deletions src/llamafactory/data/mm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,14 @@ def _validate_input(
Validates if this model accepts the input modalities.
"""
if len(images) != 0 and self.image_token is None:
raise ValueError("This model does not support image input.")
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
)

if len(videos) != 0 and self.video_token is None:
raise ValueError("This model does not support video input.")
raise ValueError(
"This model does not support video input. Please check whether the correct `template` is used."
)

def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
r"""
Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union

from transformers.utils.versions import require_version
from typing_extensions import override

from ..extras import logging
from ..extras.misc import check_version
from .data_utils import Role
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
from .mm_plugin import get_mm_plugin
Expand Down Expand Up @@ -365,7 +365,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
raise ValueError(f"Template {data_args.template} does not exist.")

if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
check_version("transformers>=4.45.0")

if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.")
Expand Down
8 changes: 4 additions & 4 deletions src/llamafactory/extras/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def close(self) -> None:

class _Logger(logging.Logger):
r"""
A logger that supports info_rank0 and warning_once.
A logger that supports rank0 logging.
"""

def info_rank0(self, *args, **kwargs) -> None:
Expand All @@ -77,7 +77,7 @@ def info_rank0(self, *args, **kwargs) -> None:
def warning_rank0(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)

def warning_once(self, *args, **kwargs) -> None:
def warning_rank0_once(self, *args, **kwargs) -> None:
self.warning(*args, **kwargs)


Expand Down Expand Up @@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:


@lru_cache(None)
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
def warning_rank0_once(self: "logging.Logger", *args, **kwargs) -> None:
if int(os.getenv("LOCAL_RANK", "0")) == 0:
self.warning(*args, **kwargs)


logging.Logger.info_rank0 = info_rank0
logging.Logger.warning_rank0 = warning_rank0
logging.Logger.warning_once = warning_once
logging.Logger.warning_rank0_once = warning_rank0_once
34 changes: 23 additions & 11 deletions src/llamafactory/extras/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,31 @@ def update(self, val, n=1):
self.avg = self.sum / self.count


def check_dependencies() -> None:
def check_version(requirement: str, mandatory: bool = False) -> None:
r"""
Checks the version of the required packages.
Optionally checks the package version.
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"] and not mandatory:
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
return

require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
if mandatory:
hint = f"To fix: run `pip install {requirement}`."
else:
hint = f"To fix: run `pip install {requirement}` or set `DISABLE_VERSION_CHECK=1` to skip this check."

require_version(requirement, hint)


def check_dependencies() -> None:
r"""
Checks the version of the required packages.
"""
check_version("transformers>=4.41.2,<=4.46.1")
check_version("datasets>=2.16.0,<=3.1.0")
check_version("accelerate>=0.34.0,<=1.0.1")
check_version("peft>=0.11.1,<=0.12.0")
check_version("trl>=0.8.6,<=0.9.6")


def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
Expand Down Expand Up @@ -253,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return model_args.model_name_or_path

if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
check_version("modelscope>=1.11.0", mandatory=True)
from modelscope import snapshot_download # type: ignore

revision = "master" if model_args.model_revision == "main" else model_args.model_revision
Expand All @@ -264,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
)

if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
check_version("openmind>=0.8.0", mandatory=True)
from openmind.utils.hub import snapshot_download # type: ignore

return snapshot_download(
Expand Down
30 changes: 13 additions & 17 deletions src/llamafactory/hparams/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import ParallelMode
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
from transformers.utils.versions import require_version

from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES
from ..extras.misc import check_dependencies, get_current_device
from ..extras.misc import check_dependencies, check_version, get_current_device
from .data_args import DataArguments
from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments
Expand Down Expand Up @@ -124,38 +123,35 @@ def _check_extra_dependencies(
finetuning_args: "FinetuningArguments",
training_args: Optional["TrainingArguments"] = None,
) -> None:
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
return

if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
check_version("unsloth", mandatory=True)

if model_args.enable_liger_kernel:
require_version("liger-kernel", "To fix: pip install liger-kernel")
check_version("liger-kernel", mandatory=True)

if model_args.mixture_of_depths is not None:
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
check_version("mixture-of-depth>=1.1.6", mandatory=True)

if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7")
check_version("vllm>=0.4.3,<0.6.7")
check_version("vllm", mandatory=True)

if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")
check_version("galore_torch", mandatory=True)

if finetuning_args.use_badam:
require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1")
check_version("badam>=1.2.1", mandatory=True)

if finetuning_args.use_adam_mini:
require_version("adam-mini", "To fix: pip install adam-mini")
check_version("adam-mini", mandatory=True)

if finetuning_args.plot_loss:
require_version("matplotlib", "To fix: pip install matplotlib")
check_version("matplotlib", mandatory=True)

if training_args is not None and training_args.predict_with_generate:
require_version("jieba", "To fix: pip install jieba")
require_version("nltk", "To fix: pip install nltk")
require_version("rouge_chinese", "To fix: pip install rouge-chinese")
check_version("jieba", mandatory=True)
check_version("nltk", mandatory=True)
check_version("rouge_chinese", mandatory=True)


def _parse_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _TRAIN_CLS:
Expand Down
6 changes: 3 additions & 3 deletions src/llamafactory/model/model_utils/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from typing import TYPE_CHECKING

from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from transformers.utils.versions import require_version

from ...extras import logging
from ...extras.misc import check_version


if TYPE_CHECKING:
Expand All @@ -35,8 +35,8 @@ def configure_attn_implementation(
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3")
if model_args.flash_attn != "fa2":
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/model/model_utils/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
logger.warning_rank0_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)

Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/model/model_utils/longlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils.versions import require_version

from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than


Expand Down Expand Up @@ -353,7 +353,7 @@ def shift(state: "torch.Tensor") -> "torch.Tensor":


def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
check_version("transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
Expand Down
5 changes: 3 additions & 2 deletions src/llamafactory/model/model_utils/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.versions import require_version

from ...extras.misc import check_version


if TYPE_CHECKING:
Expand All @@ -26,7 +27,7 @@


def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0")
check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore

set_z3_leaf_modules(model, leaf_modules)
Expand Down
4 changes: 2 additions & 2 deletions src/llamafactory/model/model_utils/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@

import torch
import torch.nn.functional as F
from transformers.utils.versions import require_version

from ...extras import logging
from ...extras.misc import check_version
from ...extras.packages import is_transformers_version_greater_than


Expand Down Expand Up @@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.block_diag_attn:
return

require_version("transformers>=4.43.0,<=4.46.1", "To fix: pip install transformers>=4.43.0,<=4.46.1")
check_version("transformers>=4.43.0,<=4.46.1")
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
Loading

0 comments on commit 47e17dd

Please sign in to comment.