From 5c46d3eafc1f85d6a1ef4282f5f15790c3ef6e0f Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Mon, 22 Jul 2024 22:14:08 -0300 Subject: [PATCH] Adds TokenizationTaskRequest and InfoService This includes the TokenizationTaskRequest method for the embeddings models at the NLP interface + implements the InfoService for the watsonx product team requirement to be able to retrieve embeddings models' metadata. Signed-off-by: Flavia Beo --- fmaas-router/build.rs | 2 +- fmaas-router/src/rpc.rs | 1 + fmaas-router/src/rpc/info.rs | 92 +++++++++++++++++++ fmaas-router/src/rpc/nlp.rs | 33 +++++-- fmaas-router/src/server.rs | 9 +- proto/caikit_data_model_common_runtime.proto | 4 +- proto/caikit_runtime_Nlp.proto | 40 ++++++++- proto/google/protobuf/struct.proto | 95 ++++++++++++++++++++ 8 files changed, 263 insertions(+), 13 deletions(-) create mode 100644 fmaas-router/src/rpc/info.rs create mode 100644 proto/google/protobuf/struct.proto diff --git a/fmaas-router/build.rs b/fmaas-router/build.rs index 10a9a78..247b801 100644 --- a/fmaas-router/build.rs +++ b/fmaas-router/build.rs @@ -8,7 +8,7 @@ fn main() -> Result<(), Box> { .out_dir("src/pb") .include_file("mod.rs") .compile( - &["../proto/generation.proto", "../proto/caikit_runtime_Nlp.proto"], + &["../proto/generation.proto", "../proto/caikit_runtime_Nlp.proto", "../proto/caikit_runtime_info.proto"], &["../proto"], ) .unwrap_or_else(|e| panic!("protobuf compilation failed: {}", e)); diff --git a/fmaas-router/src/rpc.rs b/fmaas-router/src/rpc.rs index 8b3aa74..d588970 100644 --- a/fmaas-router/src/rpc.rs +++ b/fmaas-router/src/rpc.rs @@ -1,2 +1,3 @@ pub mod generation; pub mod nlp; +pub mod info; \ No newline at end of file diff --git a/fmaas-router/src/rpc/info.rs b/fmaas-router/src/rpc/info.rs new file mode 100644 index 0000000..6323f84 --- /dev/null +++ b/fmaas-router/src/rpc/info.rs @@ -0,0 +1,92 @@ +use std::collections::HashMap; + +use ginepro::LoadBalancedChannel; +use tonic::{transport::ClientTlsConfig, Code, Request, Response, Status}; +use tracing::{debug, instrument}; + +use crate::{create_clients, pb::{ + caikit::runtime::info::{ + info_service_client::InfoServiceClient, info_service_server::InfoService + }, + caikit_data_model::common::runtime::{ + ModelInfoRequest, ModelInfoResponse, RuntimeInfoRequest, RuntimeInfoResponse + } +}, ServiceAddr}; + +const METADATA_NAME_MODEL_ID: &str = "mm-model-id"; + +#[derive(Debug, Default)] +pub struct InfoServicer { + clients: HashMap>, +} + +impl InfoServicer { + pub async fn new( + default_target_port: u16, + client_tls: Option<&ClientTlsConfig>, + model_map: &HashMap, + ) -> Self { + let clients = create_clients( + default_target_port, client_tls, model_map, InfoServiceClient::new + ).await; + Self { clients } + } + + async fn client( + &self, + model_id: &str, + ) -> Result, Status> { + Ok(self + .clients + .get(model_id) + .ok_or_else(|| Status::not_found(format!("Unrecognized model_id: {model_id}")))? + .clone()) + } +} + +#[tonic::async_trait] +impl InfoService for InfoServicer { +#[instrument(skip_all)] +async fn get_models_info( + &self, + request: Request, +) -> Result, Status> { + let model_id = extract_model_id(&request)?; + let models_ids: &ModelInfoRequest = request.get_ref(); + if models_ids.model_ids.is_empty() { + return Ok(Response::new(ModelInfoResponse::default())); + } + debug!( + "Routing get models info request for Model ID {}", + model_id + ); + self.client(model_id) + .await? + .get_models_info(request) + .await +} +#[instrument(skip_all)] +async fn get_runtime_info( + &self, + _request: Request, +) -> Result, Status> { + Err(Status::unimplemented("not implemented")) +} +} + +/// Extracts model_id from [`Request`] metadata. +fn extract_model_id(request: &Request) -> Result<&str, Status> { + let metadata = request.metadata(); + if !metadata.contains_key(METADATA_NAME_MODEL_ID) { + return Err(Status::new( + Code::InvalidArgument, + "Missing required model ID", + )); + } + let model_id = metadata + .get(METADATA_NAME_MODEL_ID) + .unwrap() + .to_str() + .unwrap(); + Ok(model_id) +} \ No newline at end of file diff --git a/fmaas-router/src/rpc/nlp.rs b/fmaas-router/src/rpc/nlp.rs index 5bfe2c4..5d613f8 100644 --- a/fmaas-router/src/rpc/nlp.rs +++ b/fmaas-router/src/rpc/nlp.rs @@ -4,14 +4,9 @@ use ginepro::LoadBalancedChannel; use tonic::{transport::ClientTlsConfig, Code, Request, Response, Status, Streaming}; use tracing::{debug, instrument}; -use crate::{pb::{ +use crate::{create_clients, pb::{ caikit::runtime::nlp::{ - nlp_service_client::NlpServiceClient, nlp_service_server::NlpService, - BidiStreamingTokenClassificationTaskRequest, EmbeddingTaskRequest, - EmbeddingTasksRequest, RerankTaskRequest, RerankTasksRequest, - SentenceSimilarityTaskRequest, SentenceSimilarityTasksRequest, - ServerStreamingTextGenerationTaskRequest, TextClassificationTaskRequest, - TextGenerationTaskRequest, TokenClassificationTaskRequest, + nlp_service_client::NlpServiceClient, nlp_service_server::NlpService, BidiStreamingTokenClassificationTaskRequest, EmbeddingTaskRequest, EmbeddingTasksRequest, ModelInfoRequest, ModelInfoResponse, RerankTaskRequest, RerankTasksRequest, SentenceSimilarityTaskRequest, SentenceSimilarityTasksRequest, ServerStreamingTextGenerationTaskRequest, TextClassificationTaskRequest, TextGenerationTaskRequest, TokenClassificationTaskRequest, TokenizationTaskRequest }, caikit_data_model::{ caikit_nlp::{ @@ -21,9 +16,10 @@ use crate::{pb::{ nlp::{ ClassificationResults, GeneratedTextResult, GeneratedTextStreamResult, TokenClassificationResults, TokenClassificationStreamResult, + TokenizationResults }, }, -}, create_clients, ServiceAddr}; +}, ServiceAddr}; const METADATA_NAME_MODEL_ID: &str = "mm-model-id"; @@ -220,6 +216,27 @@ impl NlpService for NlpServicer { ) -> Result, Status> { Err(Status::unimplemented("not implemented")) } + + #[instrument(skip_all)] + async fn tokenization_task_predict( + &self, + request: Request, + ) -> Result, Status> { + let model_id = extract_model_id(&request)?; + let str: &TokenizationTaskRequest = request.get_ref(); + if str.text.is_empty() { + return Ok(Response::new(TokenizationResults::default())); + } + debug!( + "Routing tokenization task predict request for Model ID {}", + model_id + ); + self.client(model_id) + .await? + .tokenization_task_predict(request) + .await + } + } /// Extracts model_id from [`Request`] metadata. diff --git a/fmaas-router/src/server.rs b/fmaas-router/src/server.rs index f73fa7a..8f7a25b 100644 --- a/fmaas-router/src/server.rs +++ b/fmaas-router/src/server.rs @@ -11,8 +11,9 @@ use crate::{ pb::{ caikit::runtime::nlp::nlp_service_server::NlpServiceServer, fmaas::generation_service_server::GenerationServiceServer, + caikit::runtime::info::info_service_server::InfoServiceServer }, - rpc::{generation::GenerationServicer, nlp::NlpServicer}, + rpc::{generation::GenerationServicer, info::InfoServicer, nlp::NlpServicer}, ModelMap, }; @@ -73,6 +74,12 @@ pub async fn run( NlpServicer::new(default_target_port, client_tls.as_ref(), model_map).await; routes_builder.add_service(NlpServiceServer::new(nlp_servicer)); } + if let Some(model_map) = model_map.embeddings() { + info!("Enabling InfoService"); + let info_servicer = + InfoServicer::new(default_target_port, client_tls.as_ref(), model_map).await; + routes_builder.add_service(InfoServiceServer::new(info_servicer)); + } let grpc_server = builder .add_routes(routes_builder.routes()) .serve_with_shutdown(grpc_addr, shutdown_signal()); diff --git a/proto/caikit_data_model_common_runtime.proto b/proto/caikit_data_model_common_runtime.proto index 4b6190f..386c4d2 100644 --- a/proto/caikit_data_model_common_runtime.proto +++ b/proto/caikit_data_model_common_runtime.proto @@ -6,7 +6,7 @@ syntax = "proto3"; package caikit_data_model.common.runtime; import "caikit_data_model_common.proto"; - +import "google/protobuf/struct.proto"; /*-- MESSAGES ----------------------------------------------------------------*/ @@ -18,7 +18,7 @@ message ModelInfo { string model_path = 1; string name = 2; int64 size = 3; - map metadata = 4; + google.protobuf.Struct metadata = 4; bool loaded = 7; string module_id = 5; map module_metadata = 6; diff --git a/proto/caikit_runtime_Nlp.proto b/proto/caikit_runtime_Nlp.proto index e153cab..0ed3273 100644 --- a/proto/caikit_runtime_Nlp.proto +++ b/proto/caikit_runtime_Nlp.proto @@ -10,7 +10,7 @@ import "caikit_data_model_caikit_nlp.proto"; import "caikit_data_model_common.proto"; import "caikit_data_model_nlp.proto"; import "caikit_data_model_runtime.proto"; - +import "caikit_data_model_common_runtime.proto"; /*-- MESSAGES ----------------------------------------------------------------*/ @@ -201,6 +201,38 @@ message TokenClassificationTaskRequest { optional double threshold = 2; } +message TokenizationTaskRequest { + + /*-- fields --*/ + string text = 1; +} + +message ModelInfo { + + /*-- nested messages --*/ + + /*-- fields --*/ + string model_path = 1; + string name = 2; + int64 size = 3; + google.protobuf.Struct metadata = 4; + bool loaded = 7; + string module_id = 5; + map module_metadata = 6; +} + +message ModelInfoRequest { + + /*-- fields --*/ + repeated string model_ids = 1; +} + +message ModelInfoResponse { + + /*-- fields --*/ + repeated ModelInfo models = 1; +} + /*-- SERVICES ----------------------------------------------------------------*/ @@ -216,9 +248,15 @@ service NlpService { rpc TextClassificationTaskPredict(caikit.runtime.Nlp.TextClassificationTaskRequest) returns (caikit_data_model.nlp.ClassificationResults); rpc TextGenerationTaskPredict(caikit.runtime.Nlp.TextGenerationTaskRequest) returns (caikit_data_model.nlp.GeneratedTextResult); rpc TokenClassificationTaskPredict(caikit.runtime.Nlp.TokenClassificationTaskRequest) returns (caikit_data_model.nlp.TokenClassificationResults); + rpc TokenizationTaskPredict(caikit.runtime.Nlp.TokenizationTaskRequest) returns (caikit_data_model.nlp.TokenizationResults); } service NlpTrainingService { rpc TextGenerationTaskPeftPromptTuningTrain(caikit.runtime.Nlp.TextGenerationTaskPeftPromptTuningTrainRequest) returns (caikit_data_model.runtime.TrainingJob); rpc TextGenerationTaskTextGenerationTrain(caikit.runtime.Nlp.TextGenerationTaskTextGenerationTrainRequest) returns (caikit_data_model.runtime.TrainingJob); } + +service InfoService { + rpc GetRuntimeInfo(caikit_data_model.common.runtime.RuntimeInfoRequest) returns (caikit_data_model.common.runtime.RuntimeInfoResponse); + rpc GetModelsInfo(caikit_data_model.common.runtime.ModelInfoRequest) returns (caikit_data_model.common.runtime.ModelInfoResponse); +} diff --git a/proto/google/protobuf/struct.proto b/proto/google/protobuf/struct.proto new file mode 100644 index 0000000..e07e343 --- /dev/null +++ b/proto/google/protobuf/struct.proto @@ -0,0 +1,95 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +package google.protobuf; + +option cc_enable_arenas = true; +option go_package = "google.golang.org/protobuf/types/known/structpb"; +option java_package = "com.google.protobuf"; +option java_outer_classname = "StructProto"; +option java_multiple_files = true; +option objc_class_prefix = "GPB"; +option csharp_namespace = "Google.Protobuf.WellKnownTypes"; + +// `Struct` represents a structured data value, consisting of fields +// which map to dynamically typed values. In some languages, `Struct` +// might be supported by a native representation. For example, in +// scripting languages like JS a struct is represented as an +// object. The details of that representation are described together +// with the proto support for the language. +// +// The JSON representation for `Struct` is JSON object. +message Struct { + // Unordered map of dynamically typed values. + map fields = 1; +} + +// `Value` represents a dynamically typed value which can be either +// null, a number, a string, a boolean, a recursive struct value, or a +// list of values. A producer of value is expected to set one of these +// variants. Absence of any variant indicates an error. +// +// The JSON representation for `Value` is JSON value. +message Value { + // The kind of value. + oneof kind { + // Represents a null value. + NullValue null_value = 1; + // Represents a double value. + double number_value = 2; + // Represents a string value. + string string_value = 3; + // Represents a boolean value. + bool bool_value = 4; + // Represents a structured value. + Struct struct_value = 5; + // Represents a repeated `Value`. + ListValue list_value = 6; + } +} + +// `NullValue` is a singleton enumeration to represent the null value for the +// `Value` type union. +// +// The JSON representation for `NullValue` is JSON `null`. +enum NullValue { + // Null value. + NULL_VALUE = 0; +} + +// `ListValue` is a wrapper around a repeated field of values. +// +// The JSON representation for `ListValue` is JSON array. +message ListValue { + // Repeated field of dynamically typed values. + repeated Value values = 1; +} \ No newline at end of file