Skip to content

Commit

Permalink
Better prompt length handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Dec 13, 2023
1 parent e3d415a commit d3b4385
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 3 deletions.
18 changes: 16 additions & 2 deletions aana/deployments/vllm_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.utils import get_gpu_memory, random_uuid

from aana.deployments.base_deployment import BaseDeployment
from aana.exceptions.general import InferenceException
from aana.exceptions.general import InferenceException, PromptTooLongException
from aana.models.pydantic.chat_message import ChatDialog, ChatMessage
from aana.models.pydantic.sampling_params import SamplingParams
from aana.utils.chat_template import apply_chat_template
Expand Down Expand Up @@ -116,6 +116,7 @@ async def apply_config(self, config: dict[str, Any]):
# create the engine
self.engine = AsyncLLMEngine.from_engine_args(args)
self.tokenizer = self.engine.engine.tokenizer
self.model_config = await self.engine.get_model_config()

async def generate_stream(
self, prompt: str, sampling_params: SamplingParams
Expand All @@ -132,6 +133,16 @@ async def generate_stream(
prompt = str(prompt)
sampling_params = merged_options(self.default_sampling_params, sampling_params)
request_id = None

# tokenize the prompt
prompt_token_ids = self.tokenizer.encode(prompt)

if len(prompt_token_ids) > self.model_config.max_model_len:
raise PromptTooLongException(
prompt_len=len(prompt_token_ids),
max_len=self.model_config.max_model_len,
)

try:
# convert SamplingParams to VLLMSamplingParams
sampling_params_vllm = VLLMSamplingParams(
Expand All @@ -142,7 +153,10 @@ async def generate_stream(
# set the random seed for reproducibility
set_random_seed(42)
results_generator = self.engine.generate(
prompt, sampling_params_vllm, request_id
prompt=None,
sampling_params=sampling_params_vllm,
request_id=request_id,
prompt_token_ids=prompt_token_ids,
)

num_returned = 0
Expand Down
24 changes: 24 additions & 0 deletions aana/exceptions/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,27 @@ class VideoReadingException(VideoException):
"""

pass


class PromptTooLongException(BaseException):
"""Exception raised when the prompt is too long.
Attributes:
prompt_len (int): the length of the prompt in tokens
max_len (int): the maximum allowed length of the prompt in tokens
"""

def __init__(self, prompt_len: int, max_len: int):
"""Initialize the exception.
Args:
prompt_len (int): the length of the prompt in tokens
max_len (int): the maximum allowed length of the prompt in tokens
"""
super().__init__(prompt_len=prompt_len, max_len=max_len)
self.prompt_len = prompt_len
self.max_len = max_len

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.prompt_len, self.max_len))
8 changes: 8 additions & 0 deletions aana/tests/deployments/test_vllm_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ray import serve

from aana.configs.deployments import deployments
from aana.exceptions.general import PromptTooLongException
from aana.models.pydantic.chat_message import ChatDialog, ChatMessage
from aana.models.pydantic.sampling_params import SamplingParams
from aana.tests.utils import compare_texts, is_gpu_available
Expand Down Expand Up @@ -111,3 +112,10 @@ async def test_vllm_deployments():
text += chunk["text"]

compare_texts(expected_text, text)

# test generate method with too long prompt
with pytest.raises(PromptTooLongException):
output = await handle.generate.remote(
prompt="[INST] Who is Elon Musk? [/INST]" * 1000,
sampling_params=SamplingParams(temperature=0.0, max_tokens=32),
)
17 changes: 16 additions & 1 deletion aana/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,22 @@ def load_video_timeline(media_id: str):
return timeline


def generate_dialog(metadata: dict, timeline: list[dict], question: str) -> ChatDialog:
def generate_dialog(
metadata: dict,
timeline: list[dict],
question: str,
max_timeline_len: int | None = 1024,
) -> ChatDialog:
"""Generates a dialog from the metadata and timeline of a video.
Args:
metadata (dict): the metadata of the video
timeline (list[dict]): the timeline of the video
question (str): the question to ask
max_timeline_len (int, optional): the maximum length of the timeline in tokens.
Defaults to 1024.
If the timeline is longer than this, it will be truncated.
If None, the timeline will not be truncated.
Returns:
ChatDialog: the generated dialog
Expand Down Expand Up @@ -419,6 +428,12 @@ def generate_dialog(metadata: dict, timeline: list[dict], question: str) -> Chat
)

timeline_json = json.dumps(timeline, indent=4, separators=(",", ": "))
# truncate the timeline if it is too long
timeline_tokens = (
timeline_json.split()
) # not an accurate count of tokens, but good enough
if max_timeline_len is not None and len(timeline_tokens) > max_timeline_len:
timeline_json = " ".join(timeline_tokens[:max_timeline_len])

messages = []
messages.append(ChatMessage(content=system_prompt_preamble, role="system"))
Expand Down

0 comments on commit d3b4385

Please sign in to comment.