diff --git a/predict.py b/predict.py index 56e0572..72b4a71 100644 --- a/predict.py +++ b/predict.py @@ -6,6 +6,7 @@ import multiprocessing as mp from typing import Optional +import cog from cog import BasePredictor, ConcatenateIterator, Input if mp.current_process().name != "MainProcess": @@ -29,6 +30,23 @@ ) +def format_prompt( + prompt: str, prompt_template: str, system_prompt: Optional[str] +) -> str: + if not prompt_template: + prompt_template = "{prompt}" + if prompt and "{prompt}" not in prompt_template: + raise Exception( + "You have submitted both a prompt and a prompt template that doesn't include '{prompt}'. " + "Your prompt would not be used. " + "If don't want to use formatting, use your full prompt for the prompt argument and set prompt_template to '{prompt}'." + ) + return prompt_template.format( + system_prompt=system_prompt or "", + prompt=prompt, + ) + + class Predictor(BasePredictor): async def setup(self, weights: str = "") -> None: self.log_performance_metrics = bool(os.getenv("LOG_PERFORMANCE_METRICS", False)) @@ -206,7 +224,7 @@ async def predict( ) return - formatted_prompt = self._format_prompt( + formatted_prompt = format_prompt( prompt=prompt, system_prompt=system_prompt, prompt_template=prompt_template ) if formatted_prompt == "": @@ -229,8 +247,10 @@ async def predict( f"Can't set both min_tokens ({min_tokens}) and min_new_tokens ({min_new_tokens})" ) + n_prompt_tokens = self._get_n_tokens(prompt) args = self._process_args( prompt=formatted_prompt, + n_prompt_tokens=n_prompt_tokens, max_tokens=max_tokens, min_tokens=min_tokens, top_k=top_k, @@ -308,6 +328,8 @@ async def predict( f"Serverside time to first token: {round(time_to_first_token, 2)} seconds\n" ) + cog.emit_metric("input_token_count", n_prompt_tokens) + cog.emit_metric("output_token_count", n_tokens) self.log(f"Random seed used: `{args['random_seed']}`\n") self.log( "Note: Random seed will not impact output if greedy decoding is used.\n" @@ -323,6 +345,7 @@ async def predict( def _process_args( self, prompt: str, + n_prompt_tokens: int, max_tokens: int = 250, min_tokens: Optional[int] = None, top_k: int = 0, @@ -348,8 +371,6 @@ def _process_args( if not seed: seed = int(np.random.randint(0, 100000)) - n_prompt_tokens = self._get_n_tokens(prompt) - if self.max_sequence_length: token_budget = self.max_sequence_length - n_prompt_tokens max_tokens = min(max_tokens, token_budget) @@ -373,19 +394,5 @@ def _process_args( return args - def _format_prompt( - self, prompt: str, prompt_template: str, system_prompt: str - ) -> str: - if not prompt_template: - return prompt - if "system_prompt" in prompt_template: - system_prompt = system_prompt if system_prompt else "" - formatted_prompt = prompt_template.format( - system_prompt=system_prompt, prompt=prompt - ) - return formatted_prompt - formatted_prompt = prompt_template.format(prompt=prompt) - return formatted_prompt - def _get_n_tokens(self, text: str) -> int: return len(self.tokenizer(text)["input_ids"])