Skip to content

Commit

Permalink
update ollama client completion params
Browse files Browse the repository at this point in the history
  • Loading branch information
451846939 committed Feb 8, 2025
1 parent 5b4ccac commit ad91f61
Showing 1 changed file with 71 additions and 79 deletions.
150 changes: 71 additions & 79 deletions rig-core/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
//!
//! let req = rig::completion::CompletionRequest {
//! preamble: Some("You are now a humorous AI assistant.".to_owned()),
//! chat_history: vec![], // internal messages, if any
//! prompt: "Please tell me why the sky is blue.".to_owned(),
//! chat_history: vec![], // internal messages (if any)
//! prompt: /* a crate::message::Message value representing the prompt */
//! rig::message::Message::User {
//! content: rig::one_or_many::OneOrMany::one(rig::message::UserContent::text("Please tell me why the sky is blue.")),
//! name: None
//! },
//! temperature: 0.7,
//! additional_params: None,
//! tools: vec![],
Expand All @@ -36,7 +40,6 @@
//! let extractor = client.extractor::<serde_json::Value>("llama3.2");
//! ```

// =================================================================
// Imports
// =================================================================
Expand All @@ -49,7 +52,7 @@ use crate::{
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
extractor::ExtractorBuilder,
json_utils,
message::{AudioMediaType, ImageDetail},
message::{self, AudioMediaType, ImageDetail},
one_or_many::string_or_one_or_many,
Embed, OneOrMany,
};
Expand All @@ -58,8 +61,9 @@ use serde::{Deserialize, Serialize};
use serde_json::json;
use reqwest;


// =================================================================
// FromStr implementations for provider types (used by deserializers)
// FromStr implementations for provider types (for deserialization)
// =================================================================

impl FromStr for SystemContent {
Expand Down Expand Up @@ -90,7 +94,6 @@ impl FromStr for AssistantContent {
// Main Ollama Client
// =================================================================

/// Default Ollama API base URL.
const OLLAMA_API_BASE_URL: &str = "http://localhost:11434";

#[derive(Clone)]
Expand All @@ -100,12 +103,9 @@ pub struct Client {
}

impl Client {
/// Create a new Ollama client using the default API URL.
pub fn new() -> Self {
Self::from_url(OLLAMA_API_BASE_URL)
}

/// Create a new Ollama client using the specified API URL.
pub fn from_url(base_url: &str) -> Self {
Self {
base_url: base_url.to_owned(),
Expand All @@ -114,39 +114,25 @@ impl Client {
.expect("Ollama reqwest client should build"),
}
}

/// Create a new HTTP POST request builder.
fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path);
self.http_client.post(url)
}

/// Create an embedding model interface.
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, 0)
}

/// Create an embedding model interface with a specified number of dimensions.
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, ndims)
}

/// Create an embeddings builder.
pub fn embeddings<D: Embed>(&self, model: &str) -> EmbeddingsBuilder<EmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model))
}

/// Create a completion model interface.
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder.
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
Expand All @@ -156,7 +142,7 @@ impl Client {
}

// =================================================================
// Generic API Error and Response Structures
// API Error and Response Structures
// =================================================================

#[derive(Debug, Deserialize)]
Expand All @@ -175,7 +161,6 @@ enum ApiResponse<T> {
// Embedding API
// =================================================================

/// Example constant for an Ollama embedding model.
pub const ALL_MINILM: &str = "all-minilm";

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -262,7 +247,6 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
// Completion API
// =================================================================

/// Example constants for Ollama completion models.
pub const LLAMA3_2: &str = "llama3.2";
pub const LLAVA: &str = "llava";
pub const MISTRAL: &str = "mistral";
Expand Down Expand Up @@ -297,12 +281,10 @@ impl From<ApiErrorResponse> for CompletionError {
}
}

