Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi model refactor #223

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
51 changes: 51 additions & 0 deletions examples/multi_model/deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import mii

gpu_index_map1 = {'master': [0]}
gpu_index_map2 = {'master': [1]}
gpu_index_map3 = {'master': [0, 1]}

deployments = []

mii_configs1 = {"tensor_parallel": 2, "dtype": "fp16"}
mii_configs2 = {"tensor_parallel": 1}

name = "bigscience/bloom-560m"
deployments.append({
'task': 'text-generation',
'model': name,
'deployment_name': name + "_deployment",
'GPU_index_map': gpu_index_map3,
'tensor_parallel': 2,
'dtype': "fp16"
})

# gpt2
name = "microsoft/DialogRPT-human-vs-rand"
deployments.append({
'task': 'text-classification',
'model': name,
'deployment_name': name + "_deployment",
'GPU_index_map': gpu_index_map2
})

name = "microsoft/DialoGPT-large"
deployments.append({
'task': 'conversational',
'model': name,
'deployment_name': name + "_deployment",
'GPU_index_map': gpu_index_map1,
})

name = "deepset/roberta-large-squad2"
deployments.append({
'task': "question-answering",
'model': name,
'deployment_name': name + "-qa-deployment",
'GPU_index_map': gpu_index_map2
})

mii.deploy(deployment_tag="multi_models", deployment_configs=deployments[:2])
50 changes: 50 additions & 0 deletions examples/multi_model/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import mii

results = []
generator = mii.mii_query_handle("multi_models")
result = generator.query(
{
"query": ["DeepSpeed is",
"Seattle is"],
"deployment_name": "bigscience/bloom-560m_deployment"
},
do_sample=True,
max_new_tokens=30,
)
results.append(result)
print(result)

result = generator.query({
'query':
"DeepSpeed is the greatest",
"deployment_name":
"microsoft/DialogRPT-human-vs-rand_deployment"
})
results.append(result)
print(result)

result = generator.query({
'text': "DeepSpeed is the greatest",
'conversation_id': 3,
'past_user_inputs': [],
'generated_responses': [],
"deployment_name": "microsoft/DialoGPT-large_deployment"
})
results.append(result)
print(result)

result = generator.query({
'question':
"What is the greatest?",
'context':
"DeepSpeed is the greatest",
"deployment_name":
"deepset/roberta-large-squad2" + "-qa-deployment"
})
results.append(result)
print(result)
7 changes: 7 additions & 0 deletions examples/multi_model/shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import mii

mii.terminate("multi_models")
4 changes: 2 additions & 2 deletions mii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from .client import MIIClient, mii_query_handle
from .deployment import deploy
from .terminate import terminate
from .constants import DeploymentType, Tasks
from .constants import DeploymentType, TaskType
from .aml_related.utils import aml_output_path

from .config import MIIConfig, LoadBalancerConfig
from .config import MIIConfig, DeploymentConfig
from .grpc_related.proto import modelresponse_pb2_grpc

__version__ = "0.0.0"
Expand Down
113 changes: 69 additions & 44 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@
import grpc
import requests
import mii
from mii.utils import get_task
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.constants import GRPC_MAX_MSG_SIZE, Tasks
from mii.constants import GRPC_MAX_MSG_SIZE, TaskType
from mii.method_table import GRPC_METHOD_TABLE
from mii.config import MIIConfig


def _get_deployment_info(deployment_name):
configs = mii.utils.import_score_file(deployment_name).configs
task = configs[mii.constants.TASK_NAME_KEY]
mii_configs_dict = configs[mii.constants.MII_CONFIGS_KEY]
mii_configs = mii.config.MIIConfig(**mii_configs_dict)

assert task is not None, "The task name should be set before calling init"
return task, mii_configs
def _get_mii_config(deployment_name):
mii_config = mii.utils.import_score_file(deployment_name).mii_config
return MIIConfig(**mii_config)


def mii_query_handle(deployment_name):
Expand All @@ -39,40 +34,64 @@ def mii_query_handle(deployment_name):
inference_pipeline, task = mii.non_persistent_models[deployment_name]
return MIINonPersistentClient(task, deployment_name)

task_name, mii_configs = _get_deployment_info(deployment_name)
return MIIClient(task_name, "localhost", mii_configs.port_number)
mii_config = _get_mii_config(deployment_name)
return MIIClient(mii_config, "localhost", mii_config.port_number)


def create_channel(host, port):
return grpc.aio.insecure_channel(f'{host}:{port}',
options=[('grpc.max_send_message_length',
GRPC_MAX_MSG_SIZE),
('grpc.max_receive_message_length',
GRPC_MAX_MSG_SIZE)])


class MIIClient():
return grpc.aio.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_send_message_length",
GRPC_MAX_MSG_SIZE),
("grpc.max_receive_message_length",
GRPC_MAX_MSG_SIZE),
],
)


