Skip to content

Commit

Permalink
emit input_token_count and output_token_count metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed May 17, 2024
1 parent 3e555aa commit b726e20
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import multiprocessing as mp
from typing import Optional

import cog
from cog import BasePredictor, ConcatenateIterator, Input

if mp.current_process().name != "MainProcess":
Expand All @@ -29,6 +30,23 @@
)


def format_prompt(
prompt: str, prompt_template: str, system_prompt: Optional[str]
) -> str:
if not prompt_template:
prompt_template = "{prompt}"
if prompt and "{prompt}" not in prompt_template:
raise Exception(
"You have submitted both a prompt and a prompt template that doesn't include '{prompt}'. "
"Your prompt would not be used. "
"If don't want to use formatting, use your full prompt for the prompt argument and set prompt_template to '{prompt}'."
)
return prompt_template.format(
system_prompt=system_prompt or "",
prompt=prompt,
)


class Predictor(BasePredictor):
async def setup(self, weights: str = "") -> None:
self.log_performance_metrics = bool(os.getenv("LOG_PERFORMANCE_METRICS", False))
Expand Down Expand Up @@ -206,7 +224,7 @@ async def predict(
)
return

formatted_prompt = self._format_prompt(
formatted_prompt = format_prompt(
prompt=prompt, system_prompt=system_prompt, prompt_template=prompt_template
)
if formatted_prompt == "":
Expand All @@ -229,8 +247,10 @@ async def predict(
f"Can't set both min_tokens ({min_tokens}) and min_new_tokens ({min_new_tokens})"
)

n_prompt_tokens = self._get_n_tokens(prompt)
args = self._process_args(
prompt=formatted_prompt,
n_prompt_tokens=n_prompt_tokens,
max_tokens=max_tokens,
min_tokens=min_tokens,
top_k=top_k,
Expand Down Expand Up @@ -308,6 +328,8 @@ async def predict(
f"Serverside time to first token: {round(time_to_first_token, 2)} seconds\n"
)

cog.emit_metric("input_token_count", n_prompt_tokens)
cog.emit_metric("output_token_count", n_tokens)
self.log(f"Random seed used: `{args['random_seed']}`\n")
self.log(
"Note: Random seed will not impact output if greedy decoding is used.\n"
Expand All @@ -323,6 +345,7 @@ async def predict(
def _process_args(
self,
prompt: str,
n_prompt_tokens: int,
max_tokens: int = 250,
min_tokens: Optional[int] = None,
top_k: int = 0,
Expand All @@ -348,8 +371,6 @@ def _process_args(
if not seed:
seed = int(np.random.randint(0, 100000))

n_prompt_tokens = self._get_n_tokens(prompt)

if self.max_sequence_length:
token_budget = self.max_sequence_length - n_prompt_tokens
max_tokens = min(max_tokens, token_budget)
Expand All @@ -373,19 +394,5 @@ def _process_args(

return args

def _format_prompt(
self, prompt: str, prompt_template: str, system_prompt: str
) -> str:
if not prompt_template:
return prompt
if "system_prompt" in prompt_template:
system_prompt = system_prompt if system_prompt else ""
formatted_prompt = prompt_template.format(
system_prompt=system_prompt, prompt=prompt
)
return formatted_prompt
formatted_prompt = prompt_template.format(prompt=prompt)
return formatted_prompt

def _get_n_tokens(self, text: str) -> int:
return len(self.tokenizer(text)["input_ids"])

0 comments on commit b726e20

Please sign in to comment.