/// For single-turn generation, we convert a CompletionResponse into our unified CompletionResponse type
impl TryFrom<CompletionResponse> for completion::CompletionResponse<serde_json::Value> {
type Error = CompletionError;
fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
let assistant = completion::AssistantContent::text(&response.response);
// Instead of using map_err, we check manually:
let choice = OneOrMany::one(assistant);
if choice.is_empty() {
return Err(CompletionError::ResponseError("Empty response".into()));
Expand Down Expand Up @@ -342,8 +324,8 @@ impl TryFrom<ChatResponse> for completion::CompletionResponse<serde_json::Value>
Message::Assistant { ref content, .. } => {
let texts: Vec<completion::AssistantContent> = content.into_iter().map(|c| {
match c {
AssistantContent::Text { text } => completion::AssistantContent::text(text),
AssistantContent::Refusal { refusal } => completion::AssistantContent::text(refusal),
AssistantContent::Text { ref text } => completion::AssistantContent::text(text),
AssistantContent::Refusal { ref refusal } => completion::AssistantContent::text(refusal),
}
}).collect();
let choice = OneOrMany::many(texts)
Expand All @@ -363,7 +345,6 @@ impl TryFrom<ChatResponse> for completion::CompletionResponse<serde_json::Value>
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
/// The model name (e.g. "llama3.2").
pub model: String,
}

Expand All @@ -373,7 +354,7 @@ impl CompletionModel {
}
}

/// We set our associated type Response to be a JSON value.
/// In our unified API, we set the associated Response type to be a JSON value.
impl completion::CompletionModel for CompletionModel {
type Response = serde_json::Value;
async fn completion(
Expand All @@ -382,7 +363,6 @@ impl completion::CompletionModel for CompletionModel {
) -> Result<completion::CompletionResponse<Self::Response>, CompletionError> {
if !completion_request.chat_history.is_empty() {
// Chat mode:
// Assume prompt_with_context() returns a single internal message.
let prompt_internal: crate::message::Message = completion_request.prompt_with_context();
let prompt_msg: Message = prompt_internal.try_into()?;
let chat_history: Vec<Message> = completion_request
Expand Down Expand Up @@ -422,14 +402,24 @@ impl completion::CompletionModel for CompletionModel {
}
} else {
// Single-turn mode:
let full_prompt = completion_request.prompt.clone();
// Convert the internal message (which is our prompt) to a plain string.
let full_prompt = internal_message_to_string(&completion_request.prompt);
let mut request_payload = json!({
"model": self.model,
"prompt": full_prompt,
"prompt": full_prompt.clone(),
"stream": false,
});
if let Some(params) = &completion_request.additional_params {
request_payload = json_utils::merge(request_payload, params.clone());
// Remove any "prompt" key from additional parameters
let mut params = params.clone();
if let Some(map) = params.as_object_mut() {
map.remove("prompt");
}
request_payload = json_utils::merge(request_payload, params);
}
// Forcefully re-set "prompt" as a string.
if let Some(obj) = request_payload.as_object_mut() {
obj.insert("prompt".to_string(), serde_json::Value::String(full_prompt));
}
let response = self.client.post("api/generate")
.json(&request_payload)
Expand All @@ -452,7 +442,7 @@ impl completion::CompletionModel for CompletionModel {
}

// =================================================================
// Provider Message Definitions and Conversions (following openai.rs)
// Provider Message Definitions and Conversions
// =================================================================

#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
Expand Down Expand Up @@ -591,7 +581,8 @@ pub struct Function {

// =================================================================
// Conversion from internal Rig message (crate::message::Message)
// to provider Message. (Only User and Assistant variants are supported.)
// to provider Message.
// (Only User and Assistant variants are supported.)
// =================================================================

impl TryFrom<crate::message::Message> for Message {
Expand All @@ -618,7 +609,8 @@ impl TryFrom<crate::message::Message> for Message {
other => Err(crate::message::MessageError::ConversionError(format!("Unsupported user content: {:?}", other))),
}
}).collect();
let one = OneOrMany::many(converted?).map_err(|e| crate::message::MessageError::ConversionError(e.to_string()))?;
let one = OneOrMany::many(converted?)
.map_err(|e| crate::message::MessageError::ConversionError(e.to_string()))?;
Ok(Message::User { content: one, name: None })
}
InternalMessage::Assistant { content } => {
Expand All @@ -636,10 +628,31 @@ impl TryFrom<crate::message::Message> for Message {
tool_calls: vec![],
})
}
other => Err(crate::message::MessageError::ConversionError(format!("Unsupported internal message variant: {:?}", other))),
}
}
}

// =================================================================
// Helper: Convert internal message (the prompt) to a plain string.
// =================================================================