class MIIClient:
"""
Client to send queries to a single endpoint.
"""
def __init__(self, task_name, host, port):
def __init__(self, mii_config, host, port):
self.asyncio_loop = asyncio.get_event_loop()
channel = create_channel(host, port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
self.task = get_task(task_name)

async def _request_async_response(self, request_dict, **query_kwargs):
if self.task not in GRPC_METHOD_TABLE:
raise ValueError(f"unknown task: {self.task}")

task_methods = GRPC_METHOD_TABLE[self.task]
self.mii_config = mii_config

def _get_deployment_task(self, deployment_name=None):
task = None
if deployment_name is None: #mii.terminate() or single model
if deployment_name is None:
assert len(self.deployments) == 1, "Must pass deployment_name to query when using multiple deployments"
deployment = self.mii_config.deployment_configs[0]
deployment_name = getattr(deployment, deployment_name)
task = getattr(deployment, task)
else:
if deployment_name in self.deployments:
deployment = self.mii_config.deployment_configs[deployment_name]
task = getattr(deployment, task)
else:
assert False, f"{deployment_name} not found in list of deployments"
return deployment_name, task

async def _request_async_response(self, request_dict, task, **query_kwargs):
if task not in GRPC_METHOD_TABLE:
raise ValueError(f"unknown task: {task}")

task_methods = GRPC_METHOD_TABLE[task]
proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
proto_response = await getattr(self.stub, task_methods.method)(proto_request)
proto_response = await getattr(self.mr_stub, task_methods.method)(proto_request)
return task_methods.unpack_response_from_proto(proto_response)

def query(self, request_dict, **query_kwargs):
deployment_name = request_dict.get(mii.constants.DEPLOYMENT_NAME_KEY)
deployment_name, task = self._get_deployment_task(deployment_name)
request_dict['deployment_name'] = deployment_name
return self.asyncio_loop.run_until_complete(
self._request_async_response(request_dict,
task,
**query_kwargs))

async def terminate_async(self):
Expand All @@ -87,7 +106,9 @@ async def create_session_async(self, session_id):
modelresponse_pb2.SessionID(session_id=session_id))

def create_session(self, session_id):
assert self.task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'."
assert (
self.task == TaskType.TEXT_GENERATION
), f"Session creation only available for task '{TaskType.TEXT_GENERATION}'."
return self.asyncio_loop.run_until_complete(
self.create_session_async(session_id))

Expand All @@ -96,18 +117,20 @@ async def destroy_session_async(self, session_id):
)

def destroy_session(self, session_id):
assert self.task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'."
assert (
self.task == TaskType.TEXT_GENERATION
), f"Session deletion only available for task '{TaskType.TEXT_GENERATION}'."
self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id))


class MIITensorParallelClient():
class MIITensorParallelClient:
"""
Client to send queries to multiple endpoints in parallel.
This is used to call multiple servers deployed for tensor parallelism.
"""
def __init__(self, task_name, host, ports):
self.task = get_task(task_name)
self.clients = [MIIClient(task_name, host, port) for port in ports]
def __init__(self, task, host, ports):
self.task = task
self.clients = [MIIClient(task, host, port) for port in ports]
self.asyncio_loop = asyncio.get_event_loop()

# runs task in parallel and return the result from the first task
Expand Down Expand Up @@ -155,30 +178,32 @@ def destroy_session(self, session_id):
client.destroy_session(session_id)


class MIINonPersistentClient():
class MIINonPersistentClient:
def __init__(self, task, deployment_name):
self.task = task
self.deployment_name = deployment_name

def query(self, request_dict, **query_kwargs):
assert self.deployment_name in mii.non_persistent_models, f"deployment: {self.deployment_name} not found"
assert (
self.deployment_name in mii.non_persistent_models
), f"deployment: {self.deployment_name} not found"
task_methods = GRPC_METHOD_TABLE[self.task]
inference_pipeline = mii.non_persistent_models[self.deployment_name][0]

if self.task == Tasks.QUESTION_ANSWERING:
if 'question' not in request_dict or 'context' not in request_dict:
if self.task == TaskType.QUESTION_ANSWERING:
if "question" not in request_dict or "context" not in request_dict:
raise Exception(
"Question Answering Task requires 'question' and 'context' keys")
args = (request_dict["question"], request_dict["context"])
kwargs = query_kwargs

elif self.task == Tasks.CONVERSATIONAL:
elif self.task == TaskType.CONVERSATIONAL:
conv = task_methods.create_conversation(request_dict, **query_kwargs)
args = (conv, )
kwargs = {}

else:
args = (request_dict['query'], )
args = (request_dict["query"], )
kwargs = query_kwargs

return task_methods.run_inference(inference_pipeline, args, query_kwargs)
Expand All @@ -189,6 +214,6 @@ def terminate(self):


def terminate_restful_gateway(deployment_name):
_, mii_configs = _get_deployment_info(deployment_name)
if mii_configs.enable_restful_api:
requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate")
mii_config = _get_mii_config(deployment_name)
if mii_config.enable_restful_api:
requests.get(f"http://localhost:{mii_config.restful_api_port}/terminate")
Loading