diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index f55f6d27..443cc257 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -52,8 +52,7 @@ use crate::{ embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, json_utils, - message::{self, AudioMediaType, ImageDetail}, - one_or_many::string_or_one_or_many, + message::{AudioMediaType, ImageDetail}, Embed, OneOrMany, }; use schemars::JsonSchema; @@ -61,21 +60,10 @@ use serde::{Deserialize, Serialize}; use serde_json::json; use reqwest; - // ================================================================= // FromStr implementations for provider types (for deserialization) // ================================================================= -impl FromStr for SystemContent { - type Err = Infallible; - fn from_str(s: &str) -> Result { - Ok(SystemContent { - r#type: SystemContentType::Text, - text: s.to_owned(), - }) - } -} - impl FromStr for UserContent { type Err = Infallible; fn from_str(s: &str) -> Result { @@ -162,6 +150,7 @@ enum ApiResponse { // ================================================================= pub const ALL_MINILM: &str = "all-minilm"; +pub const NOMIC_EMBED_TEXT: &str = "nomic-embed-text"; #[derive(Debug, Serialize, Deserialize)] pub struct EmbeddingResponse { @@ -322,14 +311,11 @@ impl TryFrom for completion::CompletionResponse fn try_from(resp: ChatResponse) -> Result { match resp.message { Message::Assistant { ref content, .. } => { - let texts: Vec = content.into_iter().map(|c| { - match c { - AssistantContent::Text { ref text } => completion::AssistantContent::text(text), - AssistantContent::Refusal { ref refusal } => completion::AssistantContent::text(refusal), - } - }).collect(); - let choice = OneOrMany::many(texts) - .map_err(|e| CompletionError::ResponseError(e.to_string()))?; + // Since the provider Message's content is now a String, + // create a single AssistantContent from it. + let assistant_content = completion::AssistantContent::text(content); + // Directly construct OneOrMany from the assistant_content. + let choice = OneOrMany::one(assistant_content); if choice.is_empty() { return Err(CompletionError::ResponseError("Empty chat response".into())); } @@ -337,7 +323,9 @@ impl TryFrom for completion::CompletionResponse .map_err(|e| CompletionError::ResponseError(e.to_string()))?; Ok(completion::CompletionResponse { choice, raw_response: raw }) }, - _ => Err(CompletionError::ResponseError("Chat response does not include an assistant message".into())), + _ => Err(CompletionError::ResponseError( + "Chat response does not include an assistant message".into(), + )), } } } @@ -355,153 +343,209 @@ impl CompletionModel { } /// In our unified API, we set the associated Response type to be a JSON value. +// ----------------------------- +// Additional conversion implementations +// ----------------------------- + +// This implementation allows converting an internal message (crate::message::Message) +// into a Vec of provider Message. This is used when combining prompt context. +impl TryFrom for Vec { + type Error = crate::message::MessageError; + fn try_from(internal_msg: crate::message::Message) -> Result { + // For now, simply convert the internal message to a provider Message and wrap it in a Vec. + Ok(vec![Message::try_from(internal_msg)?]) + } +} + +// This implementation allows the '?' operator to convert an Infallible error into a CompletionError. +impl From for CompletionError { + fn from(_: std::convert::Infallible) -> Self { + CompletionError::ProviderError("Infallible error".to_string()) + } +} + +// ----------------------------- +// CompletionModel implementation +// ----------------------------- +// Helper method for provider Message conversion to a plain prompt string. + impl completion::CompletionModel for CompletionModel { type Response = serde_json::Value; async fn completion( &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - if !completion_request.chat_history.is_empty() { - // Chat mode: - let prompt_internal: crate::message::Message = completion_request.prompt_with_context(); - let prompt_msg: Message = prompt_internal.try_into()?; + // Convert internal prompt using prompt_with_context() into Vec + let prompt: Vec = completion_request.prompt_with_context().try_into()?; + let default_options = json!({ + "temperature": completion_request.temperature, + }); + // Determine chat mode: if chat history is non-empty OR prompt returns more than one message, use chat mode. + if !completion_request.chat_history.is_empty() || prompt.len() > 1 { + // Chat mode: build full conversation history as an array. + let mut full_history: Vec = match &completion_request.preamble { + Some(preamble) => vec![Message::system(preamble)], + None => vec![], + }; + + // Convert chat history: each internal message may yield multiple provider messages. let chat_history: Vec = completion_request .chat_history .into_iter() .map(|m| m.try_into()) - .collect::>()?; - let mut full_history = Vec::new(); - if let Some(preamble) = &completion_request.preamble { - full_history.push(Message::system(preamble)); - } + .collect::>, _>>()? + .into_iter() + .flatten() + .collect(); + full_history.extend(chat_history); - full_history.push(prompt_msg); - let mut request_payload = json!({ + full_history.extend(prompt); + let options = if let Some(extra) = completion_request.additional_params { + json_utils::merge(default_options, extra) + } else { + default_options + }; + + let request_payload = json!({ "model": self.model, - "messages": full_history, + "messages": full_history, // Send as an array, per API specification. + "temperature": options, "stream": false, }); - if let Some(params) = &completion_request.additional_params { - request_payload = json_utils::merge(request_payload, params.clone()); - } + + tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload); let response = self.client.post("api/chat") .json(&request_payload) .send() .await .map_err(|e| CompletionError::ProviderError(e.to_string()))?; if response.status().is_success() { - let text = response.text().await.map_err(|e| CompletionError::ProviderError(e.to_string()))?; + let text = response.text().await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; tracing::debug!(target: "rig", "Ollama chat response: {}", text); let chat_resp: ChatResponse = serde_json::from_str(&text) .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let conv: completion::CompletionResponse = chat_resp.try_into()?; Ok(conv) } else { - let err_text = response.text().await.map_err(|e| CompletionError::ProviderError(e.to_string()))?; + let err_text = response.text().await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; Err(CompletionError::ProviderError(err_text)) } } else { - // Single-turn mode: - // Convert the internal message (which is our prompt) to a plain string. - let full_prompt = internal_message_to_string(&completion_request.prompt); + // Single-turn mode: if prompt_with_context() returns empty, fallback to converting internal prompt to a plain string. + let full_prompt = provider_messages_to_string(&prompt); let mut request_payload = json!({ "model": self.model, - "prompt": full_prompt.clone(), + "prompt": full_prompt, // prompt must be a string + "temperature": completion_request.temperature, "stream": false, }); - if let Some(params) = &completion_request.additional_params { - // Remove any "prompt" key from additional parameters - let mut params = params.clone(); - if let Some(map) = params.as_object_mut() { - map.remove("prompt"); - } + if let Some(params) = completion_request.additional_params { request_payload = json_utils::merge(request_payload, params); } - // Forcefully re-set "prompt" as a string. - if let Some(obj) = request_payload.as_object_mut() { - obj.insert("prompt".to_string(), serde_json::Value::String(full_prompt)); - } + tracing::debug!(target: "rig", "Single-turn payload: {}", request_payload); let response = self.client.post("api/generate") .json(&request_payload) .send() .await .map_err(|e| CompletionError::ProviderError(e.to_string()))?; if response.status().is_success() { - let text = response.text().await.map_err(|e| CompletionError::ProviderError(e.to_string()))?; + let text = response.text().await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; tracing::debug!(target: "rig", "Ollama generate response: {}", text); let gen_resp: CompletionResponse = serde_json::from_str(&text) .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let conv: completion::CompletionResponse = gen_resp.try_into()?; Ok(conv) } else { - let err_text = response.text().await.map_err(|e| CompletionError::ProviderError(e.to_string()))?; + let err_text = response.text().await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; Err(CompletionError::ProviderError(err_text)) } } } } +// Helper function: convert a slice of provider Message into a plain string. +// For each message, we extract the text from User, Assistant or System variants. +fn provider_messages_to_string(messages: &[Message]) -> String { + messages + .iter() + .map(|msg| msg.to_prompt()) + .collect::>() + .join("\n") +} + + // ================================================================= // Provider Message Definitions and Conversions // ================================================================= +// ================================================================= +// Provider Message Definitions for Ollama (simplified for API) +// ================================================================= + #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] #[serde(tag = "role", rename_all = "lowercase")] pub enum Message { User { - #[serde(deserialize_with = "string_or_one_or_many")] - content: OneOrMany, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + images: Option>, #[serde(skip_serializing_if = "Option::is_none")] name: Option, }, Assistant { - #[serde(default, deserialize_with = "crate::json_utils::string_or_vec")] - content: Vec, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + images: Option>, #[serde(skip_serializing_if = "Option::is_none")] refusal: Option, #[serde(skip_serializing_if = "Option::is_none")] audio: Option, #[serde(skip_serializing_if = "Option::is_none")] name: Option, - #[serde(default, deserialize_with = "crate::json_utils::null_or_vec")] - tool_calls: Vec, }, System { - #[serde(deserialize_with = "string_or_one_or_many")] - content: OneOrMany, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + images: Option>, #[serde(skip_serializing_if = "Option::is_none")] name: Option, }, #[serde(rename = "Tool")] ToolResult { tool_call_id: String, - #[serde(deserialize_with = "string_or_one_or_many")] - content: OneOrMany, + content: String, }, } +// Implement a helper method on provider Message to extract text for prompt. impl Message { + pub fn to_prompt(&self) -> String { + match self { + Message::User { content, .. } => content.clone(), + Message::Assistant { content, .. } => content.clone(), + Message::System { content, .. } => content.clone(), + Message::ToolResult { content, .. } => content.clone(), + } + } + + // A convenience method to create a system message from a string. pub fn system(content: &str) -> Self { Message::System { - content: OneOrMany::many(vec![SystemContent { - r#type: SystemContentType::Text, - text: content.to_owned(), - }]).expect("Non-empty system content"), + content: content.to_owned(), + images: None, name: None, } } } - -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct AudioAssistant { - pub id: String, -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct SystemContent { #[serde(default)] - pub r#type: SystemContentType, - pub text: String, + r#type: SystemContentType, + text: String, } #[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)] @@ -511,6 +555,29 @@ pub enum SystemContentType { Text, } +impl From for SystemContent { + fn from(s: String) -> Self { + SystemContent { + r#type: SystemContentType::default(), + text: s, + } + } +} + +impl FromStr for SystemContent { + type Err = Infallible; + + fn from_str(s: &str) -> Result { + Ok(SystemContent { + r#type: SystemContentType::default(), + text: s.to_string(), + }) + } +} +#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] +pub struct AudioAssistant { + pub id: String, +} #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] #[serde(tag = "type", rename_all = "lowercase")] pub enum AssistantContent { @@ -582,7 +649,7 @@ pub struct Function { // ================================================================= // Conversion from internal Rig message (crate::message::Message) // to provider Message. -// (Only User and Assistant variants are supported.) +// (Only User, Assistant and System variants are supported.) // ================================================================= impl TryFrom for Message { @@ -590,211 +657,35 @@ impl TryFrom for Message { fn try_from(internal_msg: crate::message::Message) -> Result { use crate::message::Message as InternalMessage; match internal_msg { - InternalMessage::User { content } => { - let converted: Result, _> = content.into_iter().map(|uc| { + InternalMessage::User { content, .. } => { + let mut texts = Vec::new(); + let mut images = Vec::new(); + for uc in content.into_iter() { match uc { - crate::message::UserContent::Text(t) => Ok(UserContent::Text { text: t.text }), - crate::message::UserContent::Image(img) => Ok(UserContent::Image { - image_url: ImageUrl { - url: img.data, - detail: img.detail.unwrap_or_default(), - }, - }), - crate::message::UserContent::Audio(audio) => Ok(UserContent::Audio { - input_audio: InputAudio { - data: audio.data, - format: audio.media_type.unwrap_or(AudioMediaType::MP3), - }, - }), - other => Err(crate::message::MessageError::ConversionError(format!("Unsupported user content: {:?}", other))), + crate::message::UserContent::Text(t) => texts.push(t.text), + crate::message::UserContent::Image(img) => images.push(img.data), + crate::message::UserContent::Audio(_audio) => { + } + _ => {} } - }).collect(); - let one = OneOrMany::many(converted?) - .map_err(|e| crate::message::MessageError::ConversionError(e.to_string()))?; - Ok(Message::User { content: one, name: None }) + } + let content_str = texts.join(" "); + let images_opt = if images.is_empty() { None } else { Some(images) }; + Ok(Message::User { content: content_str, images: images_opt, name: None }) } - InternalMessage::Assistant { content } => { - let converted: Result, _> = content.into_iter().map(|ac| { + InternalMessage::Assistant { content, .. } => { + let mut texts = Vec::new(); + let images = Vec::new(); + for ac in content.into_iter() { match ac { - crate::message::AssistantContent::Text(t) => Ok(AssistantContent::Text { text: t.text }), - other => Err(crate::message::MessageError::ConversionError(format!("Unsupported assistant content: {:?}", other))), + crate::message::AssistantContent::Text(t) => texts.push(t.text), + _ => {} } - }).collect(); - Ok(Message::Assistant { - content: converted?, - refusal: None, - audio: None, - name: None, - tool_calls: vec![], - }) + } + let content_str = texts.join(" "); + let images_opt = if images.is_empty() { None } else { Some(images) }; + Ok(Message::Assistant { content: content_str, images: images_opt, refusal: None, audio: None, name: None, }) } - other => Err(crate::message::MessageError::ConversionError(format!("Unsupported internal message variant: {:?}", other))), } } } - -// ================================================================= -// Helper: Convert internal message (the prompt) to a plain string. -// ================================================================= - -fn internal_message_to_string(msg: &crate::message::Message) -> String { - use crate::message::{Message as InternalMessage, UserContent}; - match msg { - InternalMessage::User { content, .. } => { - content.iter().filter_map(|uc| { - if let UserContent::Text(t) = uc { - Some(t.text.clone()) - } else { - None - } - }).collect::>().join("\n") - }, - _ => format!("{:?}", msg), - } -} - -// ================================================================= -// Tests -// ================================================================= - -#[cfg(test)] -mod tests { - use super::*; - use serde_path_to_error::deserialize; - #[test] - fn test_deserialize_message() { - let assistant_message_json = r#" - { - "role": "assistant", - "content": "\n\nHello there, how may I assist you today?" - } - "#; - let assistant_message_json2 = r#" - { - "role": "assistant", - "content": [ - { - "type": "text", - "text": "\n\nHello there, how may I assist you today?" - } - ], - "tool_calls": null - } - "#; - let assistant_message_json3 = r#" - { - "role": "assistant", - "tool_calls": [ - { - "id": "call_h89ipqYUjEpCPI6SxspMnoUU", - "type": "function", - "function": { - "name": "subtract", - "arguments": "{\"x\": 2, \"y\": 5}" - } - } - ], - "content": null, - "refusal": null - } - "#; - let user_message_json = r#" - { - "role": "user", - "content": [ - { - "type": "text", - "text": "What's in this image?" - }, - { - "type": "image", - "image_url": { - "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" - } - }, - { - "type": "audio", - "input_audio": { - "data": "...", - "format": "mp3" - } - } - ] - } - "#; - let assistant_message: Message = { - let jd = &mut serde_json::Deserializer::from_str(assistant_message_json); - deserialize(jd).unwrap_or_else(|err| panic!("Deserialization error at {}: {}", err.path(), err.inner())) - }; - let assistant_message2: Message = { - let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2); - deserialize(jd).unwrap_or_else(|err| panic!("Deserialization error at {}: {}", err.path(), err.inner())) - }; - let assistant_message3: Message = { - let jd = &mut serde_json::Deserializer::from_str(assistant_message_json3); - deserialize(jd).unwrap_or_else(|err| panic!("Deserialization error at {}: {}", err.path(), err.inner())) - }; - let user_message: Message = { - let jd = &mut serde_json::Deserializer::from_str(user_message_json); - deserialize(jd).unwrap_or_else(|err| panic!("Deserialization error at {}: {}", err.path(), err.inner())) - }; - if let Message::Assistant { content, .. } = assistant_message { - assert_eq!(content[0], AssistantContent::Text { text: "\n\nHello there, how may I assist you today?".to_owned() }); - } else { - panic!("Expected assistant message"); - } - if let Message::Assistant { content, tool_calls, .. } = assistant_message2 { - assert_eq!(content[0], AssistantContent::Text { text: "\n\nHello there, how may I assist you today?".to_owned() }); - assert_eq!(tool_calls, vec![]); - } else { - panic!("Expected assistant message"); - } - if let Message::Assistant { content, tool_calls, refusal, .. } = assistant_message3 { - assert!(content.is_empty()); - assert!(refusal.is_none()); - assert_eq!(tool_calls[0], ToolCall { - id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_owned(), - r#type: ToolType::Function, - function: Function { - name: "subtract".to_owned(), - arguments: serde_json::json!({"x": 2, "y": 5}), - }, - }); - } else { - panic!("Expected assistant message"); - } - if let Message::User { content, .. } = user_message { - let mut iter = content.into_iter(); - let first = iter.next().unwrap(); - let second = iter.next().unwrap(); - assert_eq!(first, UserContent::Text { text: "What's in this image?".to_owned() }); - assert_eq!(second, UserContent::Image { image_url: ImageUrl { - url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_owned(), - detail: ImageDetail::default(), - }}); - } else { - panic!("Expected user message"); - } - } - #[test] - fn test_message_to_message_conversion() { - let internal_user = crate::message::Message::User { - content: OneOrMany::many(vec![crate::message::UserContent::text("Hello")]).expect("Non-empty"), - }; - let internal_assistant = crate::message::Message::Assistant { - content: OneOrMany::many(vec![crate::message::AssistantContent::text("Hi there!")]).expect("Non-empty"), - }; - let converted_user: Message = Message::try_from(internal_user).unwrap(); - let converted_assistant: Message = Message::try_from(internal_assistant).unwrap(); - if let Message::User { ref content, .. } = converted_user { - assert_eq!(content.first(), UserContent::Text { text: "Hello".to_owned() }); - } else { - panic!("Expected user message"); - } - if let Message::Assistant { ref content, .. } = converted_assistant { - assert_eq!(content.first(), Some(&AssistantContent::Text { text: "Hi there!".to_owned() })); - } else { - panic!("Expected assistant message"); - } - } -} \ No newline at end of file