Skip to content

Commit

Permalink
Vllm update DP+TP (EleutherAI#1508)
Browse files Browse the repository at this point in the history
* use `@ray.remote` with distributed vLLM

* update versions

* bugfix

* unpin vllm

* fix pre-commit

* added version assertion error

* Revert "added version assertion error"

This reverts commit 8041e9b.

* added version assertion for DP

* expand DP note

* add warning

* nit

* pin vllm

* fix typos
  • Loading branch information
baberabb authored Mar 3, 2024
1 parent ae79b12 commit e5e35fc
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 21 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
exclude: ^tests/testdata/
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-ast
Expand All @@ -29,7 +29,7 @@ repos:
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.8
rev: v0.2.2
hooks:
# Run the linter.
- id: ruff
Expand All @@ -38,7 +38,7 @@ repos:
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
rev: v2.2.6
hooks:
- id: codespell
exclude: >
Expand Down
2 changes: 1 addition & 1 deletion lm_eval/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations.
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]
Expand Down
42 changes: 28 additions & 14 deletions lm_eval/models/vllm_causallms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
from importlib.metadata import version
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union

from more_itertools import distribute
from packaging.version import parse as parse_version
from tqdm import tqdm

from lm_eval.api.instance import Instance
Expand All @@ -18,7 +20,6 @@

try:
import ray
from ray.util.multiprocessing import Pool
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError:
Expand All @@ -27,14 +28,6 @@
eval_logger = eval_logger


# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]]
):
llm = LLM(**model_args)
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)


@register_model("vllm")
class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
Expand All @@ -61,6 +54,7 @@ def __init__(
gpu_memory_utilization: float = 0.9,
device: str = "cuda",
data_parallel_size: int = 1,
**kwargs,
):
super().__init__()

Expand Down Expand Up @@ -93,6 +87,7 @@ def __init__(
"quantization": quantization,
"seed": int(seed),
}
self.model_args.update(kwargs)
self.batch_size = (
"auto"
if isinstance(batch_size, str) and "auto" in batch_size
Expand All @@ -101,6 +96,12 @@ def __init__(
if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args)
else:
assert parse_version(version("vllm")) < parse_version(
"0.3.3"
), "data_parallel is only compatible with vllm < v0.3.3."
eval_logger.warning(
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
)
self.model_args["worker_use_ray"] = True
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
Expand Down Expand Up @@ -182,13 +183,26 @@ def _model_generate(
temperature=0, prompt_logprobs=1, max_tokens=1
)
if self.data_parallel_size > 1:
# vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973
# note: this has changed on 0.3.3, and it only works now if num_gpus are set.
# but then tensor_parallel breaks
@ray.remote
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]]
):
llm = LLM(**model_args)
return llm.generate(
prompt_token_ids=requests, sampling_params=sampling_params
)

# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = [(self.model_args, sampling_params, req) for req in requests]

with Pool(self.data_parallel_size) as pool:
results = pool.starmap(run_inference_one_model, inputs)
inputs = ((self.model_args, sampling_params, req) for req in requests)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown()
# flatten results
Expand Down Expand Up @@ -286,7 +300,7 @@ def _collate_gen(_requests):
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
eos = self.tokenizer.decode(self.eot_token_id)
if not until:
until = [eos]
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ optimum = ["optimum[openvino]"]
promptsource = ["promptsource>=0.2.3"]
sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
vllm = ["vllm<=0.2.5"]
vllm = ["vllm==0.3.2"]
zeno = ["pandas", "zeno-client"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
all = [
Expand Down
4 changes: 2 additions & 2 deletions scripts/clean_training_data/janitor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ std::vector<std::string> clean_ngram(std::string const &input,
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);

// Otherwise, continute building
// Otherwise, continue building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
Expand Down Expand Up @@ -165,7 +165,7 @@ clean_ngram_with_indices(std::string const &input, std::string const &ignore,
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i + 1);

// Otherwise, continute building
// Otherwise, continue building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
Expand Down

0 comments on commit e5e35fc

Please sign in to comment.