Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: deepseek message to remove dependencies with openai #283

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rig-core/examples/agent_with_deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde_json::json;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();

Expand Down
284 changes: 250 additions & 34 deletions rig-core/src/providers/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@
use crate::{
completion::{self, CompletionError, CompletionModel, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
providers::openai::Message,
OneOrMany,
json_utils, message, OneOrMany,
};
use reqwest::Client as HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};

use super::openai::AssistantContent;

// ================================================================
// Main DeepSeek Client
// ================================================================
Expand Down Expand Up @@ -62,7 +58,7 @@ impl Client {
headers
})
.build()
.expect("OpenAI reqwest client should build"),
.expect("DeepSeek reqwest client should build"),
}
}

Expand Down Expand Up @@ -119,9 +115,182 @@ pub struct CompletionResponse {
// you may want usage or other fields
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub logprobs: Option<serde_json::Value>,
pub finish_reason: String,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "Tool")]
ToolResult {
tool_call_id: String,
content: String,
},
}

impl Message {
pub fn system(content: &str) -> Self {
Message::System {
content: content.to_owned(),
name: None,
}
}
}

impl From<message::ToolResult> for Message {
fn from(tool_result: message::ToolResult) -> Self {
let content = match tool_result.content.first() {
message::ToolResultContent::Text(text) => text.text,
message::ToolResultContent::Image(_) => String::from("[Image]"),
};

Message::ToolResult {
tool_call_id: tool_result.id,
content,
}
}
}

impl From<message::ToolCall> for ToolCall {
fn from(tool_call: message::ToolCall) -> Self {
Self {
id: tool_call.id,
// TODO: update index when we have it
index: 0,
r#type: ToolType::Function,
function: Function {
name: tool_call.function.name,
arguments: tool_call.function.arguments,
},
}
}
}

impl TryFrom<message::Message> for Vec<Message> {
type Error = message::MessageError;

fn try_from(message: message::Message) -> Result<Self, Self::Error> {
match message {
message::Message::User { content } => {
// extract tool results
let mut messages = vec![];

let tool_results = content
.clone()
.into_iter()
.filter_map(|content| match content {
message::UserContent::ToolResult(tool_result) => {
Some(Message::from(tool_result))
}
_ => None,
})
.collect::<Vec<_>>();

messages.extend(tool_results);

// extract text results
let text_messages = content
.into_iter()
.filter_map(|content| match content {
message::UserContent::Text(text) => Some(Message::User {
content: text.text,
name: None,
}),
_ => None,
})
.collect::<Vec<_>>();
messages.extend(text_messages);

Ok(messages)
}
message::Message::Assistant { content } => {
let mut messages: Vec<Message> = vec![];

// extract tool calls
let tool_calls = content
.clone()
.into_iter()
.filter_map(|content| match content {
message::AssistantContent::ToolCall(tool_call) => {
Some(ToolCall::from(tool_call))
}
_ => None,
})
.collect::<Vec<_>>();

// if we have tool calls, we add a new Assistant message with them
if !tool_calls.is_empty() {
messages.push(Message::Assistant {
content: "".to_string(),
name: None,
tool_calls,
});
}

// extract text
let text_content = content
.into_iter()
.filter_map(|content| match content {
message::AssistantContent::Text(text) => Some(Message::Assistant {
content: text.text,
name: None,
tool_calls: vec![],
}),
_ => None,
})
.collect::<Vec<_>>();

messages.extend(text_content);

Ok(messages)
}
}
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolCall {
pub id: String,
pub index: usize,
#[serde(default)]
pub r#type: ToolType,
pub function: Function,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct Function {
pub name: String,
#[serde(with = "json_utils::stringified_json")]
pub arguments: serde_json::Value,
}

#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
#[default]
Function,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -152,15 +321,11 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
tool_calls,
..
} => {
let mut content = content
.iter()
.map(|c| match c {
AssistantContent::Text { text } => completion::AssistantContent::text(text),
AssistantContent::Refusal { refusal } => {
completion::AssistantContent::text(refusal)
}
})
.collect::<Vec<_>>();
let mut content = if content.trim().is_empty() {
vec![]
} else {
vec![completion::AssistantContent::text(content)]
};

content.extend(
tool_calls
Expand Down Expand Up @@ -295,35 +460,40 @@ pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
// Tests
#[cfg(test)]
mod tests {
use crate::providers::openai;

use super::*;

#[test]
fn test_deserialize_vec_choice() {
let data = r#"[{"message":{"role":"assistant","content":"Hello, world!"}}]"#;
let data = r#"[{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message":{"role":"assistant","content":"Hello, world!"}
}]"#;

let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
assert_eq!(choices.len(), 1);
match &choices.first().unwrap().message {
Message::Assistant { content, .. } => match &content[0] {
openai::AssistantContent::Text { text } => assert_eq!(text, "Hello, world!"),
_ => panic!("Expected text content"),
},
Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
_ => panic!("Expected assistant message"),
}
}

#[test]
fn test_deserialize_deepseek_response() {
let data = r#"{"choices":[{"message":{"role":"assistant","content":"Hello, world!"}}]}"#;
let data = r#"{"choices":[{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message":{"role":"assistant","content":"Hello, world!"}
}]}"#;

let jd = &mut serde_json::Deserializer::from_str(data);
let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
match result {
Ok(response) => match &response.choices.first().unwrap().message {
Message::Assistant { content, .. } => match &content[0] {
openai::AssistantContent::Text { text } => assert_eq!(text, "Hello, world!"),
_ => panic!("Expected text content"),
},
Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
_ => panic!("Expected assistant message"),
},
Err(err) => {
Expand Down Expand Up @@ -369,18 +539,64 @@ mod tests {

match result {
Ok(response) => match &response.choices.first().unwrap().message {
Message::Assistant { content, .. } => match &content[0] {
openai::AssistantContent::Text { text } => assert_eq!(
text,
"Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
),
_ => panic!("Expected text content"),
},
Message::Assistant { content, .. } => assert_eq!(
content,
"Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
),
_ => panic!("Expected assistant message"),
},
Err(err) => {
panic!("Deserialization error at {}: {}", err.path(), err);
}
}
}

#[test]
fn test_serialize_deserialize_tool_call_message() {
let tool_call_choice_json = r#"
{
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null,
"message": {
"content": "",
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "{\"x\":2,\"y\":5}",
"name": "subtract"
},
"id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
"index": 0,
"type": "function"
}
]
}
}
"#;

let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();

let expected_choice: Choice = Choice {
finish_reason: "tool_calls".to_string(),
index: 0,
logprobs: None,
message: Message::Assistant {
content: "".to_string(),
name: None,
tool_calls: vec![ToolCall {
id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
function: Function {
name: "subtract".to_string(),
arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
},
index: 0,
r#type: ToolType::Function,
}],
},
};

assert_eq!(choice, expected_choice);
}
}