Skip to content

Commit

Permalink
Add warning for server hangs (#2333)
Browse files Browse the repository at this point in the history
* Add warning for server hangs

* Add http nonstream server support

* bump fastdeploy_llm to v1.0.0

* add time log for each request

* baidu-fastdeploy-fastdeploy-3 fix time format
  • Loading branch information
rainyfly authored Dec 27, 2023
1 parent 67ca253 commit 7bddc67
Show file tree
Hide file tree
Showing 11 changed files with 470 additions and 19 deletions.
8 changes: 8 additions & 0 deletions llm/fastdeploy_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, model_dir, decode_strategy="sampling", mp_num=None):
self.model_prompt_dir_path = config.get("prompt_dir_path",
"./prompt_embedding")
self.max_prefix_len = config.get("max_prefix_len", 128)
self.inference_response_timeout = 20 # timeout for inference engine output every token

def is_arch(self, arch):
return arch in self.architecture
Expand Down Expand Up @@ -158,3 +159,10 @@ def load_environment_variables(self):
"detect environment DISABLE_DYNAMIC_BATCHING={}, will reset `disable_dynamic_batching` to {}!".
format(self.disable_dynamic_batching,
self.disable_dynamic_batching))

if os.getenv("INFERENCE_RESPONSE_TIMEOUT", None):
self.inference_response_timeout = int(os.getenv("INFERENCE_RESPONSE_TIMEOUT"))
logger.warning(
"detect environment INFERENCE_RESPONSE_TIMEOUT={}, will reset `inference_response_timeout` to {}!".
format(self.inference_response_timeout,
self.inference_response_timeout))
32 changes: 25 additions & 7 deletions llm/fastdeploy_llm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def parse_args():
type=int,
default=64,
help='num_attention_heads')
parser.add_argument(
'--num_key_value_heads',
type=int,
default=4,
help='num_key_value_heads')
parser.add_argument(
'--hidden_size', type=int, default=8192, help='hidden_size')
parser.add_argument(
Expand Down Expand Up @@ -140,15 +145,28 @@ def init_dist_env(world_size, seed=20):

cache_kvs = []
for _ in range(args.num_layers):
cache_kvs.append(
paddle.cast(
paddle.to_tensor(
np.zeros(
(2, args.batch_size, args.num_attention_heads // nranks,
if 'llama' in args.architecture:
## llama in PaddleNLP after https://github.com/PaddlePaddle/PaddleNLP/pull/7516/files changed cache kv shape
cache_kvs.append(
paddle.cast(
paddle.to_tensor(
np.zeros(
(2, args.batch_size, args.num_key_value_heads,
args.max_seq_len + args.max_dec_len, args.hidden_size //
args.num_attention_heads),
dtype='float32')),
args.dtype))
dtype='float32')),
args.dtype))

else:
cache_kvs.append(
paddle.cast(
paddle.to_tensor(
np.zeros(
(2, args.batch_size, args.num_attention_heads // nranks,
args.max_seq_len + args.max_dec_len, args.hidden_size //
args.num_attention_heads),
dtype='float32')),
args.dtype))

pre_ids = paddle.to_tensor(np.full((args.batch_size, 2048), -1, dtype='int64'))
tgt_generation_mask = paddle.zeros(
Expand Down
11 changes: 10 additions & 1 deletion llm/fastdeploy_llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from fastdeploy_llm.utils.utils import deserialize_from_file, get_files, remove_files
from fastdeploy_llm.config import Config
from fastdeploy_llm.task import Task, TaskStatus
from fastdeploy_llm.utils.logging_util import logger
from fastdeploy_llm.utils.logging_util import logger, warning_logger
from fastdeploy_llm.utils.logging_util import error_format, ErrorCode, ErrorType

from concurrent.futures import ThreadPoolExecutor

Expand Down Expand Up @@ -229,6 +230,7 @@ def async_predict(self, batch_tasks, stop_num=None):

def _update_task_results(self, tasks):
step_index = 1
last_response_time = time.time()
while True:
filepath = f"./real_time_save.temp_ids_rank_0_step_{step_index}"
if os.path.exists(filepath):
Expand All @@ -240,6 +242,7 @@ def _update_task_results(self, tasks):
except:
fin.close()
token_ids = deserialize_from_file(fin)
last_response_time = time.time()
fin.close()
step_index += 1
for b, token_id in enumerate(token_ids):
Expand Down Expand Up @@ -283,6 +286,12 @@ def _update_task_results(self, tasks):
else:
if not self._is_engine_busy():
break
if time.time() - last_response_time > self.config.inference_response_timeout:
error_type = ErrorType.Server
error_code = ErrorCode.S0003
error_info = "Inference engine output token timeout due to some unexpectable exceptions."
error_msg = error_format.format(error_type.name, error_code.name, error_info)
warning_logger.error(error_msg)
ret = self.engine_proc.poll()
if ret is not None:
logger.error(
Expand Down
3 changes: 2 additions & 1 deletion llm/fastdeploy_llm/serving/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .triton_model import TritonPythonModel
from .triton_model_stream import TritonPythonModelStream
from .triton_model_nonstream import TritonPythonModelNonStream
14 changes: 10 additions & 4 deletions llm/fastdeploy_llm/serving/serving_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,17 @@ def runner(self):
get_tasks = list()
for i in range(get_tasks_num):
try:
task = self.requests_queue.get(timeout=0.1)
if i == 0:
task = self.requests_queue.get() # only block when get first data
else:
task = self.requests_queue.get(timeout=0.01) # wait only 10ms for batch
get_tasks.append(task)
except Exception as e:
break
if len(get_tasks) == 0 and batch_tasks.unfinished_size() == 0:
time.sleep(0.1)
continue

# if len(get_tasks) == 0 and batch_tasks.unfinished_size() == 0: # should not arrive here for performance
# time.sleep(0.1)
# continue

sender_size = 0
if self.model.stream_sender is not None:
Expand Down Expand Up @@ -120,6 +124,8 @@ def runner(self):
logger.debug(
"BatchTasks with task_id={} is send to engine to predict.".
format([t.task_id for t in batch_tasks.tasks]))
for task in batch_tasks.tasks:
task.set_inference_start_time(time.time())
self.model.predict(batch_tasks, stop_nums)
logger.debug("The last batch tasks' status = {}.".format(
[t.status for t in batch_tasks.tasks]))
Expand Down
Loading

0 comments on commit 7bddc67

Please sign in to comment.