fn internal_message_to_string(msg: &crate::message::Message) -> String {
use crate::message::{Message as InternalMessage, UserContent};
match msg {
InternalMessage::User { content, .. } => {
content.iter().filter_map(|uc| {
if let UserContent::Text(t) = uc {
Some(t.text.clone())
} else {
None
}
}).collect::<Vec<_>>().join("\n")
},
_ => format!("{:?}", msg),
}
}

// =================================================================
// Tests
// =================================================================
Expand All @@ -648,7 +661,6 @@ impl TryFrom<crate::message::Message> for Message {
mod tests {
use super::*;
use serde_path_to_error::deserialize;

#[test]
fn test_deserialize_message() {
let assistant_message_json = r#"
Expand Down Expand Up @@ -712,74 +724,54 @@ mod tests {
"#;
let assistant_message: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json);
deserialize(jd).unwrap_or_else(|err| {
panic!("Deserialization error at {}: {}", err.path(), err.inner())
})
deserialize(jd).unwrap_or_else(|err| panic!("Deserialization error at {}: {}", err.path(), err.inner()))
};
let assistant_message2: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json2);
deserialize(jd).unwrap_or_else(|err| {
panic!("Deserialization error at {}: {}", err.path(), err.inner())
})
deserialize(jd).unwrap_or_else(|err| panic!("Deserialization error at {}: {}", err.path(), err.inner()))
};
let assistant_message3: Message = {
let jd = &mut serde_json::Deserializer::from_str(assistant_message_json3);
deserialize(jd).unwrap_or_else(|err| {
panic!("Deserialization error at {}: {}", err.path(), err.inner())
})
deserialize(jd).unwrap_or_else(|err| panic!("Deserialization error at {}: {}", err.path(), err.inner()))
};
let user_message: Message = {
let jd = &mut serde_json::Deserializer::from_str(user_message_json);
deserialize(jd).unwrap_or_else(|err| {
panic!("Deserialization error at {}: {}", err.path(), err.inner())
})
deserialize(jd).unwrap_or_else(|err| panic!("Deserialization error at {}: {}", err.path(), err.inner()))
};
if let Message::Assistant { content, .. } = assistant_message {
assert_eq!(content[0],
AssistantContent::Text { text: "\n\nHello there, how may I assist you today?".to_owned() }
);
assert_eq!(content[0], AssistantContent::Text { text: "\n\nHello there, how may I assist you today?".to_owned() });
} else {
panic!("Expected assistant message");
}
if let Message::Assistant { content, tool_calls, .. } = assistant_message2 {
assert_eq!(content[0],
AssistantContent::Text { text: "\n\nHello there, how may I assist you today?".to_owned() }
);
assert_eq!(content[0], AssistantContent::Text { text: "\n\nHello there, how may I assist you today?".to_owned() });
assert_eq!(tool_calls, vec![]);
} else {
panic!("Expected assistant message");
}
if let Message::Assistant { content, tool_calls, refusal, .. } = assistant_message3 {
assert!(content.is_empty());
assert!(refusal.is_none());
assert_eq!(tool_calls[0],
ToolCall {
id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_owned(),
r#type: ToolType::Function,
function: Function {
name: "subtract".to_owned(),
arguments: serde_json::json!({"x": 2, "y": 5}),
},
}
);
assert_eq!(tool_calls[0], ToolCall {
id: "call_h89ipqYUjEpCPI6SxspMnoUU".to_owned(),
r#type: ToolType::Function,
function: Function {
name: "subtract".to_owned(),
arguments: serde_json::json!({"x": 2, "y": 5}),
},
});
} else {
panic!("Expected assistant message");
}
if let Message::User { content, .. } = user_message {
let mut iter = content.into_iter();
let first = iter.next().unwrap();
let second = iter.next().unwrap();
assert_eq!(first,
UserContent::Text { text: "What's in this image?".to_owned() }
);
assert_eq!(second,
UserContent::Image {
image_url: ImageUrl {
url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_owned(),
detail: ImageDetail::default()
}
}
);
assert_eq!(first, UserContent::Text { text: "What's in this image?".to_owned() });
assert_eq!(second, UserContent::Image { image_url: ImageUrl {
url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_owned(),
detail: ImageDetail::default(),
}});
} else {
panic!("Expected user message");
}
Expand Down

0 comments on commit ad91f61

Please sign in to comment.