Skip to content

Commit

Permalink
Conditionally use metal or cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
domWinter authored May 28, 2024
1 parent 3faf9b9 commit e058de5
Showing 1 changed file with 33 additions and 15 deletions.
48 changes: 33 additions & 15 deletions src/chat_completion.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use std::sync::Arc;
use candle_core::Device;
use indexmap::IndexMap;
use mistralrs::{
Constraint, DeviceMapMetadata, GGUFLoaderBuilder, GGUFSpecificConfig, MistralRs, MistralRsBuilder, NormalRequest, Request, RequestMessage, SamplingParams, SchedulerMethod, TokenSource
Constraint, DeviceMapMetadata, GGUFLoaderBuilder, GGUFSpecificConfig, MistralRs,
MistralRsBuilder, NormalRequest, Request, RequestMessage, SamplingParams, SchedulerMethod,
TokenSource,
};
use indexmap::IndexMap;
use std::sync::Arc;
use tokio::sync::mpsc::channel;
use tokio_stream::wrappers::ReceiverStream;


#[derive(Clone)]
pub struct CompletionModel {
pub mistralrs: Arc<MistralRs>
pub mistralrs: Arc<MistralRs>,
}

impl CompletionModel {
pub fn new() -> anyhow::Result<Self> {
pub fn new() -> anyhow::Result<Self> {
// Select a Mistral model
let loader = GGUFLoaderBuilder::new(
GGUFSpecificConfig { repeat_last_n: 64 },
Expand All @@ -25,23 +26,40 @@ impl CompletionModel {
"mistral-7b-instruct-v0.2.Q4_K_M.gguf".to_string(),
)
.build();

let pipeline = loader.load_model_from_hf(
None,
TokenSource::CacheToken,
None,
&Device::new_metal(0)?,
&Self::device()?,
false,
DeviceMapMetadata::dummy(),
None,
)?;
// Create the MistralRs, which is a runner
Ok(Self {
mistralrs: MistralRsBuilder::new(pipeline, SchedulerMethod::Fixed(5.try_into().unwrap())).build()
Ok(Self {
mistralrs: MistralRsBuilder::new(
pipeline,
SchedulerMethod::Fixed(5.try_into().unwrap()),
)
.build(),
})
}

pub async fn complete(&self, request: &str) -> anyhow::Result<ReceiverStream<mistralrs::Response>> {
#[cfg(feature = "metal")]
fn device() -> anyhow::Result<Device> {
Ok(Device::new_metal(0)?)
}

#[cfg(not(feature = "metal"))]
fn device() -> anyhow::Result<Device> {
Ok(Device::cuda_if_available(0)?)
}

pub async fn complete(
&self,
request: &str,
) -> anyhow::Result<ReceiverStream<mistralrs::Response>> {
let (tx, rx) = channel(10_000);

let mut messages = Vec::new();
Expand All @@ -61,8 +79,8 @@ impl CompletionModel {
adapters: None,
});

self.mistralrs.get_sender().send(request).await?;
Ok(ReceiverStream::new(rx))
self.mistralrs.get_sender().send(request).await?;

Ok(ReceiverStream::new(rx))
}
}
}

0 comments on commit e058de5

Please sign in to comment.