From 79535686047bdbdd1f446c6b7dd76d8aaea160ec Mon Sep 17 00:00:00 2001 From: utnim2 Date: Thu, 26 Dec 2024 12:44:31 +0530 Subject: [PATCH] feat: added function calling and added token count --- src/providers/vertexai/models.rs | 106 ++++++++++++++++++++++++++--- src/providers/vertexai/provider.rs | 31 ++++++--- src/providers/vertexai/tests.rs | 94 +++++++++++++++++++++++-- tests/vertexai_integration_test.rs | 74 ++++++++++++++++++++ 4 files changed, 284 insertions(+), 21 deletions(-) diff --git a/src/providers/vertexai/models.rs b/src/providers/vertexai/models.rs index 399f8f2..93d26ea 100644 --- a/src/providers/vertexai/models.rs +++ b/src/providers/vertexai/models.rs @@ -5,6 +5,7 @@ use crate::models::embeddings::{ Embeddings, EmbeddingsInput, EmbeddingsRequest, EmbeddingsResponse, }; use crate::models::streaming::{ChatCompletionChunk, Choice, ChoiceDelta}; +use crate::models::tool_calls::{ChatMessageToolCall, FunctionCall}; use crate::models::usage::Usage; use serde::{Deserialize, Serialize}; @@ -14,6 +15,9 @@ pub(crate) struct VertexAIChatCompletionRequest { pub contents: Vec, #[serde(rename = "generation_config")] pub generation_config: Option, + #[serde(rename = "tools")] + #[serde(skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, } #[derive(Deserialize, Serialize, Clone, Debug)] @@ -42,7 +46,10 @@ pub(crate) struct Content { #[derive(Deserialize, Serialize, Clone, Debug)] pub(crate) struct Part { - pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")] + pub function_call: Option, } #[derive(Deserialize, Serialize, Clone, Debug)] @@ -63,6 +70,8 @@ pub(crate) struct GenerateContentResponse { pub safety_ratings: Option>, #[serde(rename = "avgLogprobs")] pub avg_logprobs: Option, + #[serde(rename = "functionCall")] + pub function_call: Option, } #[derive(Deserialize, Serialize, Clone, Debug)] @@ -140,6 +149,40 @@ pub(crate) struct VertexAIEmbeddingStatistics { pub token_count: u32, } +#[derive(Deserialize, Serialize, Clone, Debug)] +pub(crate) struct Tool { + pub function_declarations: Vec, +} + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub(crate) struct FunctionDeclaration { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub(crate) struct VertexAITool { + pub function_declarations: Vec, +} + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub(crate) struct VertexAIFunctionDeclaration { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub(crate) struct VertexFunctionCall { + pub name: String, + pub args: serde_json::Value, +} + impl From for VertexAIChatCompletionRequest { fn from(request: crate::models::chat::ChatCompletionRequest) -> Self { let contents = request @@ -162,11 +205,32 @@ impl From for VertexAIChatCompletion "assistant" => "model".to_string(), _ => "user".to_string(), }, - parts: vec![Part { text }], + parts: vec![Part { + text: Some(text), + function_call: None, + }], } }) .collect(); + let tools = if let Some(tools) = request.tools { + vec![VertexAITool { + function_declarations: tools + .into_iter() + .map(|tool| VertexAIFunctionDeclaration { + name: tool.function.name, + description: tool.function.description, + parameters: tool + .function + .parameters + .map(|p| serde_json::to_value(p).unwrap_or_default()), + }) + .collect(), + }] + } else { + Vec::new() + }; + VertexAIChatCompletionRequest { contents, generation_config: Some(GenerationConfig { @@ -176,6 +240,7 @@ impl From for VertexAIChatCompletion candidate_count: request.n.map(|n| n as i32), max_output_tokens: request.max_tokens.or(Some(default_max_tokens())), }), + tools, } } } @@ -187,10 +252,24 @@ impl From for ChatCompletion { .into_iter() .enumerate() .map(|(index, candidate)| { - let content = if let Some(part) = candidate.content.parts.first() { - ChatMessageContent::String(part.text.clone()) + let (content, tool_calls) = if let Some(part) = candidate.content.parts.first() { + match (&part.text, &part.function_call) { + (Some(text), None) => (ChatMessageContent::String(text.clone()), None), + (None, Some(func_call)) => ( + ChatMessageContent::String(String::new()), + Some(vec![ChatMessageToolCall { + id: uuid::Uuid::new_v4().to_string(), + function: FunctionCall { + name: func_call.name.clone(), + arguments: func_call.args.to_string(), + }, + r#type: "function".to_string(), + }]), + ), + _ => (ChatMessageContent::String(String::new()), None), + } } else { - ChatMessageContent::String(String::new()) + (ChatMessageContent::String(String::new()), None) }; ChatCompletionChoice { @@ -199,7 +278,7 @@ impl From for ChatCompletion { role: "assistant".to_string(), content: Some(content), name: None, - tool_calls: None, + tool_calls, }, finish_reason: Some(candidate.finish_reason), logprobs: None, @@ -207,13 +286,24 @@ impl From for ChatCompletion { }) .collect(); + let usage = response + .usage_metadata + .map(|metadata| Usage { + prompt_tokens: metadata.prompt_token_count as u32, + completion_tokens: metadata.candidates_token_count as u32, + total_tokens: metadata.total_token_count as u32, + completion_tokens_details: None, + prompt_tokens_details: None, + }) + .unwrap_or_default(); + ChatCompletion { id: uuid::Uuid::new_v4().to_string(), object: None, created: None, model: "gemini-pro".to_string(), choices, - usage: crate::models::usage::Usage::default(), + usage, system_fingerprint: None, } } @@ -231,7 +321,7 @@ impl From for ChatCompletionChunk { delta: ChoiceDelta { content: candidate .content - .and_then(|c| c.parts.first().map(|p| p.text.clone())), + .and_then(|c| c.parts.first().and_then(|p| p.text.clone())), role: Some("assistant".to_string()), tool_calls: None, }, diff --git a/src/providers/vertexai/provider.rs b/src/providers/vertexai/provider.rs index 36aeca8..93fb4a7 100644 --- a/src/providers/vertexai/provider.rs +++ b/src/providers/vertexai/provider.rs @@ -87,10 +87,19 @@ impl Provider for VertexAIProvider { ); headers.insert("Content-Type", HeaderValue::from_static("application/json")); - let request_body = json!({ - "contents": request.contents, - "generation_config": request.generation_config, - }); + let request_body = if payload.stream.unwrap_or(false) { + json!({ + "contents": request.contents, + "generation_config": request.generation_config, + "tools": request.tools, + }) + } else { + json!({ + "contents": request.contents, + "generation_config": request.generation_config, + "tools": request.tools, + }) + }; let response = self .http_client @@ -125,9 +134,17 @@ impl Provider for VertexAIProvider { Ok(ChatCompletionResponse::Stream(Box::pin(stream))) } else { + let response_text = response.text().await.map_err(|e| { + eprintln!("Failed to get response text: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + let vertex_response: VertexAIChatCompletionResponse = - response.json().await.map_err(|e| { - eprintln!("VertexAI API response error: {}", e); + serde_json::from_str(&response_text).map_err(|e| { + eprintln!( + "Failed to parse response: {}. Response was: {}", + e, response_text + ); StatusCode::INTERNAL_SERVER_ERROR })?; @@ -172,8 +189,6 @@ impl Provider for VertexAIProvider { model = model ); - println!("Request {:?}", request); - let mut headers = HeaderMap::new(); headers.insert( "Authorization", diff --git a/src/providers/vertexai/tests.rs b/src/providers/vertexai/tests.rs index 780550c..816433e 100644 --- a/src/providers/vertexai/tests.rs +++ b/src/providers/vertexai/tests.rs @@ -3,12 +3,14 @@ mod tests { use crate::models::chat::{ChatCompletion, ChatCompletionRequest}; use crate::models::content::{ChatCompletionMessage, ChatMessageContent}; use crate::models::embeddings::{EmbeddingsInput, EmbeddingsRequest}; + use crate::models::tool_definition::{FunctionDefinition, ToolDefinition}; use crate::providers::provider::Provider; use crate::providers::vertexai::models::{ Content, GenerateContentResponse, Part, UsageMetadata, VertexAIChatCompletionRequest, - VertexAIChatCompletionResponse, VertexAIEmbeddingsRequest, + VertexAIChatCompletionResponse, VertexAIEmbeddingsRequest, VertexFunctionCall, }; use crate::providers::vertexai::provider::VertexAIProvider; + use serde_json::json; use std::collections::HashMap; fn create_test_config() -> ProviderConfig { @@ -33,7 +35,9 @@ mod tests { model: "gemini-pro".to_string(), messages: vec![ChatCompletionMessage { role: "user".to_string(), - content: Some(ChatMessageContent::String("Test message".to_string())), + content: Some(ChatMessageContent::String( + "What's the weather in London?".to_string(), + )), name: None, tool_calls: None, }], @@ -90,9 +94,15 @@ mod tests { assert_eq!(vertex_request.contents.len(), 2); assert_eq!(vertex_request.contents[0].role, "user"); - assert_eq!(vertex_request.contents[0].parts[0].text, "Hello"); + assert_eq!( + vertex_request.contents[0].parts[0].text, + Some("Hello".to_string()) + ); assert_eq!(vertex_request.contents[1].role, "model"); - assert_eq!(vertex_request.contents[1].parts[0].text, "Hi there!"); + assert_eq!( + vertex_request.contents[1].parts[0].text, + Some("Hi there!".to_string()) + ); let gen_config = vertex_request.generation_config.unwrap(); assert_eq!(gen_config.temperature, Some(0.7)); @@ -107,12 +117,14 @@ mod tests { content: Content { role: "model".to_string(), parts: vec![Part { - text: "Generated response".to_string(), + text: Some("Generated response".to_string()), + function_call: None, }], }, finish_reason: "stop".to_string(), safety_ratings: None, avg_logprobs: None, + function_call: None, }], usage_metadata: Some(UsageMetadata { prompt_token_count: 10, @@ -227,4 +239,76 @@ mod tests { assert_eq!(vertex_request.instances[0].content, "test text"); assert!(vertex_request.parameters.unwrap().auto_truncate.unwrap()); } + + #[test] + fn test_function_calling_request() { + let mut chat_request = create_test_chat_request(); + chat_request.tools = Some(vec![ToolDefinition { + function: FunctionDefinition { + name: "get_weather".to_string(), + description: Some("Get the current weather in a location".to_string()), + parameters: Some(HashMap::from([ + ("type".to_string(), json!("object")), + ( + "properties".to_string(), + json!({ + "location": { + "type": "string", + "description": "The city name" + } + }), + ), + ("required".to_string(), json!(["location"])), + ])), + strict: None, + }, + tool_type: "function".to_string(), + }]); + + let vertex_request: VertexAIChatCompletionRequest = chat_request.into(); + + assert!(!vertex_request.tools.is_empty()); + assert_eq!( + vertex_request.tools[0].function_declarations[0].name, + "get_weather" + ); + assert_eq!( + vertex_request.tools[0].function_declarations[0].description, + Some("Get the current weather in a location".to_string()) + ); + } + + #[test] + fn test_function_calling_response() { + let vertex_response = VertexAIChatCompletionResponse { + candidates: vec![GenerateContentResponse { + content: Content { + role: "model".to_string(), + parts: vec![Part { + text: None, + function_call: Some(VertexFunctionCall { + name: "get_weather".to_string(), + args: json!({"location": "London"}), + }), + }], + }, + finish_reason: "stop".to_string(), + safety_ratings: None, + avg_logprobs: None, + function_call: None, + }], + usage_metadata: None, + model_version: None, + }; + + let chat_completion: ChatCompletion = vertex_response.into(); + + let tool_calls = chat_completion.choices[0] + .message + .tool_calls + .as_ref() + .unwrap(); + assert_eq!(tool_calls[0].function.name, "get_weather"); + assert_eq!(tool_calls[0].function.arguments, r#"{"location":"London"}"#); + } } diff --git a/tests/vertexai_integration_test.rs b/tests/vertexai_integration_test.rs index 30df1af..953dbd4 100644 --- a/tests/vertexai_integration_test.rs +++ b/tests/vertexai_integration_test.rs @@ -4,8 +4,10 @@ use hub::config::models::{ModelConfig, Provider as ProviderConfig}; use hub::models::chat::{ChatCompletionRequest, ChatCompletionResponse}; use hub::models::content::{ChatCompletionMessage, ChatMessageContent}; use hub::models::embeddings::{EmbeddingsInput, EmbeddingsRequest}; +use hub::models::tool_definition::{FunctionDefinition, ToolDefinition}; use hub::providers::provider::Provider; use hub::providers::vertexai::VertexAIProvider; +use serde_json::json; use std::collections::HashMap; use std::env; @@ -199,3 +201,75 @@ async fn test_error_handling() { let response = provider.chat_completions(request, &model_config).await; assert!(response.is_err(), "Should fail with invalid model"); } + +#[tokio::test] +async fn test_function_calling_integration() { + let provider = create_live_provider().await; + let model_config = create_test_model_config(); + + let weather_function = ToolDefinition { + function: FunctionDefinition { + name: "get_weather".to_string(), + description: Some("Get the current weather in a location".to_string()), + parameters: Some(HashMap::from([ + ("type".to_string(), json!("object")), + ("properties".to_string(), json!({ + "location": { + "type": "string", + "description": "The city name" + } + })), + ("required".to_string(), json!(["location"])) + ])), + strict: None, + }, + tool_type: "function".to_string(), + }; + + let request = ChatCompletionRequest { + model: "gemini-pro".to_string(), + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: Some(ChatMessageContent::String( + "What's the weather like in Paris today?".to_string(), + )), + name: None, + tool_calls: None, + }], + temperature: Some(0.0), + stream: None, + max_tokens: Some(100), + tools: Some(vec![weather_function]), + top_p: None, + n: None, + stop: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + tool_choice: None, // Removed + parallel_tool_calls: None, // Removed + }; + + let response = provider.chat_completions(request, &model_config).await; + assert!(response.is_ok(), "Function calling request failed"); + + if let Ok(ChatCompletionResponse::NonStream(completion)) = response { + assert!(!completion.choices.is_empty(), "No choices in response"); + match (&completion.choices[0].message.content, &completion.choices[0].message.tool_calls) { + (_, Some(tool_calls)) if !tool_calls.is_empty() => { + assert_eq!(tool_calls[0].function.name, "get_weather"); + let args: serde_json::Value = serde_json::from_str(&tool_calls[0].function.arguments) + .expect("Failed to parse function arguments"); + assert!(args["location"].is_string(), "Location should be a string"); + }, + (Some(content), _) => { + if let ChatMessageContent::String(text) = content { + println!("Got text response: {}", text); + assert!(!text.is_empty(), "Response should not be empty"); + } + }, + _ => panic!("Expected either tool calls or text content in response"), + } + } +} \ No newline at end of file