Skip to content

Commit

Permalink
Refactor utilities into a separate model utils file. (EleutherAI#1429)
Browse files Browse the repository at this point in the history
  • Loading branch information
baberabb authored Feb 14, 2024
1 parent 620d6a1 commit 2d0a646
Show file tree
Hide file tree
Showing 11 changed files with 554 additions and 527 deletions.
2 changes: 1 addition & 1 deletion lm_eval/models/anthropic_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
from lm_eval.models.utils import retry_on_specific_exceptions


eval_logger = utils.eval_logger
Expand Down
24 changes: 15 additions & 9 deletions lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import Collator, stop_sequences_criteria
from lm_eval.models.utils import (
Collator,
clear_torch_cache,
get_dtype,
pad_and_concat,
stop_sequences_criteria,
)


eval_logger = utils.eval_logger
Expand Down Expand Up @@ -503,13 +509,13 @@ def _create_model(
if transformers.__version__ >= "4.30.0":
if model_kwargs.get("load_in_4bit", None):
if model_kwargs.get("bnb_4bit_compute_dtype", None):
model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype(
model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
model_kwargs["bnb_4bit_compute_dtype"]
)
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision,
torch_dtype=utils.get_dtype(dtype),
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
**model_kwargs,
)
Expand Down Expand Up @@ -639,10 +645,10 @@ def forward_batch(batch_size):
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
)
batch_size = min(gathered)
utils.clear_torch_cache()
clear_torch_cache()
return batch_size

utils.clear_torch_cache()
clear_torch_cache()
return batch_size

def tok_encode(
Expand Down Expand Up @@ -997,18 +1003,18 @@ def _collate(x):
# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = utils.pad_and_concat(
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat(
batched_inps = pad_and_concat(
padding_len_inp, inps
) # [batch, padding_len_inp]
batched_conts = utils.pad_and_concat(
batched_conts = pad_and_concat(
padding_len_cont, conts
) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat(
batched_encoder_mask = pad_and_concat(
padding_len_inp, encoder_attns
) # [batch, padding_len_inp]
call_kwargs = {
Expand Down
6 changes: 4 additions & 2 deletions lm_eval/models/mamba_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from lm_eval import utils
import lm_eval.models.utils
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM

Expand Down Expand Up @@ -97,7 +97,9 @@ def _create_model(
self._model = MambaLMHeadModel.from_pretrained(
pretrained,
device=self._device,
dtype=torch.float16 if dtype == "auto" else utils.get_dtype(dtype),
dtype=torch.float16
if dtype == "auto"
else lm_eval.models.utils.get_dtype(dtype),
)

def _model_generate(self, context, max_length, stop, **generation_kwargs):
Expand Down
15 changes: 9 additions & 6 deletions lm_eval/models/neuron_optimum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from transformers import GenerationConfig
from transformers.generation import StoppingCriteriaList

import lm_eval.models.utils
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import stop_sequences_criteria
from lm_eval.models.utils import stop_sequences_criteria


try:
Expand Down Expand Up @@ -239,7 +240,7 @@ def __init__(
revision=revision,
trust_remote_code=trust_remote_code,
)
torch_dtype = utils.get_dtype(dtype)
torch_dtype = lm_eval.models.utils.get_dtype(dtype)

assert torch_dtype in [
torch.float16,
Expand Down Expand Up @@ -550,7 +551,7 @@ def _collate(x):
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request

chunks = utils.chunks(
chunks = lm_eval.models.utils.chunks(
re_ord.get_reordered(),
n=self.batch_size,
fn=None,
Expand Down Expand Up @@ -603,7 +604,7 @@ def _collate(x):

# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
batched_inps = utils.pad_and_concat(
batched_inps = lm_eval.models.utils.pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]

Expand Down Expand Up @@ -663,7 +664,7 @@ def _collate(x):
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
Expand All @@ -672,7 +673,9 @@ def _collate(x):

# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
chunks = utils.chunks(re_ord.get_reordered(), n=self.batch_size)
chunks = lm_eval.models.utils.chunks(
re_ord.get_reordered(), n=self.batch_size
)
for chunk in tqdm(chunks, disable=self.rank != 0):
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
Expand Down
10 changes: 6 additions & 4 deletions lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

from tqdm import tqdm

import lm_eval.models.utils
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import eval_logger, retry_on_specific_exceptions
from lm_eval.models.utils import retry_on_specific_exceptions
from lm_eval.utils import eval_logger


def get_result(response, ctxlen: int) -> Tuple[float, bool]:
Expand Down Expand Up @@ -219,7 +221,7 @@ def _collate(x):
re_ord = utils.Reorderer(requests, _collate)

for chunk in tqdm(
list(utils.chunks(re_ord.get_reordered(), self.batch_size)),
list(lm_eval.models.utils.chunks(re_ord.get_reordered(), self.batch_size)),
disable=disable_tqdm,
):
inps = []
Expand Down Expand Up @@ -429,7 +431,7 @@ def generate_until(self, requests) -> List[str]:
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
grouper = lm_eval.models.utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer(
Expand All @@ -441,7 +443,7 @@ def generate_until(self, requests) -> List[str]:
# n needs to be 1 because messages in
# chat completion are not batch but
# is regarded as a single conversation.
chunks = utils.chunks(re_ord.get_reordered(), n=1)
chunks = lm_eval.models.utils.chunks(re_ord.get_reordered(), n=1)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
inps = [{"role": "user", "content": context} for context in contexts]
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/textsynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
from lm_eval.models.utils import retry_on_specific_exceptions


logger = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 2d0a646

Please sign in to comment.