From 492d767fdb023a913596fd73c577ad9236d8e05e Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Mon, 10 Jun 2024 15:34:35 -0400 Subject: [PATCH] Expose API to specify dtype during loading (#417) * Abstract dtype settings * Abstract dtype settings * Better name * Expose in apis * Logging * Fix test * Docs --- README.md | 6 ++ mistralrs-bench/src/main.rs | 4 +- mistralrs-core/src/lib.rs | 4 +- mistralrs-core/src/model_loader.rs | 31 +++++++- mistralrs-core/src/model_selected.rs | 25 ++++++- mistralrs-core/src/pipeline/ggml.rs | 12 ++-- mistralrs-core/src/pipeline/gguf.rs | 10 +-- mistralrs-core/src/pipeline/macros.rs | 25 +++---- mistralrs-core/src/pipeline/mod.rs | 12 ++-- mistralrs-core/src/pipeline/normal.rs | 21 ++---- mistralrs-core/src/pipeline/speculative.rs | 8 +-- mistralrs-core/src/pipeline/vision.rs | 19 ++--- mistralrs-core/src/toml_selector.rs | 45 ++++++++++-- mistralrs-core/src/utils/mod.rs | 1 + mistralrs-core/src/utils/normal.rs | 84 ++++++++++++++++++++++ mistralrs-pyo3/src/lib.rs | 8 +-- mistralrs-pyo3/src/which.rs | 12 ++-- mistralrs-server/src/main.rs | 7 +- mistralrs/examples/gguf_locally/main.rs | 4 +- mistralrs/examples/grammar/main.rs | 8 +-- mistralrs/examples/isq/main.rs | 4 +- mistralrs/examples/lora/main.rs | 8 +-- mistralrs/examples/lora_activation/main.rs | 8 +-- mistralrs/examples/phi3v/main.rs | 8 +-- mistralrs/examples/quantized/main.rs | 4 +- mistralrs/examples/simple/main.rs | 8 +-- mistralrs/examples/xlora/main.rs | 8 +-- 27 files changed, 279 insertions(+), 115 deletions(-) create mode 100644 mistralrs-core/src/utils/normal.rs diff --git a/README.md b/README.md index d639b16621..d20419ec46 100644 --- a/README.md +++ b/README.md @@ -341,6 +341,8 @@ Additionally, for models without quantization, the model architecture should be ### Architecture for plain models +> Note: for plain models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. This is specified in the `--dype`/`-d` parameter after the model architecture (`plain`). + - `mistral` - `gemma` - `mixtral` @@ -351,6 +353,8 @@ Additionally, for models without quantization, the model architecture should be ### Architecture for vision models +> Note: for vision models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. This is specified in the `--dype`/`-d` parameter after the model architecture (`vision-plain`). + - `phi3v` **Interactive mode:** @@ -495,6 +499,8 @@ If you want to add a new model, please contact us via an issue and we can coordi - Error: `recompile with -fPIE`: - Some Linux distributions require compiling with `-fPIE`. - Set the `CUDA_NVCC_FLAGS` environment variable to `-fPIE` during build: `CUDA_NVCC_FLAGS=-fPIE` +- Error `CUDA_ERROR_NOT_FOUND` or symbol not found when using a normal or vison model: + - For non-quantized models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. ## Credits This project would not be possible without the excellent work at [`candle`](https://github.com/huggingface/candle). Additionally, thank you to all contributors! Contributing can range from raising an issue or suggesting a feature to adding some new functionality. diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index 348aae8909..0635ee8e8c 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -2,7 +2,7 @@ use candle_core::Device; use clap::Parser; use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table}; use mistralrs_core::{ - Constraint, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder, + Constraint, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder, ModelDType, ModelSelected, NormalRequest, Request, RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, Usage, }; @@ -313,7 +313,7 @@ fn main() -> anyhow::Result<()> { let pipeline = loader.load_model_from_hf( None, token_source, - None, + &ModelDType::Auto, &device, false, args.num_device_layers diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 59e543c85a..ab97b245d8 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -22,9 +22,10 @@ mod device_map; mod engine; mod lora; mod model_loader; -pub use model_loader::{get_tgt_non_granular_index, LoaderBuilder}; +pub use model_loader::{get_model_dtype, get_tgt_non_granular_index, LoaderBuilder}; mod model_selected; pub use model_selected::ModelSelected; +pub use toml_selector::get_toml_selected_model_dtype; mod cublaslt; mod gguf; @@ -62,6 +63,7 @@ pub use scheduler::SchedulerMethod; use serde::Serialize; use tokio::runtime::Runtime; use toml_selector::{TomlLoaderArgs, TomlSelector}; +pub use utils::normal::{ModelDType, TryIntoDType}; /// `true` if `MISTRALRS_DEBUG=1` pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false); diff --git a/mistralrs-core/src/model_loader.rs b/mistralrs-core/src/model_loader.rs index 852c79d405..1bd5994b47 100644 --- a/mistralrs-core/src/model_loader.rs +++ b/mistralrs-core/src/model_loader.rs @@ -1,12 +1,13 @@ use std::fs::{self, File}; use crate::{ + get_toml_selected_model_dtype, pipeline::{ GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, NormalSpecificConfig, }, - Loader, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, VisionLoaderBuilder, - VisionSpecificConfig, + Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, + VisionLoaderBuilder, VisionSpecificConfig, }; /// A builder for a loader using the selected model. @@ -70,6 +71,28 @@ pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option { } } +pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result { + match model { + ModelSelected::Plain { dtype, .. } + | ModelSelected::Lora { dtype, .. } + | ModelSelected::XLora { dtype, .. } + | ModelSelected::VisionPlain { dtype, .. } => Ok(*dtype), + ModelSelected::GGUF { .. } + | ModelSelected::LoraGGUF { .. } + | ModelSelected::GGML { .. } + | ModelSelected::LoraGGML { .. } + | ModelSelected::XLoraGGUF { .. } + | ModelSelected::XLoraGGML { .. } => Ok(ModelDType::Auto), + ModelSelected::Toml { file } => { + let selector: TomlSelector = toml::from_str( + &fs::read_to_string(file.clone()) + .unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")), + )?; + Ok(get_toml_selected_model_dtype(&selector)) + } + } +} + fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result> { let use_flash_attn = args.use_flash_attn; let loader: Box = match args.model { @@ -90,6 +113,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result NormalLoaderBuilder::new( NormalSpecificConfig { use_flash_attn, @@ -108,6 +132,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result NormalLoaderBuilder::new( NormalSpecificConfig { use_flash_attn, @@ -134,6 +159,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result NormalLoaderBuilder::new( NormalSpecificConfig { use_flash_attn, @@ -285,6 +311,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result VisionLoaderBuilder::new( VisionSpecificConfig { use_flash_attn, diff --git a/mistralrs-core/src/model_selected.rs b/mistralrs-core/src/model_selected.rs index 7c82254284..1a6e06bbff 100644 --- a/mistralrs-core/src/model_selected.rs +++ b/mistralrs-core/src/model_selected.rs @@ -1,6 +1,9 @@ use clap::Subcommand; -use crate::pipeline::{NormalLoaderType, VisionLoaderType}; +use crate::{ + pipeline::{NormalLoaderType, VisionLoaderType}, + ModelDType, +}; fn parse_arch(x: &str) -> Result { x.parse() @@ -10,6 +13,10 @@ fn parse_vision_arch(x: &str) -> Result { x.parse() } +fn parse_model_dtype(x: &str) -> Result { + x.parse() +} + #[derive(Debug, Subcommand)] pub enum ModelSelected { /// Select the model from a toml file @@ -36,6 +43,10 @@ pub enum ModelSelected { /// The architecture of the model. #[arg(short, long, value_parser = parse_arch)] arch: NormalLoaderType, + + /// Model data type. Defaults to `auto`. + #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)] + dtype: ModelDType, }, /// Select an X-LoRA architecture @@ -68,6 +79,10 @@ pub enum ModelSelected { /// The architecture of the model. #[arg(short, long, value_parser = parse_arch)] arch: NormalLoaderType, + + /// Model data type. Defaults to `auto`. + #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)] + dtype: ModelDType, }, /// Select a LoRA architecture @@ -95,6 +110,10 @@ pub enum ModelSelected { /// The architecture of the model. #[arg(long, value_parser = parse_arch)] arch: NormalLoaderType, + + /// Model data type. Defaults to `auto`. + #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)] + dtype: ModelDType, }, /// Select a GGUF model. @@ -306,5 +325,9 @@ pub enum ModelSelected { /// The architecture of the model. #[arg(short, long, value_parser = parse_vision_arch)] arch: VisionLoaderType, + + /// Model data type. Defaults to `auto`. + #[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)] + dtype: ModelDType, }, } diff --git a/mistralrs-core/src/pipeline/ggml.rs b/mistralrs-core/src/pipeline/ggml.rs index a7ba185062..d5a711c6ad 100644 --- a/mistralrs-core/src/pipeline/ggml.rs +++ b/mistralrs-core/src/pipeline/ggml.rs @@ -20,14 +20,16 @@ use crate::utils::debug::setup_logger_and_debug; use crate::utils::model_config as ModelConfig; use crate::utils::tokenizer::get_tokenizer; use crate::xlora_models::NonGranularState; -use crate::{do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, Pipeline, DEBUG}; +use crate::{ + do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, Pipeline, TryIntoDType, DEBUG, +}; use crate::{ models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token, xlora_models::XLoraQLlama, }; use anyhow::Result; use candle_core::quantized::{ggml_file, GgmlDType}; -use candle_core::{DType, Device, Tensor}; +use candle_core::{Device, Tensor}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use rand_isaac::Isaac64Rng; use std::any::Any; @@ -232,7 +234,7 @@ impl Loader for GGMLLoader { fn load_model_from_path( &self, paths: &Box, - _dtype: Option, + _: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, @@ -363,7 +365,7 @@ impl Loader for GGMLLoader { &self, revision: Option, token_source: TokenSource, - _dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, @@ -378,7 +380,7 @@ impl Loader for GGMLLoader { Some(vec![self.quantized_filename.as_ref().unwrap().clone()]), silent ); - self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant) + self.load_model_from_path(&paths?, dtype, device, silent, mapper, in_situ_quant) } fn get_id(&self) -> String { diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 15d71c90de..9fecb53b90 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -25,7 +25,7 @@ use crate::utils::tokenizer::get_tokenizer; use crate::xlora_models::NonGranularState; use crate::{ do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, LocalModelPaths, Pipeline, - DEBUG, + TryIntoDType, DEBUG, }; use crate::{ models::quantized_llama::ModelWeights as QLlama, @@ -39,7 +39,7 @@ use candle_core::quantized::{ gguf_file::{self, Value as GgufValue}, GgmlDType, }; -use candle_core::{DType, Device, Tensor}; +use candle_core::{Device, Tensor}; use either::Either; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use rand_isaac::Isaac64Rng; @@ -294,7 +294,7 @@ impl Loader for GGUFLoader { &self, revision: Option, token_source: TokenSource, - _dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, @@ -309,14 +309,14 @@ impl Loader for GGUFLoader { self.quantized_filename.clone(), silent ); - self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant) + self.load_model_from_path(&paths?, dtype, device, silent, mapper, in_situ_quant) } #[allow(clippy::type_complexity, clippy::too_many_arguments)] fn load_model_from_path( &self, paths: &Box, - _dtype: Option, + _: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index c3b93bfea2..d02f1fa47d 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -330,11 +330,11 @@ macro_rules! get_paths_gguf { #[doc(hidden)] #[macro_export] macro_rules! normal_model_loader { - ($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ + ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ let vb = from_mmaped_safetensors( $paths.get_weight_filenames().to_vec(), Vec::new(), - $dtype.unwrap_or($default_dtype), + $dtype, $device, $silent, )?; @@ -355,11 +355,11 @@ macro_rules! normal_model_loader { #[doc(hidden)] #[macro_export] macro_rules! vision_normal_model_loader { - ($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ + ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ let vb = from_mmaped_safetensors( $paths.get_weight_filenames().to_vec(), Vec::new(), - $dtype.unwrap_or($default_dtype), + $dtype, $device, $silent, )?; @@ -380,7 +380,7 @@ macro_rules! vision_normal_model_loader { #[doc(hidden)] #[macro_export] macro_rules! xlora_model_loader { - ($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ + ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::>(); safetensors_paths.push($paths.get_classifier_path().as_ref().unwrap()); let vb = from_mmaped_safetensors( @@ -395,7 +395,7 @@ macro_rules! xlora_model_loader { .iter() .map(|(_, x)| (*x).to_owned()) .collect::>(), - $dtype.unwrap_or($default_dtype), + $dtype, $device, $silent, )?; @@ -412,12 +412,7 @@ macro_rules! xlora_model_loader { loading_isq: $loading_isq, real_device: $real_device, }, - &$crate::utils::varbuilder_utils::load_preload_adapters( - $paths.get_lora_preload_adapter_info(), - $dtype.unwrap_or($default_dtype), - $device, - $silent, - )?, + &None, )? }}; } @@ -425,7 +420,7 @@ macro_rules! xlora_model_loader { #[doc(hidden)] #[macro_export] macro_rules! lora_model_loader { - ($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ + ($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{ let safetensors_paths = $paths.get_weight_filenames().iter().collect::>(); let vb = from_mmaped_safetensors( safetensors_paths @@ -439,7 +434,7 @@ macro_rules! lora_model_loader { .iter() .map(|(_, x)| (*x).to_owned()) .collect::>(), - $dtype.unwrap_or($default_dtype), + $dtype, $device, $silent, )?; @@ -458,7 +453,7 @@ macro_rules! lora_model_loader { }, &$crate::utils::varbuilder_utils::load_preload_adapters( $paths.get_lora_preload_adapter_info(), - $dtype.unwrap_or($default_dtype), + $dtype, $device, $silent, )?, diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index a5477a9612..b57929495d 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -17,7 +17,7 @@ use crate::aici::toktree::TokTrie; use crate::prefix_cacher::PrefixCacheManager; mod sampling_pipeline; use crate::lora::{LoraConfig, Ordering}; -use crate::DeviceMapMetadata; +use crate::{DeviceMapMetadata, TryIntoDType}; use candle_core::quantized::GgmlDType; use chat_template::ChatTemplate; use core::fmt; @@ -43,7 +43,7 @@ pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig}; pub use vision_loaders::{Phi3VLoader, VisionLoaderType, VisionModelLoader}; use anyhow::Result; -use candle_core::{DType, Device, Tensor}; +use candle_core::{Device, Tensor}; use crate::{ sequence::Sequence, @@ -360,14 +360,14 @@ impl ModelKind { /// /// # Example /// ```no_run -/// use mistralrs_core::{Loader, TokenSource, DeviceMapMetadata}; +/// use mistralrs_core::{Loader, TokenSource, DeviceMapMetadata, ModelDType}; /// use candle_core::Device; /// /// let loader: Box = todo!(); /// let pipeline = loader.load_model_from_hf( /// None, /// TokenSource::CacheToken, -/// None, +/// &ModelDType::Auto, /// &Device::cuda_if_available(0).unwrap(), /// false, /// DeviceMapMetadata::dummy(), @@ -383,7 +383,7 @@ pub trait Loader { &self, revision: Option, token_source: TokenSource, - dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, @@ -400,7 +400,7 @@ pub trait Loader { fn load_model_from_path( &self, paths: &Box, - _dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 6a71da6b43..866442076c 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -26,11 +26,11 @@ use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors} use crate::xlora_models::NonGranularState; use crate::{ do_sample, get_mut_arcmutex, get_paths, lora_model_loader, normal_model_loader, - xlora_model_loader, DeviceMapMetadata, Pipeline, + xlora_model_loader, DeviceMapMetadata, Pipeline, TryIntoDType, }; use anyhow::Result; use candle_core::quantized::GgmlDType; -use candle_core::{DType, Device, Tensor}; +use candle_core::{Device, Tensor}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use rand_isaac::Isaac64Rng; use std::any::Any; @@ -186,7 +186,7 @@ impl Loader for NormalLoader { &self, revision: Option, token_source: TokenSource, - _dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, @@ -201,27 +201,21 @@ impl Loader for NormalLoader { None, silent ); - self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant) + self.load_model_from_path(&paths?, dtype, device, silent, mapper, in_situ_quant) } #[allow(clippy::type_complexity, clippy::too_many_arguments)] fn load_model_from_path( &self, paths: &Box, - dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, in_situ_quant: Option, ) -> Result>> { let config = std::fs::read_to_string(paths.get_config_filename())?; - let default_dtype = if device.is_cuda() && mapper.is_dummy() { - DType::BF16 - } else if !mapper.is_dummy() { - DType::F16 - } else { - DType::F32 - }; + let dtype = dtype.try_into_dtype(device)?; // Otherwise, the device mapper will print it if mapper.is_dummy() { info!("Loading model `{}` on {device:?}...", self.get_id()); @@ -245,7 +239,6 @@ impl Loader for NormalLoader { ModelKind::Normal => normal_model_loader!( paths, dtype, - default_dtype, &load_device, config, self.inner, @@ -260,7 +253,6 @@ impl Loader for NormalLoader { } => xlora_model_loader!( paths, dtype, - default_dtype, &load_device, config, self.inner, @@ -275,7 +267,6 @@ impl Loader for NormalLoader { } => lora_model_loader!( paths, dtype, - default_dtype, &load_device, config, self.inner, diff --git a/mistralrs-core/src/pipeline/speculative.rs b/mistralrs-core/src/pipeline/speculative.rs index 18fbaf0a01..37ef210479 100644 --- a/mistralrs-core/src/pipeline/speculative.rs +++ b/mistralrs-core/src/pipeline/speculative.rs @@ -5,7 +5,7 @@ use std::{ }; use anyhow::Result as anyhowResult; -use candle_core::{quantized::GgmlDType, DType, Device, IndexOp, Result, Tensor}; +use candle_core::{quantized::GgmlDType, Device, IndexOp, Result, Tensor}; use rand_isaac::Isaac64Rng; use tokenizers::Tokenizer; @@ -17,7 +17,7 @@ use crate::{ }, prefix_cacher::PrefixCacheManager, sequence::{Sequence, SequenceRecognizer}, - DeviceMapMetadata, Loader, ModelKind, Pipeline, TokenSource, + DeviceMapMetadata, Loader, ModelKind, Pipeline, TokenSource, TryIntoDType, }; use super::{ @@ -39,7 +39,7 @@ impl Loader for SpeculativeLoader { &self, revision: Option, token_source: TokenSource, - dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, @@ -74,7 +74,7 @@ impl Loader for SpeculativeLoader { fn load_model_from_path( &self, paths: &Box, - dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index e86c188a6e..e1ea936318 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -20,10 +20,11 @@ use crate::vision_models::processor_config::ProcessorConfig; use crate::vision_models::ModelInputs; use crate::{ do_sample, get_paths, vision_normal_model_loader, DeviceMapMetadata, Ordering, Pipeline, + TryIntoDType, }; use anyhow::Result; use candle_core::quantized::GgmlDType; -use candle_core::{DType, Device, Tensor}; +use candle_core::{Device, Tensor}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use rand_isaac::Isaac64Rng; use std::any::Any; @@ -116,7 +117,7 @@ impl Loader for VisionLoader { &self, revision: Option, token_source: TokenSource, - _dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, @@ -131,27 +132,22 @@ impl Loader for VisionLoader { None, silent ); - self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant) + self.load_model_from_path(&paths?, dtype, device, silent, mapper, in_situ_quant) } #[allow(clippy::type_complexity, clippy::too_many_arguments)] fn load_model_from_path( &self, paths: &Box, - dtype: Option, + dtype: &dyn TryIntoDType, device: &Device, silent: bool, mapper: DeviceMapMetadata, in_situ_quant: Option, ) -> Result>> { let config = std::fs::read_to_string(paths.get_config_filename())?; - let default_dtype = if device.is_cuda() && mapper.is_dummy() { - DType::BF16 - } else if !mapper.is_dummy() { - DType::F16 - } else { - DType::F32 - }; + let dtype = dtype.try_into_dtype(device)?; + // Otherwise, the device mapper will print it if mapper.is_dummy() { info!("Loading model `{}` on {device:?}...", self.get_id()); @@ -173,7 +169,6 @@ impl Loader for VisionLoader { ModelKind::Normal => vision_normal_model_loader!( paths, dtype, - default_dtype, &load_device, config, self.inner, diff --git a/mistralrs-core/src/toml_selector.rs b/mistralrs-core/src/toml_selector.rs index a06ba83190..70c6d1728e 100644 --- a/mistralrs-core/src/toml_selector.rs +++ b/mistralrs-core/src/toml_selector.rs @@ -4,7 +4,7 @@ use serde::Deserialize; use crate::{ GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader, - NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig, + ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, SpeculativeConfig, SpeculativeLoader, VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig, }; @@ -18,7 +18,7 @@ fn default_one() -> usize { #[derive(Debug, Deserialize)] #[serde(untagged)] -enum TomlModelSelected { +pub enum TomlModelSelected { /// Select a plain model, without quantization or adapters Plain { /// Model ID to load from. This may be a HF hub repo or a local path. @@ -26,6 +26,9 @@ enum TomlModelSelected { /// The architecture of the model. arch: NormalLoaderType, + + /// Model data type. Defaults to `auto`. + dtype: ModelDType, }, /// Select an X-LoRA architecture @@ -45,6 +48,9 @@ enum TomlModelSelected { /// The architecture of the model. arch: NormalLoaderType, + + /// Model data type. Defaults to `auto`. + dtype: ModelDType, }, /// Select a LoRA architecture @@ -60,6 +66,9 @@ enum TomlModelSelected { /// The architecture of the model. arch: NormalLoaderType, + + /// Model data type. Defaults to `auto`. + dtype: ModelDType, }, /// Select a GGUF model. @@ -199,6 +208,9 @@ enum TomlModelSelected { /// The architecture of the model. arch: VisionLoaderType, + + /// Model data type. Defaults to `auto`. + dtype: ModelDType, }, } @@ -242,13 +254,32 @@ pub struct TomlLoaderArgs { pub no_kv_cache: bool, } +pub fn get_toml_selected_model_dtype(model: &TomlSelector) -> ModelDType { + match model.model { + TomlModelSelected::Plain { dtype, .. } + | TomlModelSelected::Lora { dtype, .. } + | TomlModelSelected::XLora { dtype, .. } + | TomlModelSelected::VisionPlain { dtype, .. } => dtype, + TomlModelSelected::GGUF { .. } + | TomlModelSelected::LoraGGUF { .. } + | TomlModelSelected::GGML { .. } + | TomlModelSelected::LoraGGML { .. } + | TomlModelSelected::XLoraGGUF { .. } + | TomlModelSelected::XLoraGGML { .. } => ModelDType::Auto, + } +} + fn loader_from_selected( args: TomlLoaderInnerParams, model: TomlModelSelected, ) -> anyhow::Result> { let use_flash_attn = args.use_flash_attn; let loader: Box = match model { - TomlModelSelected::Plain { model_id, arch } => NormalLoaderBuilder::new( + TomlModelSelected::Plain { + model_id, + arch, + dtype: _, + } => NormalLoaderBuilder::new( NormalSpecificConfig { use_flash_attn, repeat_last_n: args.repeat_last_n, @@ -264,6 +295,7 @@ fn loader_from_selected( order, tgt_non_granular_index, arch, + dtype: _, } => NormalLoaderBuilder::new( NormalSpecificConfig { use_flash_attn, @@ -288,6 +320,7 @@ fn loader_from_selected( adapters_model_id, order, arch, + dtype: _, } => NormalLoaderBuilder::new( NormalSpecificConfig { use_flash_attn, @@ -440,7 +473,11 @@ fn loader_from_selected( )?, ) .build(), - TomlModelSelected::VisionPlain { model_id, arch } => VisionLoaderBuilder::new( + TomlModelSelected::VisionPlain { + model_id, + arch, + dtype: _, + } => VisionLoaderBuilder::new( VisionSpecificConfig { use_flash_attn, repeat_last_n: args.repeat_last_n, diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 314007140a..979ca43320 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod debug; pub(crate) mod gguf_metadata; pub(crate) mod model_config; +pub(crate) mod normal; pub(crate) mod progress; pub(crate) mod tokenizer; pub(crate) mod tokens; diff --git a/mistralrs-core/src/utils/normal.rs b/mistralrs-core/src/utils/normal.rs new file mode 100644 index 0000000000..084ff70266 --- /dev/null +++ b/mistralrs-core/src/utils/normal.rs @@ -0,0 +1,84 @@ +use std::{fmt::Display, str::FromStr}; + +use anyhow::Result; +use candle_core::{DType, Device}; +use serde::Deserialize; +use tracing::info; + +#[derive(Clone, Copy, Default, Debug, Deserialize)] +/// DType for the model. +/// +/// If the model is quantized, this is ignored so it is reasonable to use the [`Default`] impl. +/// +/// ## `Auto` rules +/// - If CUDA device or CPU, use BF16 +/// - Fallback to F16 +pub enum ModelDType { + #[default] + #[serde(rename = "auto")] + Auto, + #[serde(rename = "bf16")] + BF16, + #[serde(rename = "f16")] + F16, + #[serde(rename = "f32")] + F32, +} + +impl Display for ModelDType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Auto => write!(f, "auto"), + Self::BF16 => write!(f, "bf16"), + Self::F16 => write!(f, "f16"), + Self::F32 => write!(f, "f32"), + } + } +} + +impl FromStr for ModelDType { + type Err = String; + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "auto" => Ok(Self::Auto), + "bf16" => Ok(Self::BF16), + "f16" => Ok(Self::F16), + "f32" => Ok(Self::F32), + other => Err(format!("Model DType `{other}` is not supported.")), + } + } +} + +/// Type which can be converted to a DType +pub trait TryIntoDType { + fn try_into_dtype(&self, device: &Device) -> Result; +} + +impl TryIntoDType for DType { + fn try_into_dtype(&self, _: &Device) -> Result { + info!("DType selected is {self:?}."); + if !matches!(self, DType::BF16 | DType::F32 | DType::F64 | DType::F16) { + anyhow::bail!("DType must be one of BF16, F16, F32, F64"); + } + Ok(*self) + } +} + +impl TryIntoDType for ModelDType { + fn try_into_dtype(&self, device: &Device) -> Result { + let dtype = match self { + Self::Auto => { + if device.is_cuda() || device.is_cpu() { + Ok(DType::BF16) + } else { + Ok(DType::F32) + } + } + Self::BF16 => Ok(DType::BF16), + Self::F16 => Ok(DType::F16), + Self::F32 => Ok(DType::F32), + }; + info!("DType selected is {:?}.", dtype.as_ref().unwrap()); + dtype + } +} diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 6956da8ea7..4cb99e9892 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -19,9 +19,9 @@ use candle_core::Device; use mistralrs_core::{ ChatCompletionResponse, CompletionResponse, Constraint, DeviceMapMetadata, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig, Loader, MistralRs, MistralRsBuilder, - NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, Request as _Request, RequestMessage, - Response, SamplingParams, SchedulerMethod, SpeculativeConfig, SpeculativeLoader, StopTokens, - TokenSource, VisionLoaderBuilder, VisionSpecificConfig, + ModelDType, NormalLoaderBuilder, NormalRequest, NormalSpecificConfig, Request as _Request, + RequestMessage, Response, SamplingParams, SchedulerMethod, SpeculativeConfig, + SpeculativeLoader, StopTokens, TokenSource, VisionLoaderBuilder, VisionSpecificConfig, }; use pyo3::{ exceptions::{PyTypeError, PyValueError}, @@ -412,7 +412,7 @@ impl Runner { None, TokenSource::from_str(token_source) .map_err(|e| PyValueError::new_err(e.to_string()))?, - None, + &ModelDType::Auto, &device, true, // Silent for jupyter num_device_layers diff --git a/mistralrs-pyo3/src/which.rs b/mistralrs-pyo3/src/which.rs index caac702766..f112f958e6 100644 --- a/mistralrs-pyo3/src/which.rs +++ b/mistralrs-pyo3/src/which.rs @@ -13,12 +13,6 @@ pub enum Architecture { Qwen2, } -#[pyclass] -#[derive(Debug, Clone)] -pub enum VisionArchitecture { - Phi3V, -} - impl From for NormalLoaderType { fn from(value: Architecture) -> Self { match value { @@ -33,6 +27,12 @@ impl From for NormalLoaderType { } } +#[pyclass] +#[derive(Debug, Clone)] +pub enum VisionArchitecture { + Phi3V, +} + impl From for VisionLoaderType { fn from(value: VisionArchitecture) -> Self { match value { diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index 61c65be14a..73a7277a91 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -8,8 +8,8 @@ use axum::{ use candle_core::{quantized::GgmlDType, Device}; use clap::Parser; use mistralrs_core::{ - get_tgt_non_granular_index, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, - MistralRsBuilder, ModelSelected, Request, SchedulerMethod, TokenSource, + get_model_dtype, get_tgt_non_granular_index, DeviceMapMetadata, Loader, LoaderBuilder, + MistralRs, MistralRsBuilder, ModelSelected, Request, SchedulerMethod, TokenSource, }; use openai::{ChatCompletionRequest, Message, ModelObjects, StopTokens}; use serde::{Deserialize, Serialize}; @@ -236,6 +236,7 @@ async fn main() -> Result<()> { let use_flash_attn = true; let tgt_non_granular_index = get_tgt_non_granular_index(&args.model); + let dtype = get_model_dtype(&args.model)?; if tgt_non_granular_index.is_some() { args.max_seqs = 1; @@ -270,7 +271,7 @@ async fn main() -> Result<()> { let pipeline = loader.load_model_from_hf( None, args.token_source, - None, + &dtype, &device, false, args.num_device_layers diff --git a/mistralrs/examples/gguf_locally/main.rs b/mistralrs/examples/gguf_locally/main.rs index b04fc9fa53..e95b971b85 100644 --- a/mistralrs/examples/gguf_locally/main.rs +++ b/mistralrs/examples/gguf_locally/main.rs @@ -3,7 +3,7 @@ use tokio::sync::mpsc::channel; use mistralrs::{ Constraint, Device, DeviceMapMetadata, GGUFLoaderBuilder, GGUFSpecificConfig, MistralRs, - MistralRsBuilder, NormalRequest, Request, RequestMessage, Response, SamplingParams, + MistralRsBuilder, ModelDType, NormalRequest, Request, RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, }; @@ -24,7 +24,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::Auto, &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(), diff --git a/mistralrs/examples/grammar/main.rs b/mistralrs/examples/grammar/main.rs index 92582a450a..c09209c7cb 100644 --- a/mistralrs/examples/grammar/main.rs +++ b/mistralrs/examples/grammar/main.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use tokio::sync::mpsc::channel; use mistralrs::{ - Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, NormalLoaderBuilder, - NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, RequestMessage, Response, - SamplingParams, SchedulerMethod, TokenSource, + Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, ModelDType, + NormalLoaderBuilder, NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, + RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, }; fn setup() -> anyhow::Result> { @@ -23,7 +23,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::Auto, &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(), diff --git a/mistralrs/examples/isq/main.rs b/mistralrs/examples/isq/main.rs index 67f2f876ff..8e380ace15 100644 --- a/mistralrs/examples/isq/main.rs +++ b/mistralrs/examples/isq/main.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use tokio::sync::mpsc::channel; use mistralrs::{ - Constraint, Device, DeviceMapMetadata, GgmlDType, MistralRs, MistralRsBuilder, + Constraint, Device, DeviceMapMetadata, GgmlDType, MistralRs, MistralRsBuilder, ModelDType, NormalLoaderBuilder, NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, }; @@ -23,7 +23,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::Auto, &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(), diff --git a/mistralrs/examples/lora/main.rs b/mistralrs/examples/lora/main.rs index 5afc0e7674..7f858372d3 100644 --- a/mistralrs/examples/lora/main.rs +++ b/mistralrs/examples/lora/main.rs @@ -2,9 +2,9 @@ use std::{fs::File, sync::Arc}; use tokio::sync::mpsc::channel; use mistralrs::{ - Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, NormalLoaderBuilder, - NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, RequestMessage, Response, - SamplingParams, SchedulerMethod, TokenSource, + Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, ModelDType, + NormalLoaderBuilder, NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, + RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, }; fn setup() -> anyhow::Result> { @@ -30,7 +30,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::Auto, &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(), diff --git a/mistralrs/examples/lora_activation/main.rs b/mistralrs/examples/lora_activation/main.rs index 0aa83e9a4f..7585db4709 100644 --- a/mistralrs/examples/lora_activation/main.rs +++ b/mistralrs/examples/lora_activation/main.rs @@ -2,9 +2,9 @@ use std::{fs::File, sync::Arc}; use tokio::sync::mpsc::channel; use mistralrs::{ - Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, NormalLoaderBuilder, - NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, RequestMessage, Response, - SamplingParams, SchedulerMethod, TokenSource, + Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, ModelDType, + NormalLoaderBuilder, NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, + RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, }; fn setup() -> anyhow::Result> { @@ -30,7 +30,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::Auto, &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(), diff --git a/mistralrs/examples/phi3v/main.rs b/mistralrs/examples/phi3v/main.rs index b2a4cc3d26..469718cc68 100644 --- a/mistralrs/examples/phi3v/main.rs +++ b/mistralrs/examples/phi3v/main.rs @@ -5,9 +5,9 @@ use std::sync::Arc; use tokio::sync::mpsc::channel; use mistralrs::{ - Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, NormalRequest, Request, - RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, VisionLoaderBuilder, - VisionLoaderType, VisionSpecificConfig, + Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, ModelDType, NormalRequest, + Request, RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, + VisionLoaderBuilder, VisionLoaderType, VisionSpecificConfig, }; fn setup() -> anyhow::Result> { @@ -26,7 +26,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::Auto, &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(), diff --git a/mistralrs/examples/quantized/main.rs b/mistralrs/examples/quantized/main.rs index b6539edaf2..198976e9f7 100644 --- a/mistralrs/examples/quantized/main.rs +++ b/mistralrs/examples/quantized/main.rs @@ -3,7 +3,7 @@ use tokio::sync::mpsc::channel; use mistralrs::{ Constraint, Device, DeviceMapMetadata, GGUFLoaderBuilder, GGUFSpecificConfig, MistralRs, - MistralRsBuilder, NormalRequest, Request, RequestMessage, Response, SamplingParams, + MistralRsBuilder, ModelDType, NormalRequest, Request, RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, }; @@ -22,7 +22,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::default(), &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(), diff --git a/mistralrs/examples/simple/main.rs b/mistralrs/examples/simple/main.rs index 310c23c4d9..7378a17333 100644 --- a/mistralrs/examples/simple/main.rs +++ b/mistralrs/examples/simple/main.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use tokio::sync::mpsc::channel; use mistralrs::{ - Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, NormalLoaderBuilder, - NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, RequestMessage, Response, - SamplingParams, SchedulerMethod, TokenSource, + Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, ModelDType, + NormalLoaderBuilder, NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, + RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, }; fn setup() -> anyhow::Result> { @@ -23,7 +23,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::Auto, &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(), diff --git a/mistralrs/examples/xlora/main.rs b/mistralrs/examples/xlora/main.rs index bc4c52a3a8..40db2768db 100644 --- a/mistralrs/examples/xlora/main.rs +++ b/mistralrs/examples/xlora/main.rs @@ -2,9 +2,9 @@ use std::{fs::File, sync::Arc}; use tokio::sync::mpsc::channel; use mistralrs::{ - Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, NormalLoaderBuilder, - NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, RequestMessage, Response, - SamplingParams, SchedulerMethod, TokenSource, + Constraint, Device, DeviceMapMetadata, MistralRs, MistralRsBuilder, ModelDType, + NormalLoaderBuilder, NormalLoaderType, NormalRequest, NormalSpecificConfig, Request, + RequestMessage, Response, SamplingParams, SchedulerMethod, TokenSource, }; fn setup() -> anyhow::Result> { @@ -32,7 +32,7 @@ fn setup() -> anyhow::Result> { let pipeline = loader.load_model_from_hf( None, TokenSource::CacheToken, - None, + &ModelDType::Auto, &Device::cuda_if_available(0)?, false, DeviceMapMetadata::dummy(),