Skip to content

Commit

Permalink
new: automatically detect model supported features such as system pro…
Browse files Browse the repository at this point in the history
…mpt support and function calling
  • Loading branch information
evilsocket committed Dec 10, 2024
1 parent a5c0c84 commit c8ed591
Show file tree
Hide file tree
Showing 15 changed files with 111 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Nerve features integrations for any model accessible via the following providers
| **Mistral.ai** | `MISTRAL_API_KEY` | `mistral://mistral-large-latest` |
| **Novita** | `NOVITA_API_KEY` | `novita://meta-llama/llama-3.1-70b-instruct` |

¹ **o1-preview and o1 models do not support function calling directly** and do not support a system prompt. It is possible to workaround this by adding the `--user-only` flag to the command line.
¹ **o1-preview and o1 models do not support function calling directly** and do not support a system prompt. Nerve will try to detect this and fallback to user prompt. It is possible to force this behaviour by adding the `--user-only` flag to the command line.

² Refer to [this document](https://huggingface.co/blog/tgi-messages-api#using-inference-endpoints-with-openai-client-libraries) for how to configure a custom Huggingface endpoint.

Expand Down
22 changes: 14 additions & 8 deletions src/agent/generator/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashMap;

use crate::agent::{
generator::{ChatResponse, Usage},
generator::{ChatResponse, SupportedFeatures, Usage},
state::SharedState,
Invocation,
};
Expand Down Expand Up @@ -102,7 +102,7 @@ impl Client for AnthropicClient {
Ok(Self { model, client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
let messages = vec![Message::user("Execute the test function.")];
let max_tokens = MaxTokens::new(4096, self.model)?;

Expand All @@ -128,9 +128,15 @@ impl Client for AnthropicClient {
log::debug!("response = {:?}", response);

if let Ok(tool_use) = response.content.flatten_into_tool_use() {
Ok(tool_use.name == "test")
Ok(SupportedFeatures {
system_prompt: true,
tools: tool_use.name == "test",
})
} else {
Ok(false)
Ok(SupportedFeatures {
system_prompt: true,
tools: false,
})
}
}

Expand Down Expand Up @@ -183,10 +189,10 @@ impl Client for AnthropicClient {

let request_body = MessagesRequestBody {
model: self.model,
system: match &options.system_prompt {
Some(sp) => Some(SystemPrompt::new(sp.trim())),
None => None,
},
system: options
.system_prompt
.as_ref()
.map(|sp| SystemPrompt::new(sp.trim())),
messages,
max_tokens,
tools: if tools.is_empty() { None } else { Some(tools) },
Expand Down
6 changes: 3 additions & 3 deletions src/agent/generator/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures};

pub struct DeepSeekClient {
client: OpenAIClient,
Expand All @@ -24,8 +24,8 @@ impl Client for DeepSeekClient {
Ok(Self { client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
self.client.check_native_tools_support().await
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
self.client.check_supported_features().await
}

async fn chat(
Expand Down
6 changes: 5 additions & 1 deletion src/agent/generator/fireworks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures};

pub struct FireworksClient {
client: OpenAIClient,
Expand All @@ -24,6 +24,10 @@ impl Client for FireworksClient {
Ok(Self { client })
}

async fn check_supported_features(&self) -> Result<SupportedFeatures> {
self.client.check_supported_features().await
}

async fn chat(
&self,
state: SharedState,
Expand Down
9 changes: 6 additions & 3 deletions src/agent/generator/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::agent::{
Invocation,
};

use super::{ChatOptions, Client};
use super::{ChatOptions, Client, SupportedFeatures};

lazy_static! {
static ref RETRY_TIME_PARSER: Regex =
Expand Down Expand Up @@ -58,7 +58,7 @@ impl Client for GroqClient {
Ok(Self { model, api_key })
}

async fn check_native_tools_support(&self) -> Result<bool> {
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
let chat_history = vec![
crate::api::groq::completion::message::Message::SystemMessage {
role: Some("system".to_string()),
Expand Down Expand Up @@ -106,7 +106,10 @@ impl Client for GroqClient {

log::debug!("groq.check_tools_support.resp = {:?}", &resp);

Ok(resp.is_ok())
Ok(SupportedFeatures {
system_prompt: true,
tools: resp.is_ok(),
})
}

async fn chat(
Expand Down
6 changes: 3 additions & 3 deletions src/agent/generator/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures};

pub struct HuggingfaceMessageClient {
client: OpenAIClient,
Expand All @@ -23,8 +23,8 @@ impl Client for HuggingfaceMessageClient {
Ok(Self { client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
self.client.check_native_tools_support().await
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
self.client.check_supported_features().await
}

async fn chat(
Expand Down
11 changes: 6 additions & 5 deletions src/agent/generator/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures};

pub struct MistralClient {
client: OpenAIClient,
Expand All @@ -15,13 +15,14 @@ impl Client for MistralClient {
where
Self: Sized,
{
let client = OpenAIClient::custom(model_name, "MISTRAL_API_KEY", "https://api.mistral.ai/v1/")?;
let client =
OpenAIClient::custom(model_name, "MISTRAL_API_KEY", "https://api.mistral.ai/v1/")?;

Ok(Self { client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
self.client.check_native_tools_support().await
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
self.client.check_supported_features().await
}

async fn chat(
Expand All @@ -33,7 +34,7 @@ impl Client for MistralClient {

if let Err(error) = &response {
if self.check_rate_limit(&error.to_string()).await {
return self.chat(state, options).await;
return self.chat(state, options).await;
}
}

Expand Down
18 changes: 16 additions & 2 deletions src/agent/generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ pub struct ChatResponse {
pub usage: Option<Usage>,
}

pub struct SupportedFeatures {
pub system_prompt: bool,
pub tools: bool,
}

impl Default for SupportedFeatures {
fn default() -> Self {
Self {
system_prompt: true,
tools: false,
}
}
}

#[async_trait]
pub trait Client: mini_rag::Embedder + Send + Sync {
fn new(url: &str, port: u16, model_name: &str, context_window: u32) -> Result<Self>
Expand All @@ -97,8 +111,8 @@ pub trait Client: mini_rag::Embedder + Send + Sync {

async fn chat(&self, state: SharedState, options: &ChatOptions) -> Result<ChatResponse>;

async fn check_native_tools_support(&self) -> Result<bool> {
Ok(false)
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
Ok(SupportedFeatures::default())
}

async fn check_rate_limit(&self, error: &str) -> bool {
Expand Down
6 changes: 3 additions & 3 deletions src/agent/generator/nim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures};

pub struct NvidiaNIMClient {
client: OpenAIClient,
Expand All @@ -30,8 +30,8 @@ impl Client for NvidiaNIMClient {
Ok(Self { client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
self.client.check_native_tools_support().await
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
self.client.check_supported_features().await
}

async fn chat(
Expand Down
6 changes: 5 additions & 1 deletion src/agent/generator/novita.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures};

pub struct NovitaClient {
client: OpenAIClient,
Expand All @@ -24,6 +24,10 @@ impl Client for NovitaClient {
Ok(Self { client })
}

async fn check_supported_features(&self) -> Result<SupportedFeatures> {
self.client.check_supported_features().await
}

async fn chat(
&self,
state: SharedState,
Expand Down
14 changes: 10 additions & 4 deletions src/agent/generator/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use async_trait::async_trait;

use crate::agent::{state::SharedState, Invocation};

use super::{ChatOptions, ChatResponse, Client, Message};
use super::{ChatOptions, ChatResponse, Client, Message, SupportedFeatures};

pub struct OllamaClient {
model: String,
Expand Down Expand Up @@ -51,7 +51,7 @@ impl Client for OllamaClient {
})
}

async fn check_native_tools_support(&self) -> Result<bool> {
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
let chat_history = vec![
ChatMessage::system("You are an helpful assistant.".to_string()),
ChatMessage::user("Call the test function.".to_string()),
Expand Down Expand Up @@ -79,12 +79,18 @@ impl Client for OllamaClient {

if let Err(err) = self.client.send_chat_messages(request).await {
if err.to_string().contains("does not support tools") {
Ok(false)
Ok(SupportedFeatures {
system_prompt: true,
tools: false,
})
} else {
Err(anyhow!(err))
}
} else {
Ok(true)
Ok(SupportedFeatures {
system_prompt: true,
tools: true,
})
}
}

Expand Down
25 changes: 21 additions & 4 deletions src/agent/generator/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};

use crate::agent::{state::SharedState, Invocation};

use super::{ChatOptions, ChatResponse, Client, Message};
use super::{ChatOptions, ChatResponse, Client, Message, SupportedFeatures};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiToolFunctionParameterProperty {
Expand Down Expand Up @@ -130,7 +130,7 @@ impl Client for OpenAIClient {
Self::custom(model_name, "OPENAI_API_KEY", "https://api.openai.com/v1/")
}

async fn check_native_tools_support(&self) -> Result<bool> {
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
let chat_history = vec![
crate::api::openai::Message {
role: Role::System,
Expand Down Expand Up @@ -172,18 +172,35 @@ impl Client for OpenAIClient {

log::debug!("openai.check_tools_support.resp = {:?}", &resp);

let mut system_prompt_support = true;

if let Ok(comp) = resp {
if !comp.choices.is_empty() {
let first = comp.choices.first().unwrap();
if let Some(m) = first.message.as_ref() {
if m.tool_calls.is_some() {
return Ok(true);
return Ok(SupportedFeatures {
system_prompt: true,
tools: true,
});
}
}
}
} else {
let api_error = resp.unwrap_err().to_string();
if api_error.contains("unsupported_value")
&& api_error.contains("does not support 'system' with this model")
{
system_prompt_support = false;
} else {
log::error!("openai.check_tools_support.error = {}", api_error);
}
}

Ok(false)
Ok(SupportedFeatures {
system_prompt: system_prompt_support,
tools: false,
})
}

async fn chat(
Expand Down
6 changes: 3 additions & 3 deletions src/agent/generator/openai_compatible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures};

pub struct OpenAiCompatibleClient {
client: OpenAIClient,
Expand All @@ -30,8 +30,8 @@ impl Client for OpenAiCompatibleClient {
Ok(Self { client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
self.client.check_native_tools_support().await
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
self.client.check_supported_features().await
}

async fn chat(
Expand Down
6 changes: 3 additions & 3 deletions src/agent/generator/xai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;

use crate::agent::state::SharedState;

use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client};
use super::{openai::OpenAIClient, ChatOptions, ChatResponse, Client, SupportedFeatures};

pub struct XAIClient {
client: OpenAIClient,
Expand All @@ -20,8 +20,8 @@ impl Client for XAIClient {
Ok(Self { client })
}

async fn check_native_tools_support(&self) -> Result<bool> {
self.client.check_native_tools_support().await
async fn check_supported_features(&self) -> Result<SupportedFeatures> {
self.client.check_supported_features().await
}

async fn chat(
Expand Down
Loading

0 comments on commit c8ed591

Please sign in to comment.