Skip to content

Commit

Permalink
refactor: deepseek message to remove dependencies with openai (#283)
Browse files Browse the repository at this point in the history
* refactor: deepseek message to remove dependencies with openai
* chore: apply pr comments
  • Loading branch information
carlos-verdes authored Feb 12, 2025
1 parent 57e536e commit 530a327
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 35 deletions.
2 changes: 1 addition & 1 deletion rig-core/examples/agent_with_deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde_json::json;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();

Expand Down
284 changes: 250 additions & 34 deletions rig-core/src/providers/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@
use crate::{
completion::{self, CompletionError, CompletionModel, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
providers::openai::Message,
OneOrMany,
json_utils, message, OneOrMany,
};
use reqwest::Client as HttpClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};

use super::openai::AssistantContent;

// ================================================================
// Main DeepSeek Client
// ================================================================
Expand Down Expand Up @@ -62,7 +58,7 @@ impl Client {
headers
})
.build()
.expect("OpenAI reqwest client should build"),
.expect("DeepSeek reqwest client should build"),
}
}

Expand Down Expand Up @@ -119,9 +115,182 @@ pub struct CompletionResponse {
// you may want usage or other fields
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct Choice {
pub index: usize,
pub message: Message,
pub logprobs: Option<serde_json::Value>,
pub finish_reason: String,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum Message {
System {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "Tool")]
ToolResult {
tool_call_id: String,
content: String,
},
}

impl Message {
pub fn system(content: &str) -> Self {
Message::System {
content: content.to_owned(),
name: None,
}
}
}

impl From<message::ToolResult> for Message {
fn from(tool_result: message::ToolResult) -> Self {
let content = match tool_result.content.first() {
message::ToolResultContent::Text(text) => text.text,
message::ToolResultContent::Image(_) => String::from("[Image]"),
};

Message::ToolResult {
tool_call_id: tool_result.id,
content,
}
}
}

impl From<message::ToolCall> for ToolCall {
fn from(tool_call: message::ToolCall) -> Self {
Self {
id: tool_call.id,
// TODO: update index when we have it
index: 0,
r#type: ToolType::Function,
function: Function {
name: tool_call.function.name,
arguments: tool_call.function.arguments,
},
}
}
}

impl TryFrom<message::Message> for Vec<Message> {
type Error = message::MessageError;

fn try_from(message: message::Message) -> Result<Self, Self::Error> {
match message {
message::Message::User { content } => {
// extract tool results
let mut messages = vec![];

let tool_results = content
.clone()
.into_iter()
.filter_map(|content| match content {
message::UserContent::ToolResult(tool_result) => {
Some(Message::from(tool_result))
}
_ => None,
})
.collect::<Vec<_>>();

messages.extend(tool_results);

// extract text results
let text_messages = content
.into_iter()
.filter_map(|content| match content {
message::UserContent::Text(text) => Some(Message::User {
content: text.text,
name: None,
}),
_ => None,
})
.collect::<Vec<_>>();
messages.extend(text_messages);

Ok(messages)
}
message::Message::Assistant { content } => {
let mut messages: Vec<Message> = vec![];

// extract tool calls
let tool_calls = content
.clone()
.into_iter()
.filter_map(|content| match content {
message::AssistantContent::ToolCall(tool_call) => {
Some(ToolCall::from(tool_call))
}
_ => None,
})
.collect::<Vec<_>>();

// if we have tool calls, we add a new Assistant message with them
if !tool_calls.is_empty() {
messages.push(Message::Assistant {
content: "".to_string(),
name: None,
tool_calls,
});
}

// extract text
let text_content = content
.into_iter()
.filter_map(|content| match content {
message::AssistantContent::Text(text) => Some(Message::Assistant {
content: text.text,
name: None,
tool_calls: vec![],
}),
_ => None,
})
.collect::<Vec<_>>();

messages.extend(text_content);

Ok(messages)
}
}
}
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolCall {
pub id: String,
pub index: usize,
#[serde(default)]
pub r#type: ToolType,
pub function: Function,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct Function {
pub name: String,
#[serde(with = "json_utils::stringified_json")]
pub arguments: serde_json::Value,
}

#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
#[default]
Function,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -152,15 +321,11 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
tool_calls,
..
} => {
let mut content = content
.iter()
.map(|c| match c {
AssistantContent::Text { text } => completion::AssistantContent::text(text),
AssistantContent::Refusal { refusal } => {
completion::AssistantContent::text(refusal)
}
})
.collect::<Vec<_>>();
let mut content = if content.trim().is_empty() {
vec![]
} else {
vec![completion::AssistantContent::text(content)]
};

content.extend(
tool_calls
Expand Down Expand Up @@ -295,35 +460,40 @@ pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner";
// Tests
#[cfg(test)]
mod tests {
use crate::providers::openai;

use super::*;

#[test]
fn test_deserialize_vec_choice() {
let data = r#"[{"message":{"role":"assistant","content":"Hello, world!"}}]"#;
let data = r#"[{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message":{"role":"assistant","content":"Hello, world!"}
}]"#;

let choices: Vec<Choice> = serde_json::from_str(data).unwrap();
assert_eq!(choices.len(), 1);
match &choices.first().unwrap().message {
Message::Assistant { content, .. } => match &content[0] {
openai::AssistantContent::Text { text } => assert_eq!(text, "Hello, world!"),
_ => panic!("Expected text content"),
},
Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
_ => panic!("Expected assistant message"),
}
}

#[test]
fn test_deserialize_deepseek_response() {
let data = r#"{"choices":[{"message":{"role":"assistant","content":"Hello, world!"}}]}"#;
let data = r#"{"choices":[{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message":{"role":"assistant","content":"Hello, world!"}
}]}"#;

let jd = &mut serde_json::Deserializer::from_str(data);
let result: Result<CompletionResponse, _> = serde_path_to_error::deserialize(jd);
match result {
Ok(response) => match &response.choices.first().unwrap().message {
Message::Assistant { content, .. } => match &content[0] {
openai::AssistantContent::Text { text } => assert_eq!(text, "Hello, world!"),
_ => panic!("Expected text content"),
},
Message::Assistant { content, .. } => assert_eq!(content, "Hello, world!"),
_ => panic!("Expected assistant message"),
},
Err(err) => {
Expand Down Expand Up @@ -369,18 +539,64 @@ mod tests {

match result {
Ok(response) => match &response.choices.first().unwrap().message {
Message::Assistant { content, .. } => match &content[0] {
openai::AssistantContent::Text { text } => assert_eq!(
text,
"Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
),
_ => panic!("Expected text content"),
},
Message::Assistant { content, .. } => assert_eq!(
content,
"Why don’t skeletons fight each other? \nBecause they don’t have the guts! 😄"
),
_ => panic!("Expected assistant message"),
},
Err(err) => {
panic!("Deserialization error at {}: {}", err.path(), err);
}
}
}

#[test]
fn test_serialize_deserialize_tool_call_message() {
let tool_call_choice_json = r#"
{
"finish_reason": "tool_calls",
"index": 0,
"logprobs": null,
"message": {
"content": "",
"role": "assistant",
"tool_calls": [
{
"function": {
"arguments": "{\"x\":2,\"y\":5}",
"name": "subtract"
},
"id": "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b",
"index": 0,
"type": "function"
}
]
}
}
"#;

let choice: Choice = serde_json::from_str(tool_call_choice_json).unwrap();

let expected_choice: Choice = Choice {
finish_reason: "tool_calls".to_string(),
index: 0,
logprobs: None,
message: Message::Assistant {
content: "".to_string(),
name: None,
tool_calls: vec![ToolCall {
id: "call_0_2b4a85ee-b04a-40ad-a16b-a405caf6e65b".to_string(),
function: Function {
name: "subtract".to_string(),
arguments: serde_json::from_str(r#"{"x":2,"y":5}"#).unwrap(),
},
index: 0,
r#type: ToolType::Function,
}],
},
};

assert_eq!(choice, expected_choice);
}
}

0 comments on commit 530a327

Please sign in to comment.