Skip to content

Commit

Permalink
fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed May 1, 2024
1 parent 94c0f8e commit 21ff680
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import subprocess
import time
import multiprocessing as mp
from typing import Optional

from cog import BasePredictor, ConcatenateIterator, Input

Expand Down Expand Up @@ -63,8 +64,9 @@ async def setup(self, weights: str = "") -> None:
self.trt_llm_config = config = json.load(f)
print(f"tensorrt_llm config: {config}")

if os.getenv("MAX_SEQUENCE_LENGTH", None):
self.max_sequence_length = int(os.getenv("MAX_SEQUENCE_LENGTH"))
max_seqlen_env = os.getenv("MAX_SEQUENCE_LENGTH", None)
if max_seqlen_env :
self.max_sequence_length = int(max_seqlen_env )
else:
try:
self.max_sequence_length = self.trt_llm_config["pretrained_config"][
Expand All @@ -87,7 +89,7 @@ async def setup(self, weights: str = "") -> None:
return
raise Exception(f"Couldn't start Triton (exit code {self.proc.poll()})")

async def start_triton(self) -> None:
async def start_triton(self) -> bool:
# # launch triton server
# # python3 scripts/launch_triton_server.py --world_size=1 --model_repo=/src/tensorrtllm_backend/triton_model
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand Down Expand Up @@ -248,6 +250,7 @@ async def predict(
start_time = time.time()
n_tokens = 0
tokens = np.array([], dtype=np.int32)
first_token_time = None

async with req as resp:
async for event in receive_sse(resp):
Expand Down Expand Up @@ -289,13 +292,14 @@ async def predict(
if self.log_performance_metrics or log_performance_metrics:
latency = end_time - start_time
actual_tps = n_tokens / latency
time_to_first_token = first_token_time - start_time
self.log(f"Tokens processed: {n_tokens}\n")
self.log(f"Serverside tokens per second: {round(actual_tps, 2)}\n")
self.log(f"Serverside execution time: {round(latency, 2)} seconds\n")
self.log(
f"Serverside time to first token: {round(time_to_first_token, 2)} seconds\n"
)
if first_token_time:
time_to_first_token = first_token_time - start_time
self.log(
f"Serverside time to first token: {round(time_to_first_token, 2)} seconds\n"
)

self.log(f"Random seed used: `{args['random_seed']}`\n")
self.log(
Expand All @@ -313,14 +317,14 @@ def _process_args(
self,
prompt: str,
max_tokens: int = 250,
min_tokens: int = None,
min_tokens: Optional[int] = None,
top_k: int = 0,
top_p: float = 0.0,
temperature: float = 1.0,
length_penalty: float = 1.0,
presence_penalty: float = 0.0,
stop_words: str = None,
seed: int = None,
stop_words: Optional[str] = None,
seed: Optional[int] = None,
stream: bool = True,
):
stop_words_list = stop_words.split(",") if stop_words else []
Expand Down

0 comments on commit 21ff680

Please sign in to comment.