Skip to content

Commit

Permalink
Expose API to specify dtype during loading (#417)
Browse files Browse the repository at this point in the history
* Abstract dtype settings

* Abstract dtype settings

* Better name

* Expose in apis

* Logging

* Fix test

* Docs
  • Loading branch information
EricLBuehler authored Jun 10, 2024
1 parent ac1537d commit 492d767
Show file tree
Hide file tree
Showing 27 changed files with 279 additions and 115 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ Additionally, for models without quantization, the model architecture should be

### Architecture for plain models

> Note: for plain models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. This is specified in the `--dype`/`-d` parameter after the model architecture (`plain`).

- `mistral`
- `gemma`
- `mixtral`
Expand All @@ -351,6 +353,8 @@ Additionally, for models without quantization, the model architecture should be

### Architecture for vision models

> Note: for vision models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device. This is specified in the `--dype`/`-d` parameter after the model architecture (`vision-plain`).

- `phi3v`

**Interactive mode:**
Expand Down Expand Up @@ -495,6 +499,8 @@ If you want to add a new model, please contact us via an issue and we can coordi
- Error: `recompile with -fPIE`:
- Some Linux distributions require compiling with `-fPIE`.
- Set the `CUDA_NVCC_FLAGS` environment variable to `-fPIE` during build: `CUDA_NVCC_FLAGS=-fPIE`
- Error `CUDA_ERROR_NOT_FOUND` or symbol not found when using a normal or vison model:
- For non-quantized models, you can specify the data type to load and run in. This must be one of `f32`, `f16`, `bf16` or `auto` to choose based on the device.

## Credits
This project would not be possible without the excellent work at [`candle`](https://github.com/huggingface/candle). Additionally, thank you to all contributors! Contributing can range from raising an issue or suggesting a feature to adding some new functionality.
4 changes: 2 additions & 2 deletions mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use candle_core::Device;
use clap::Parser;
use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
use mistralrs_core::{
Constraint, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder,
Constraint, DeviceMapMetadata, Loader, LoaderBuilder, MistralRs, MistralRsBuilder, ModelDType,
ModelSelected, NormalRequest, Request, RequestMessage, Response, SamplingParams,
SchedulerMethod, TokenSource, Usage,
};
Expand Down Expand Up @@ -313,7 +313,7 @@ fn main() -> anyhow::Result<()> {
let pipeline = loader.load_model_from_hf(
None,
token_source,
None,
&ModelDType::Auto,
&device,
false,
args.num_device_layers
Expand Down
4 changes: 3 additions & 1 deletion mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ mod device_map;
mod engine;
mod lora;
mod model_loader;
pub use model_loader::{get_tgt_non_granular_index, LoaderBuilder};
pub use model_loader::{get_model_dtype, get_tgt_non_granular_index, LoaderBuilder};
mod model_selected;
pub use model_selected::ModelSelected;
pub use toml_selector::get_toml_selected_model_dtype;

mod cublaslt;
mod gguf;
Expand Down Expand Up @@ -62,6 +63,7 @@ pub use scheduler::SchedulerMethod;
use serde::Serialize;
use tokio::runtime::Runtime;
use toml_selector::{TomlLoaderArgs, TomlSelector};
pub use utils::normal::{ModelDType, TryIntoDType};

/// `true` if `MISTRALRS_DEBUG=1`
pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
Expand Down
31 changes: 29 additions & 2 deletions mistralrs-core/src/model_loader.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::fs::{self, File};

use crate::{
get_toml_selected_model_dtype,
pipeline::{
GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoaderBuilder, GGUFSpecificConfig,
NormalSpecificConfig,
},
Loader, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector, VisionLoaderBuilder,
VisionSpecificConfig,
Loader, ModelDType, ModelSelected, NormalLoaderBuilder, TomlLoaderArgs, TomlSelector,
VisionLoaderBuilder, VisionSpecificConfig,
};

/// A builder for a loader using the selected model.
Expand Down Expand Up @@ -70,6 +71,28 @@ pub fn get_tgt_non_granular_index(model: &ModelSelected) -> Option<usize> {
}
}

pub fn get_model_dtype(model: &ModelSelected) -> anyhow::Result<ModelDType> {
match model {
ModelSelected::Plain { dtype, .. }
| ModelSelected::Lora { dtype, .. }
| ModelSelected::XLora { dtype, .. }
| ModelSelected::VisionPlain { dtype, .. } => Ok(*dtype),
ModelSelected::GGUF { .. }
| ModelSelected::LoraGGUF { .. }
| ModelSelected::GGML { .. }
| ModelSelected::LoraGGML { .. }
| ModelSelected::XLoraGGUF { .. }
| ModelSelected::XLoraGGML { .. } => Ok(ModelDType::Auto),
ModelSelected::Toml { file } => {
let selector: TomlSelector = toml::from_str(
&fs::read_to_string(file.clone())
.unwrap_or_else(|_| panic!("Could not load toml selector file at {file}")),
)?;
Ok(get_toml_selected_model_dtype(&selector))
}
}
}

fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loader>> {
let use_flash_attn = args.use_flash_attn;
let loader: Box<dyn Loader> = match args.model {
Expand All @@ -90,6 +113,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
repeat_last_n,
tokenizer_json,
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
Expand All @@ -108,6 +132,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
tokenizer_json,
tgt_non_granular_index,
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
Expand All @@ -134,6 +159,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
repeat_last_n,
order,
arch,
dtype: _,
} => NormalLoaderBuilder::new(
NormalSpecificConfig {
use_flash_attn,
Expand Down Expand Up @@ -285,6 +311,7 @@ fn loader_from_model_selected(args: LoaderBuilder) -> anyhow::Result<Box<dyn Loa
repeat_last_n,
tokenizer_json,
arch,
dtype: _,
} => VisionLoaderBuilder::new(
VisionSpecificConfig {
use_flash_attn,
Expand Down
25 changes: 24 additions & 1 deletion mistralrs-core/src/model_selected.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use clap::Subcommand;

use crate::pipeline::{NormalLoaderType, VisionLoaderType};
use crate::{
pipeline::{NormalLoaderType, VisionLoaderType},
ModelDType,
};

fn parse_arch(x: &str) -> Result<NormalLoaderType, String> {
x.parse()
Expand All @@ -10,6 +13,10 @@ fn parse_vision_arch(x: &str) -> Result<VisionLoaderType, String> {
x.parse()
}

fn parse_model_dtype(x: &str) -> Result<ModelDType, String> {
x.parse()
}

#[derive(Debug, Subcommand)]
pub enum ModelSelected {
/// Select the model from a toml file
Expand All @@ -36,6 +43,10 @@ pub enum ModelSelected {
/// The architecture of the model.
#[arg(short, long, value_parser = parse_arch)]
arch: NormalLoaderType,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,
},

/// Select an X-LoRA architecture
Expand Down Expand Up @@ -68,6 +79,10 @@ pub enum ModelSelected {
/// The architecture of the model.
#[arg(short, long, value_parser = parse_arch)]
arch: NormalLoaderType,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,
},

/// Select a LoRA architecture
Expand Down Expand Up @@ -95,6 +110,10 @@ pub enum ModelSelected {
/// The architecture of the model.
#[arg(long, value_parser = parse_arch)]
arch: NormalLoaderType,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,
},

/// Select a GGUF model.
Expand Down Expand Up @@ -306,5 +325,9 @@ pub enum ModelSelected {
/// The architecture of the model.
#[arg(short, long, value_parser = parse_vision_arch)]
arch: VisionLoaderType,

/// Model data type. Defaults to `auto`.
#[arg(short, long, default_value_t = ModelDType::Auto, value_parser = parse_model_dtype)]
dtype: ModelDType,
},
}
12 changes: 7 additions & 5 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ use crate::utils::debug::setup_logger_and_debug;
use crate::utils::model_config as ModelConfig;
use crate::utils::tokenizer::get_tokenizer;
use crate::xlora_models::NonGranularState;
use crate::{do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, Pipeline, DEBUG};
use crate::{
do_sample, get_mut_arcmutex, get_paths, DeviceMapMetadata, Pipeline, TryIntoDType, DEBUG,
};
use crate::{
models::quantized_llama::ModelWeights as QLlama, utils::tokens::get_token,
xlora_models::XLoraQLlama,
};
use anyhow::Result;
use candle_core::quantized::{ggml_file, GgmlDType};
use candle_core::{DType, Device, Tensor};
use candle_core::{Device, Tensor};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
use std::any::Any;
Expand Down Expand Up @@ -232,7 +234,7 @@ impl Loader for GGMLLoader {
fn load_model_from_path(
&self,
paths: &Box<dyn ModelPaths>,
_dtype: Option<DType>,
_: &dyn TryIntoDType,
device: &Device,
silent: bool,
mapper: DeviceMapMetadata,
Expand Down Expand Up @@ -363,7 +365,7 @@ impl Loader for GGMLLoader {
&self,
revision: Option<String>,
token_source: TokenSource,
_dtype: Option<DType>,
dtype: &dyn TryIntoDType,
device: &Device,
silent: bool,
mapper: DeviceMapMetadata,
Expand All @@ -378,7 +380,7 @@ impl Loader for GGMLLoader {
Some(vec![self.quantized_filename.as_ref().unwrap().clone()]),
silent
);
self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant)
self.load_model_from_path(&paths?, dtype, device, silent, mapper, in_situ_quant)
}

fn get_id(&self) -> String {
Expand Down
10 changes: 5 additions & 5 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::utils::tokenizer::get_tokenizer;
use crate::xlora_models::NonGranularState;
use crate::{
do_sample, get_mut_arcmutex, get_paths_gguf, DeviceMapMetadata, LocalModelPaths, Pipeline,
DEBUG,
TryIntoDType, DEBUG,
};
use crate::{
models::quantized_llama::ModelWeights as QLlama,
Expand All @@ -39,7 +39,7 @@ use candle_core::quantized::{
gguf_file::{self, Value as GgufValue},
GgmlDType,
};
use candle_core::{DType, Device, Tensor};
use candle_core::{Device, Tensor};
use either::Either;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use rand_isaac::Isaac64Rng;
Expand Down Expand Up @@ -294,7 +294,7 @@ impl Loader for GGUFLoader {
&self,
revision: Option<String>,
token_source: TokenSource,
_dtype: Option<DType>,
dtype: &dyn TryIntoDType,
device: &Device,
silent: bool,
mapper: DeviceMapMetadata,
Expand All @@ -309,14 +309,14 @@ impl Loader for GGUFLoader {
self.quantized_filename.clone(),
silent
);
self.load_model_from_path(&paths?, _dtype, device, silent, mapper, in_situ_quant)
self.load_model_from_path(&paths?, dtype, device, silent, mapper, in_situ_quant)
}

#[allow(clippy::type_complexity, clippy::too_many_arguments)]
fn load_model_from_path(
&self,
paths: &Box<dyn ModelPaths>,
_dtype: Option<DType>,
_: &dyn TryIntoDType,
device: &Device,
silent: bool,
mapper: DeviceMapMetadata,
Expand Down
25 changes: 10 additions & 15 deletions mistralrs-core/src/pipeline/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ macro_rules! get_paths_gguf {
#[doc(hidden)]
#[macro_export]
macro_rules! normal_model_loader {
($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
let vb = from_mmaped_safetensors(
$paths.get_weight_filenames().to_vec(),
Vec::new(),
$dtype.unwrap_or($default_dtype),
$dtype,
$device,
$silent,
)?;
Expand All @@ -355,11 +355,11 @@ macro_rules! normal_model_loader {
#[doc(hidden)]
#[macro_export]
macro_rules! vision_normal_model_loader {
($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
let vb = from_mmaped_safetensors(
$paths.get_weight_filenames().to_vec(),
Vec::new(),
$dtype.unwrap_or($default_dtype),
$dtype,
$device,
$silent,
)?;
Expand All @@ -380,7 +380,7 @@ macro_rules! vision_normal_model_loader {
#[doc(hidden)]
#[macro_export]
macro_rules! xlora_model_loader {
($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
let mut safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
safetensors_paths.push($paths.get_classifier_path().as_ref().unwrap());
let vb = from_mmaped_safetensors(
Expand All @@ -395,7 +395,7 @@ macro_rules! xlora_model_loader {
.iter()
.map(|(_, x)| (*x).to_owned())
.collect::<Vec<_>>(),
$dtype.unwrap_or($default_dtype),
$dtype,
$device,
$silent,
)?;
Expand All @@ -412,20 +412,15 @@ macro_rules! xlora_model_loader {
loading_isq: $loading_isq,
real_device: $real_device,
},
&$crate::utils::varbuilder_utils::load_preload_adapters(
$paths.get_lora_preload_adapter_info(),
$dtype.unwrap_or($default_dtype),
$device,
$silent,
)?,
&None,
)?
}};
}

#[doc(hidden)]
#[macro_export]
macro_rules! lora_model_loader {
($paths:expr, $dtype:expr, $default_dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
($paths:expr, $dtype:expr, $device:expr, $config:expr, $loader:expr, $use_flash_attn:expr, $silent:expr, $mapper:expr, $loading_isq:expr, $real_device:expr) => {{
let safetensors_paths = $paths.get_weight_filenames().iter().collect::<Vec<_>>();
let vb = from_mmaped_safetensors(
safetensors_paths
Expand All @@ -439,7 +434,7 @@ macro_rules! lora_model_loader {
.iter()
.map(|(_, x)| (*x).to_owned())
.collect::<Vec<_>>(),
$dtype.unwrap_or($default_dtype),
$dtype,
$device,
$silent,
)?;
Expand All @@ -458,7 +453,7 @@ macro_rules! lora_model_loader {
},
&$crate::utils::varbuilder_utils::load_preload_adapters(
$paths.get_lora_preload_adapter_info(),
$dtype.unwrap_or($default_dtype),
$dtype,
$device,
$silent,
)?,
Expand Down
Loading

0 comments on commit 492d767

Please sign in to comment.