diff --git a/llm/server/server/triton_server.py b/llm/server/server/triton_server.py index 12024c251a..6b9ae6cec4 100644 --- a/llm/server/server/triton_server.py +++ b/llm/server/server/triton_server.py @@ -98,11 +98,37 @@ def _push_mode_sender_thread(self): except Exception as e: model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc()))) + def _cache_special_tokens(self, batch_result): + for i in range(len(batch_result)): + is_end = batch_result[i].get("is_end", 0) + token_ids = batch_result[i]["token_ids"] + return_all_tokens = batch_result[i].get("return_all_tokens", False) + cache_special_token = False if is_end == 1 else (13 <= token_ids[0] <= 268) + if is_end != 1 and (cache_special_token or return_all_tokens or self.cfg.disable_streaming): + if batch_result[i]["req_id"] not in self.token_buffer: + self.token_buffer[batch_result[i]["req_id"]] = list() + self.score_buffer[batch_result[i]["req_id"]] = list() + self.token_buffer[batch_result[i]["req_id"]].extend(token_ids) + self.score_buffer[batch_result[i]["req_id"]].extend(batch_result[i].get("token_scores", [])) + batch_result[i]["token_ids"] = [] + if "token_scores" in batch_result[i]: + batch_result[i]["token_scores"] = [] + else: + if batch_result[i]["req_id"] in self.token_buffer: + batch_result[i]["token_ids"] = self.token_buffer[batch_result[i] + ["req_id"]] + batch_result[i]["token_ids"] + del self.token_buffer[batch_result[i]["req_id"]] + if "token_scores" in batch_result[i]: + batch_result[i]["token_scores"] = self.score_buffer[batch_result[i] + ["req_id"]] + batch_result[i]["token_scores"] + del self.score_buffer[batch_result[i]["req_id"]] + def postprocess(self, batch_result, exist_finished_task=False): """ single postprocess for triton """ try: + self._cache_special_tokens(batch_result) self.cached_generated_tokens.put(batch_result) except Exception as e: model_server_logger.info(