diff --git a/README.md b/README.md index 2ef1b2b..9eabf7a 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ Nerve features integrations for any model accessible via the following providers | **Mistral.ai** | `MISTRAL_API_KEY` | `mistral://mistral-large-latest` | | **Novita** | `NOVITA_API_KEY` | `novita://meta-llama/llama-3.1-70b-instruct` | -¹ **o1-preview and o1 models do not support function calling directly** and do not support a system prompt. It is possible to workaround this by adding the `--user-only` flag to the command line. +¹ **o1-preview and o1 models do not support function calling directly** and do not support a system prompt. Nerve will try to detect this and fallback to user prompt. It is possible to force this behaviour by adding the `--user-only` flag to the command line. ² Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint. diff --git a/src/agent/generator/anthropic.rs b/src/agent/generator/anthropic.rs index 2a748e1..2b2c1a4 100644 --- a/src/agent/generator/anthropic.rs +++ b/src/agent/generator/anthropic.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use crate::agent::{ - generator::{ChatResponse, Usage}, + generator::{ChatResponse, SupportedFeatures, Usage}, state::SharedState, Invocation, }; @@ -102,7 +102,7 @@ impl Client for AnthropicClient { Ok(Self { model, client }) } - async fn check_native_tools_support(&self) -> Result { + async fn check_supported_features(&self) -> Result { let messages = vec![Message::user("Execute the test function.")]; let max_tokens = MaxTokens::new(4096, self.model)?; @@ -128,9 +128,15 @@ impl Client for AnthropicClient { log::debug!("response = {:?}", response); if let Ok(tool_use) = response.content.flatten_into_tool_use() { - Ok(tool_use.name == "test") + Ok(SupportedFeatures { + system_prompt: true, + tools: tool_use.name == "test", + }) } else { - Ok(false) + Ok(SupportedFeatures { + system_prompt: true, + tools: false, + }) } } @@ -183,10 +189,10 @@ impl Client for AnthropicClient { let request_body = MessagesRequestBody { model: self.model, - system: match &options.system_prompt { - Some(sp) => Some(SystemPrompt::new(sp.trim())), - None => None, - }, + system: options + .system_prompt + .as_ref() + .map(|sp| SystemPrompt::new(sp.trim())), messages, max_tokens, tools: if tools.is_empty() { None } else { Some(tools) }, diff --git a/src/agent/generator/deepseek.rs b/src/agent/generator/deepseek.rs index c07f44c..bf24c97 100644 --- a/src/agent/generator/deepseek.rs +++ b/src/agent/generator/deepseek.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::agent::state::SharedState; -use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client}; +use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures}; pub struct DeepSeekClient { client: OpenAIClient, @@ -24,8 +24,8 @@ impl Client for DeepSeekClient { Ok(Self { client }) } - async fn check_native_tools_support(&self) -> Result { - self.client.check_native_tools_support().await + async fn check_supported_features(&self) -> Result { + self.client.check_supported_features().await } async fn chat( diff --git a/src/agent/generator/fireworks.rs b/src/agent/generator/fireworks.rs index ce626f9..6dc8ae8 100644 --- a/src/agent/generator/fireworks.rs +++ b/src/agent/generator/fireworks.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::agent::state::SharedState; -use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client}; +use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures}; pub struct FireworksClient { client: OpenAIClient, @@ -24,6 +24,10 @@ impl Client for FireworksClient { Ok(Self { client }) } + async fn check_supported_features(&self) -> Result { + self.client.check_supported_features().await + } + async fn chat( &self, state: SharedState, diff --git a/src/agent/generator/groq.rs b/src/agent/generator/groq.rs index c3cfe39..095f817 100644 --- a/src/agent/generator/groq.rs +++ b/src/agent/generator/groq.rs @@ -17,7 +17,7 @@ use crate::agent::{ Invocation, }; -use super::{ChatOptions, Client}; +use super::{ChatOptions, Client, SupportedFeatures}; lazy_static! { static ref RETRY_TIME_PARSER: Regex = @@ -58,7 +58,7 @@ impl Client for GroqClient { Ok(Self { model, api_key }) } - async fn check_native_tools_support(&self) -> Result { + async fn check_supported_features(&self) -> Result { let chat_history = vec![ crate::api::groq::completion::message::Message::SystemMessage { role: Some("system".to_string()), @@ -106,7 +106,10 @@ impl Client for GroqClient { log::debug!("groq.check_tools_support.resp = {:?}", &resp); - Ok(resp.is_ok()) + Ok(SupportedFeatures { + system_prompt: true, + tools: resp.is_ok(), + }) } async fn chat( diff --git a/src/agent/generator/huggingface.rs b/src/agent/generator/huggingface.rs index 6b83c2a..71b8325 100644 --- a/src/agent/generator/huggingface.rs +++ b/src/agent/generator/huggingface.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::agent::state::SharedState; -use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client}; +use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures}; pub struct HuggingfaceMessageClient { client: OpenAIClient, @@ -23,8 +23,8 @@ impl Client for HuggingfaceMessageClient { Ok(Self { client }) } - async fn check_native_tools_support(&self) -> Result { - self.client.check_native_tools_support().await + async fn check_supported_features(&self) -> Result { + self.client.check_supported_features().await } async fn chat( diff --git a/src/agent/generator/mistral.rs b/src/agent/generator/mistral.rs index be0b79f..e9751e2 100644 --- a/src/agent/generator/mistral.rs +++ b/src/agent/generator/mistral.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::agent::state::SharedState; -use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client}; +use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures}; pub struct MistralClient { client: OpenAIClient, @@ -15,13 +15,14 @@ impl Client for MistralClient { where Self: Sized, { - let client = OpenAIClient::custom(model_name, "MISTRAL_API_KEY", "https://api.mistral.ai/v1/")?; + let client = + OpenAIClient::custom(model_name, "MISTRAL_API_KEY", "https://api.mistral.ai/v1/")?; Ok(Self { client }) } - async fn check_native_tools_support(&self) -> Result { - self.client.check_native_tools_support().await + async fn check_supported_features(&self) -> Result { + self.client.check_supported_features().await } async fn chat( @@ -33,7 +34,7 @@ impl Client for MistralClient { if let Err(error) = &response { if self.check_rate_limit(&error.to_string()).await { - return self.chat(state, options).await; + return self.chat(state, options).await; } } diff --git a/src/agent/generator/mod.rs b/src/agent/generator/mod.rs index 325cc54..cda89ee 100644 --- a/src/agent/generator/mod.rs +++ b/src/agent/generator/mod.rs @@ -89,6 +89,20 @@ pub struct ChatResponse { pub usage: Option, } +pub struct SupportedFeatures { + pub system_prompt: bool, + pub tools: bool, +} + +impl Default for SupportedFeatures { + fn default() -> Self { + Self { + system_prompt: true, + tools: false, + } + } +} + #[async_trait] pub trait Client: mini_rag::Embedder + Send + Sync { fn new(url: &str, port: u16, model_name: &str, context_window: u32) -> Result @@ -97,8 +111,8 @@ pub trait Client: mini_rag::Embedder + Send + Sync { async fn chat(&self, state: SharedState, options: &ChatOptions) -> Result; - async fn check_native_tools_support(&self) -> Result { - Ok(false) + async fn check_supported_features(&self) -> Result { + Ok(SupportedFeatures::default()) } async fn check_rate_limit(&self, error: &str) -> bool { diff --git a/src/agent/generator/nim.rs b/src/agent/generator/nim.rs index ebeca45..43b2431 100644 --- a/src/agent/generator/nim.rs +++ b/src/agent/generator/nim.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::agent::state::SharedState; -use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client}; +use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures}; pub struct NvidiaNIMClient { client: OpenAIClient, @@ -30,8 +30,8 @@ impl Client for NvidiaNIMClient { Ok(Self { client }) } - async fn check_native_tools_support(&self) -> Result { - self.client.check_native_tools_support().await + async fn check_supported_features(&self) -> Result { + self.client.check_supported_features().await } async fn chat( diff --git a/src/agent/generator/novita.rs b/src/agent/generator/novita.rs index 7dd8574..5ebf4a1 100644 --- a/src/agent/generator/novita.rs +++ b/src/agent/generator/novita.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::agent::state::SharedState; -use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client}; +use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures}; pub struct NovitaClient { client: OpenAIClient, @@ -24,6 +24,10 @@ impl Client for NovitaClient { Ok(Self { client }) } + async fn check_supported_features(&self) -> Result { + self.client.check_supported_features().await + } + async fn chat( &self, state: SharedState, diff --git a/src/agent/generator/ollama.rs b/src/agent/generator/ollama.rs index 479b53d..1a25b02 100644 --- a/src/agent/generator/ollama.rs +++ b/src/agent/generator/ollama.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use crate::agent::{state::SharedState, Invocation}; -use super::{ChatOptions, ChatResponse, Client, Message}; +use super::{ChatOptions, ChatResponse, Client, Message, SupportedFeatures}; pub struct OllamaClient { model: String, @@ -51,7 +51,7 @@ impl Client for OllamaClient { }) } - async fn check_native_tools_support(&self) -> Result { + async fn check_supported_features(&self) -> Result { let chat_history = vec![ ChatMessage::system("You are an helpful assistant.".to_string()), ChatMessage::user("Call the test function.".to_string()), @@ -79,12 +79,18 @@ impl Client for OllamaClient { if let Err(err) = self.client.send_chat_messages(request).await { if err.to_string().contains("does not support tools") { - Ok(false) + Ok(SupportedFeatures { + system_prompt: true, + tools: false, + }) } else { Err(anyhow!(err)) } } else { - Ok(true) + Ok(SupportedFeatures { + system_prompt: true, + tools: true, + }) } } diff --git a/src/agent/generator/openai.rs b/src/agent/generator/openai.rs index 453e614..4b4f73b 100644 --- a/src/agent/generator/openai.rs +++ b/src/agent/generator/openai.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use crate::agent::{state::SharedState, Invocation}; -use super::{ChatOptions, ChatResponse, Client, Message}; +use super::{ChatOptions, ChatResponse, Client, Message, SupportedFeatures}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OpenAiToolFunctionParameterProperty { @@ -130,7 +130,7 @@ impl Client for OpenAIClient { Self::custom(model_name, "OPENAI_API_KEY", "https://api.openai.com/v1/") } - async fn check_native_tools_support(&self) -> Result { + async fn check_supported_features(&self) -> Result { let chat_history = vec![ crate::api::openai::Message { role: Role::System, @@ -172,18 +172,35 @@ impl Client for OpenAIClient { log::debug!("openai.check_tools_support.resp = {:?}", &resp); + let mut system_prompt_support = true; + if let Ok(comp) = resp { if !comp.choices.is_empty() { let first = comp.choices.first().unwrap(); if let Some(m) = first.message.as_ref() { if m.tool_calls.is_some() { - return Ok(true); + return Ok(SupportedFeatures { + system_prompt: true, + tools: true, + }); } } } + } else { + let api_error = resp.unwrap_err().to_string(); + if api_error.contains("unsupported_value") + && api_error.contains("does not support 'system' with this model") + { + system_prompt_support = false; + } else { + log::error!("openai.check_tools_support.error = {}", api_error); + } } - Ok(false) + Ok(SupportedFeatures { + system_prompt: system_prompt_support, + tools: false, + }) } async fn chat( diff --git a/src/agent/generator/openai_compatible.rs b/src/agent/generator/openai_compatible.rs index 1588fd1..31e537b 100644 --- a/src/agent/generator/openai_compatible.rs +++ b/src/agent/generator/openai_compatible.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::agent::state::SharedState; -use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client}; +use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures}; pub struct OpenAiCompatibleClient { client: OpenAIClient, @@ -30,8 +30,8 @@ impl Client for OpenAiCompatibleClient { Ok(Self { client }) } - async fn check_native_tools_support(&self) -> Result { - self.client.check_native_tools_support().await + async fn check_supported_features(&self) -> Result { + self.client.check_supported_features().await } async fn chat( diff --git a/src/agent/generator/xai.rs b/src/agent/generator/xai.rs index ce9dff7..6f9285a 100644 --- a/src/agent/generator/xai.rs +++ b/src/agent/generator/xai.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use crate::agent::state::SharedState; -use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client}; +use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures}; pub struct XAIClient { client: OpenAIClient, @@ -20,8 +20,8 @@ impl Client for XAIClient { Ok(Self { client }) } - async fn check_native_tools_support(&self) -> Result { - self.client.check_native_tools_support().await + async fn check_supported_features(&self) -> Result { + self.client.check_supported_features().await } async fn chat( diff --git a/src/agent/mod.rs b/src/agent/mod.rs index f7aca43..25fcbc1 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -123,12 +123,14 @@ impl Agent { user_only: bool, max_iterations: usize, ) -> Result { + // check if the model supports tools calling and system prompt natively + let supported_features = generator.check_supported_features().await?; + let use_native_tools_format = if force_strategy { log::info!("using {:?} serialization strategy", &serializer); false } else { - // check if the model supports tools calling natively - match generator.check_native_tools_support().await? { + match supported_features.tools { true => { log::debug!("model supports tools calling natively."); true @@ -140,6 +142,14 @@ impl Agent { } }; + let user_only = if !user_only && !supported_features.system_prompt { + log::info!("model does not support system prompt, forcing user prompt"); + true + } else { + // leave whatever the user set + user_only + }; + let task_timeout = task.get_timeout(); let state = Arc::new(tokio::sync::Mutex::new( State::new(