From 25bc793081cb5a257c47239abf1811762a1ea53c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 27 Jan 2025 12:59:53 -0500 Subject: [PATCH 1/8] Add the model --- .../src/models/deepseekv3/mod.rs | 2 + .../src/models/deepseekv3/model.rs | 1184 +++++++++++++++++ 2 files changed, 1186 insertions(+) create mode 100644 candle-transformers/src/models/deepseekv3/mod.rs create mode 100644 candle-transformers/src/models/deepseekv3/model.rs diff --git a/candle-transformers/src/models/deepseekv3/mod.rs b/candle-transformers/src/models/deepseekv3/mod.rs new file mode 100644 index 0000000000..766b13284b --- /dev/null +++ b/candle-transformers/src/models/deepseekv3/mod.rs @@ -0,0 +1,2 @@ +mod ops; +pub mod model; diff --git a/candle-transformers/src/models/deepseekv3/model.rs b/candle-transformers/src/models/deepseekv3/model.rs new file mode 100644 index 0000000000..55971cf36f --- /dev/null +++ b/candle-transformers/src/models/deepseekv3/model.rs @@ -0,0 +1,1184 @@ +#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] + +use std::{f32::consts::PI, str::FromStr, sync::Arc}; + +use candle::{ + shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Result, Shape, + Tensor, WithDType, D, +}; +use candle_nn::{embedding, rms_norm, Activation, Embedding, Linear, Module, RmsNorm, VarBuilder}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use serde::Deserialize; + +struct NonZero {} + +impl NonZero { + // Sequential version + fn nonzero(&self, vs: &[T], layout: &Layout) -> Vec { + let n = layout.dims().len(); + let mut result = Vec::new(); + let mut indices = vec![0u32; n]; + for (i, v) in vs.iter().enumerate() { + if !v.is_zero() { + let mut idx = i; + for (dim_index, dim) in layout.dims().iter().enumerate().rev() { + let d = idx % dim; + indices[dim_index] = u32::try_from(d).unwrap(); + idx /= dim; + } + result.extend_from_slice(&indices); + } + } + result + } +} + +#[cfg(feature = "cuda")] +fn count_nonzero_cuda(dtype: candle::DType, d_in: *const c_void, n: u32) -> u32 { + unsafe { + match dtype { + candle::DType::U8 => ffi::count_nonzero_u8(d_in, n), + candle::DType::U32 => ffi::count_nonzero_u32(d_in, n), + candle::DType::I64 => ffi::count_nonzero_i64(d_in, n), + candle::DType::I16 => ffi::count_nonzero_i16(d_in, n), + candle::DType::I32 => ffi::count_nonzero_i32(d_in, n), + candle::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), + candle::DType::F16 => ffi::count_nonzero_f16(d_in, n), + candle::DType::F32 => ffi::count_nonzero_f32(d_in, n), + candle::DType::F64 => ffi::count_nonzero_f64(d_in, n), + candle::DType::F8E4M3 => todo!(), + } + } +} + +#[cfg(feature = "cuda")] +fn nonzero_cuda( + dtype: candle::DType, + d_in: *const c_void, + n: u32, + num_nonzero: u32, + dims: *const c_void, + num_dims: u32, + d_out: *mut c_void, +) { + unsafe { + match dtype { + candle::DType::U8 => ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::U32 => ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::I64 => ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::I32 => ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::I16 => ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::BF16 => ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::F16 => ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::F32 => ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::F64 => ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out), + candle::DType::F8E4M3 => todo!(), + } + } +} + +impl CustomOp1 for NonZero { + fn name(&self) -> &'static str { + "nonzero" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + if !layout.is_contiguous() { + return Err(Error::RequiresContiguous { op: "nonzero" }); + } + let result = match storage { + candle::CpuStorage::U8(vs) => self.nonzero(vs, layout), + candle::CpuStorage::U32(vs) => self.nonzero(vs, layout), + candle::CpuStorage::I64(vs) => self.nonzero(vs, layout), + candle::CpuStorage::BF16(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F16(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F32(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F64(vs) => self.nonzero(vs, layout), + }; + let index_len = layout.dims().len(); + let result_len = result.len() / index_len; + let result = CpuStorage::U32(result); + let shape = Shape::from_dims(&[result_len, index_len]); + Ok((result, shape)) + } +} + +pub trait NonZeroOp { + fn nonzero(&self) -> Result; +} + +impl NonZeroOp for Tensor { + fn nonzero(&self) -> Result { + if !self.is_contiguous() { + return Err(candle::Error::RequiresContiguous { op: "nonzero" }); + } + let original_device = self.device(); + self.to_device(&candle::Device::Cpu)? + .apply_op1_no_bwd(&NonZero {})? + .to_device(original_device) + } +} + +pub struct TopKOutput { + pub values: Tensor, + pub indices: Tensor, +} + +pub trait TopKLastDimOp { + /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self. + /// This expects a contiguous tensor. + /// Note: this implements torch.topk with sorted=True. + fn topk(&self, topk: usize) -> Result; + + /// Topk in the last dim. `values` retains a gradient but `indices` has none w.r.t self. + /// This expects a contiguous tensor. + /// Note: this implements torch.topk with sorted=False. + fn topk_unsorted(&self, topk: usize) -> Result; +} + +impl TopKLastDimOp for Tensor { + fn topk(&self, topk: usize) -> Result { + // Sorted descending + let sorted_indices = self.arg_sort_last_dim(false)?; + let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?; + Ok(TopKOutput { + values: self.gather(&topk_indices, D::Minus1)?, + indices: topk_indices, + }) + } + + fn topk_unsorted(&self, topk: usize) -> Result { + // Sorted descending + let sorted_indices_all = self.arg_sort_last_dim(false)?; + let topk_indices_sorted = sorted_indices_all + .narrow(D::Minus1, 0, topk)? + .contiguous()?; + let topk_values_sorted = self.gather(&topk_indices_sorted, D::Minus1)?; + + // Reorder the indices ascending + let reorder_indices = topk_indices_sorted.arg_sort_last_dim(true)?; + let topk_indices_unsorted = topk_indices_sorted + .to_dtype(DType::F32)? + .gather(&reorder_indices, D::Minus1)? + .to_dtype(DType::U32)?; + let topk_values_unsorted = topk_values_sorted.gather(&reorder_indices, D::Minus1)?; + Ok(TopKOutput { + values: topk_values_unsorted, + indices: topk_indices_unsorted, + }) + } +} + +pub trait SplitOp { + fn split(&self, splits: &[usize], dim: D) -> Result>; +} + +impl SplitOp for Tensor { + fn split(&self, splits: &[usize], dim: D) -> Result> { + let dim = dim.to_index(self.shape(), "split")?; + let mut split_res = Vec::new(); + let mut index = 0; + for split in splits { + split_res.push(self.narrow(dim, index, *split)?); + index += *split; + } + Ok(split_res) + } +} + +pub trait BincountOp { + fn bincount(&self, minlength: u32) -> Result>; +} + +fn bincount(values: &[u32], minlength: u32) -> Vec { + // Find the maximum value in `values` (or zero if empty) + let max_val = values.par_iter().max().copied().unwrap_or(0); + + // The final size of the bin counts must be at least `minlength` + // and large enough to include the largest value in `values`. + let result_len = (max_val + 1).max(minlength); + + // Each thread creates a local histogram (`fold`), + // and then they are merged together (`reduce`). + values + .par_iter() + .fold( + // Create a local histogram + || vec![0u32; result_len as usize], + // Update the local histogram + |mut local_counts, &val| { + local_counts[val as usize] += 1; + local_counts + }, + ) + // Merge histograms from all threads + .reduce( + // Identity (empty histogram) + || vec![0u32; result_len as usize], + // Combine two histograms + |mut global_counts, local_counts| { + for (g, l) in global_counts.iter_mut().zip(local_counts) { + *g += l; + } + global_counts + }, + ) +} + +impl BincountOp for Tensor { + fn bincount(&self, minlength: u32) -> Result> { + let values = self.to_vec1::()?; + + Ok(bincount(&values, minlength)) + } +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { + let shape = mask.shape(); + let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; + Ok(m) +} + +#[doc(hidden)] +#[macro_export] +macro_rules! serde_default_fn { + ($t:ty, $name:ident, $v:expr) => { + fn $name() -> $t { + $v + } + }; +} + +serde_default_fn!(f64, routed_scaling_factor, 1.0); +serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy); +serde_default_fn!(usize, moe_layer_freq, 1); +serde_default_fn!(usize, first_k_dense_replace, 0); +serde_default_fn!(bool, norm_topk_prob, false); +serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax); +serde_default_fn!(Activation, hidden_act, Activation::Silu); +serde_default_fn!(bool, tie_word_embeddings, false); + +#[derive(Deserialize, Clone, Debug)] +enum TopkMethod { + #[serde(rename = "noaux_tc")] + NoAuxTc, + #[serde(rename = "greedy")] + Greedy, + #[serde(rename = "group_limited_greedy")] + GroupLimitedGreedy, +} + +#[derive(Deserialize, Clone, Debug)] +enum ScoringFunc { + #[serde(rename = "softmax")] + Softmax, +} + +#[derive(Deserialize, Clone, Debug)] +pub struct DeepSeekV2Config { + pub(crate) vocab_size: usize, + pub(crate) hidden_size: usize, + pub(crate) intermediate_size: usize, + pub(crate) moe_intermediate_size: usize, + pub(crate) num_hidden_layers: usize, + pub(crate) num_attention_heads: usize, + pub(crate) n_shared_experts: Option, + pub(crate) n_routed_experts: Option, + #[serde(default = "routed_scaling_factor")] + pub(crate) routed_scaling_factor: f64, + #[serde(default = "topk_method")] + topk_method: TopkMethod, + pub(crate) num_experts_per_tok: Option, + #[serde(default = "moe_layer_freq")] + pub(crate) moe_layer_freq: usize, + #[serde(default = "first_k_dense_replace")] + pub(crate) first_k_dense_replace: usize, + // k dense layers + #[serde(default = "norm_topk_prob")] + pub(crate) norm_topk_prob: bool, + #[serde(default = "scoring_func")] + scoring_func: ScoringFunc, + #[serde(default = "hidden_act")] + pub(crate) hidden_act: Activation, + pub(crate) max_position_embeddings: usize, + pub(crate) rms_norm_eps: f64, + #[serde(default = "tie_word_embeddings")] + pub(crate) tie_word_embeddings: bool, + pub(crate) rope_theta: f32, + pub(crate) rope_scaling: Option, + pub(crate) attention_bias: bool, + pub(crate) q_lora_rank: Option, + pub(crate) qk_rope_head_dim: usize, + pub(crate) kv_lora_rank: usize, + pub(crate) v_head_dim: usize, + pub(crate) qk_nope_head_dim: usize, + pub(crate) n_group: usize, + pub(crate) topk_group: usize, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ScaledRopeType { + #[serde(alias = "su")] + #[serde(alias = "longrope")] + Su, + #[serde(alias = "yarn")] + Yarn, + #[serde(alias = "dynamic")] + Dynamic, + #[serde(alias = "linear")] + Linear, +} + +impl FromStr for ScaledRopeType { + type Err = candle::Error; + fn from_str(s: &str) -> std::result::Result { + match s { + "su" | "longrope" => Ok(Self::Su), + "yarn" => Ok(Self::Yarn), + "linear" => Ok(Self::Linear), + "dynamic" => Ok(Self::Dynamic), + _ => Err(candle::Error::Msg( + "Expected either `su` or `yarn` scaled RoPE type.".to_string(), + )), + } + } +} + +#[derive(Debug, Clone)] +pub struct DeepSeekV2RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum DeepSeekV2RopeScaling { + Yarn { + original_max_position_embeddings: usize, + beta_fast: f32, + beta_slow: f32, + mscale: f32, + mscale_all_dim: f32, + factor: f32, + #[serde(rename = "type")] + scaling_type: ScaledRopeType, + }, + LinearOrDynamic { + #[serde(rename = "type")] + scaling_type: ScaledRopeType, + factor: f64, + }, +} + +pub struct DeepSeekV2RopeConfig { + pub rope_scaling: Option, + pub max_position_embeddings: usize, + pub rope_theta: f32, + pub qk_rope_head_dim: usize, +} + +impl DeepSeekV2RotaryEmbedding { + fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result { + let max_seq_len = cfg.max_position_embeddings; + let dim = cfg.qk_rope_head_dim; + + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + + let sin = freqs.sin()?.to_dtype(dtype)?; + let cos = freqs.cos()?.to_dtype(dtype)?; + + Ok(Self { sin, cos }) + } + + fn yarn_find_correction_dim( + num_rot: f32, + dim: usize, + base: f32, + max_position_embeddings: usize, + ) -> f32 { + (dim as f32 * (max_position_embeddings as f32 / (num_rot * 2. * PI)).ln()) + / (2. * base.ln()) + } + + fn yarn_find_correction_range( + low_rot: f32, + high_rot: f32, + dim: usize, + base: f32, + max_position_embeddings: usize, + ) -> (f32, f32) { + let low = + Self::yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings).floor(); + let high = + Self::yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings).ceil(); + (low.max(0.), high.min(dim as f32 - 1.)) + } + + fn yarn_linear_ramp_mask(min: f32, mut max: f32, dim: usize, dev: &Device) -> Result { + if min == max { + // https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/604d5664dddd88a0433dbae533b7fe9472482de0/modeling_deepseek.py#L255 + max += 0.001; + } + let linear_func = + ((Tensor::arange(0f32, dim as f32, dev)? - min as f64)? / (max as f64 - min as f64))?; + linear_func.clamp(0., 1.) + } + + pub(crate) fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 { + if scale <= 1. { + return 1.; + } + 0.1 * mscale * scale.ln() + 1. + } + + #[allow(clippy::too_many_arguments)] + fn new_yarn( + cfg: &DeepSeekV2RopeConfig, + dtype: DType, + dev: &Device, + original_max_position_embeddings: usize, + beta_fast: f32, + beta_slow: f32, + factor: f32, + mscale: f32, + mscale_all_dim: f32, + ) -> Result { + let freq_extra: Vec<_> = (0..cfg.qk_rope_head_dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32)) + .collect(); + let freq_extra_len = freq_extra.len(); + let freq_extra = Tensor::from_vec(freq_extra, freq_extra_len, dev)?; + let freq_inter: Vec<_> = (0..cfg.qk_rope_head_dim) + .step_by(2) + .map(|i| 1f32 / (factor * cfg.rope_theta.powf(i as f32 / cfg.qk_rope_head_dim as f32))) + .collect(); + let freq_inter_len = freq_inter.len(); + let freq_inter = Tensor::from_vec(freq_inter, (1, freq_inter_len), dev)?; + + let (low, high) = Self::yarn_find_correction_range( + beta_fast, + beta_slow, + cfg.qk_rope_head_dim, + cfg.rope_theta, + original_max_position_embeddings, + ); + let inv_freq_mask = + (1. - Self::yarn_linear_ramp_mask(low, high, cfg.qk_rope_head_dim / 2, dev)?)?; + let inv_freq = freq_inter + .broadcast_mul(&(1. - &inv_freq_mask)?)? + .broadcast_add(&freq_extra.broadcast_mul(&inv_freq_mask)?)?; + + let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)? + .to_dtype(DType::F32)? + .reshape((cfg.max_position_embeddings, 1))?; + let freqs = t.matmul(&inv_freq)?; + + let mscale = + Self::yarn_get_mscale(factor, mscale) / Self::yarn_get_mscale(factor, mscale_all_dim); + let sin = (freqs.sin()? * mscale as f64)?.to_dtype(dtype)?; + let cos = (freqs.cos()? * mscale as f64)?.to_dtype(dtype)?; + + Ok(Self { sin, cos }) + } + + pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result { + match &cfg.rope_scaling { + Some(DeepSeekV2RopeScaling::LinearOrDynamic { + scaling_type: _, + factor: _, + }) => candle::bail!("linear and dynamic rope are not implemented yet!"), + Some(DeepSeekV2RopeScaling::Yarn { + original_max_position_embeddings, + beta_fast, + beta_slow, + factor, + mscale, + mscale_all_dim, + scaling_type: _, + }) => Self::new_yarn( + cfg, + dtype, + dev, + *original_max_position_embeddings, + *beta_fast, + *beta_slow, + *factor, + *mscale, + *mscale_all_dim, + ), + None => Self::new_unscaled(cfg, dtype, dev), + } + } + + pub fn forward( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + + let q_embed = candle_nn::rotary_emb::rope_i(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope_i(&k.contiguous()?, &cos, &sin)?; + + Ok((q_embed, k_embed)) + } +} + +impl DeepSeekV2Config { + pub(crate) fn q_head_dim(&self) -> usize { + self.qk_rope_head_dim + self.qk_nope_head_dim + } + + fn softmax_scale(&self) -> f32 { + let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt(); + if let Some(DeepSeekV2RopeScaling::Yarn { + mscale_all_dim, + factor, + .. + }) = self.rope_scaling + { + let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim); + softmax_scale = softmax_scale * mscale * mscale; + } + softmax_scale + } +} + +enum QProj { + Plain(Linear), + Lora { a: Linear, norm: RmsNorm, b: Linear }, +} + +impl QProj { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Lora { a, norm, b } => b.forward(&norm.forward(&a.forward(xs)?)?), + Self::Plain(lin) => lin.forward(xs), + } + } +} + +struct Attention { + q: QProj, + kv_a_proj_with_mqa: Linear, + kv_a_layernorm: RmsNorm, + kv_b_proj: Linear, + o_proj: Linear, + rotary_emb: Arc, + cfg: DeepSeekV2Config, + q_head_dim: usize, + softmax_scale: f64, +} + +impl Attention { + fn new( + rotary_emb: Arc, + cfg: &DeepSeekV2Config, + vb: VarBuilder, + ) -> Result { + let q_head_dim = cfg.q_head_dim(); + let q = match cfg.q_lora_rank { + Some(lora_rank) => { + let a = candle_nn::linear_b( + cfg.hidden_size, + lora_rank, + cfg.attention_bias, + vb.pp("q_a_proj"), + )?; + let norm = rms_norm(lora_rank, cfg.rms_norm_eps, vb.pp("q_a_layernorm"))?; + let b = candle_nn::linear_no_bias( + lora_rank, + cfg.num_attention_heads * q_head_dim, + vb.pp("q_b_proj"), + )?; + QProj::Lora { a, norm, b } + } + None => QProj::Plain(candle_nn::linear_no_bias( + cfg.hidden_size, + cfg.num_attention_heads * q_head_dim, + vb.pp("q_proj"), + )?), + }; + + let kv_a_proj_with_mqa = candle_nn::linear_b( + cfg.hidden_size, + cfg.kv_lora_rank + cfg.qk_rope_head_dim, + cfg.attention_bias, + vb.pp("kv_a_proj_with_mqa"), + )?; + let kv_a_layernorm = rms_norm(cfg.kv_lora_rank, cfg.rms_norm_eps, vb.pp("kv_a_layernorm"))?; + let kv_b_proj = candle_nn::linear_no_bias( + cfg.kv_lora_rank, + cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim), + vb.pp("kv_b_proj"), + )?; + + let o_proj = candle_nn::linear_b( + cfg.num_attention_heads * cfg.v_head_dim, + cfg.hidden_size, + cfg.attention_bias, + vb.pp("o_proj"), + )?; + + Ok(Self { + q, + kv_a_proj_with_mqa, + kv_a_layernorm, + kv_b_proj, + o_proj, + rotary_emb, + cfg: cfg.clone(), + q_head_dim, + softmax_scale: cfg.softmax_scale() as f64, + }) + } + + fn forward( + &self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (bs, seq_len, _) = xs.dims3()?; + + let mut q = self.q.forward(xs)?; + q = q + .reshape((bs, seq_len, self.cfg.num_attention_heads, self.q_head_dim))? + .transpose(1, 2)?; + let q_split = q.split( + &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim], + D::Minus1, + )?; + let q_nope = q_split[0].clone(); + let mut q_pe = q_split[1].clone(); + + let mut compressed_kv = self.kv_a_proj_with_mqa.forward(xs)?; + let ckv_split = compressed_kv.split( + &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim], + D::Minus1, + )?; + compressed_kv = ckv_split[0].clone(); + let mut k_pe = ckv_split[1].clone(); + k_pe = k_pe + .reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))? + .transpose(1, 2)?; + let mut kv = self + .kv_b_proj + .forward(&self.kv_a_layernorm.forward(&compressed_kv)?)?; + kv = kv + .reshape(( + bs, + seq_len, + self.cfg.num_attention_heads, + self.cfg.qk_nope_head_dim + self.cfg.v_head_dim, + ))? + .transpose(1, 2)?; + + let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?; + let k_nope = kv_split[0].clone(); + let v = kv_split[1].clone(); + + (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offset)?; + + let mut q = Tensor::zeros( + (bs, self.cfg.num_attention_heads, seq_len, self.q_head_dim), + q_pe.dtype(), + q_pe.device(), + )?; + q = q.slice_assign( + &[ + 0..q.dim(0)?, + 0..q.dim(1)?, + 0..q.dim(2)?, + 0..self.cfg.qk_nope_head_dim, + ], + &q_nope, + )?; + q = q.slice_assign( + &[ + 0..q.dim(0)?, + 0..q.dim(1)?, + 0..q.dim(2)?, + self.cfg.qk_nope_head_dim..q.dim(3)?, + ], + &q_pe, + )?; + + let mut k = Tensor::zeros( + (bs, self.cfg.num_attention_heads, seq_len, self.q_head_dim), + k_pe.dtype(), + k_pe.device(), + )?; + k = k.slice_assign( + &[ + 0..k.dim(0)?, + 0..k.dim(1)?, + 0..k.dim(2)?, + 0..self.cfg.qk_nope_head_dim, + ], + &k_nope, + )?; + let k_pe = k_pe.repeat((1, k.dim(1)?, 1, 1))?; + k = k.slice_assign( + &[ + 0..k.dim(0)?, + 0..k.dim(1)?, + 0..k.dim(2)?, + self.cfg.qk_nope_head_dim..k.dim(3)?, + ], + &k_pe, + )?; + + let mut attn_out = { + let att = (q.matmul(&k.t()?)? * self.softmax_scale)?; + let att = match attention_mask { + Some(mask) => att.broadcast_add(mask)?, + None => att, + }; + + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)? + }; + + attn_out = if attention_mask.is_some() { + attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))? + } else { + attn_out.reshape((bs, seq_len, ()))? + }; + + self.o_proj.forward(&attn_out) + } +} + +struct Mlp { + gate: Linear, + up: Linear, + down: Linear, + act: Activation, +} + +impl Mlp { + fn new( + cfg: &DeepSeekV2Config, + vb: VarBuilder, + hidden_size: Option, + intermediate_size: Option, + ) -> Result { + let hidden_size = hidden_size.unwrap_or(cfg.hidden_size); + let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size); + + Ok(Self { + gate: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?, + up: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?, + down: candle_nn::linear_no_bias(intermediate_size, hidden_size, vb.pp("down_proj"))?, + act: cfg.hidden_act, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let lhs = self.gate.forward(xs)?.apply(&self.act)?; + let rhs = self.up.forward(xs)?; + self.down.forward(&(&lhs * &rhs)?) + } +} + +struct MoeGate { + weight: Tensor, + cfg: DeepSeekV2Config, + top_k: usize, + n_routed_experts: usize, + e_score_correction_bias: Option, +} + +impl MoeGate { + fn new(cfg: &DeepSeekV2Config, vb: VarBuilder, n_routed_experts: usize) -> Result { + let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?; + let e_score_correction_bias = if matches!(cfg.topk_method, TopkMethod::NoAuxTc) { + Some(vb.get_with_hints_dtype( + n_routed_experts, + "e_score_correction_bias", + Default::default(), + DType::F32, + )?) + } else { + None + }; + Ok(Self { + weight, + cfg: cfg.clone(), + top_k: cfg.num_experts_per_tok.unwrap(), + n_routed_experts, + e_score_correction_bias, + }) + } + + /// (topk_idx, topk_weight) + fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> { + let (bs, seq_len, h) = xs.dims3()?; + // Compute gating score + let xs = xs.reshape(((), h))?; + let logits = xs + .to_dtype(DType::F32)? + .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?; + let scores = match self.cfg.scoring_func { + ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?, + }; + + // Select top-k experts + let (mut topk_weight, topk_idx) = match self.cfg.topk_method { + TopkMethod::Greedy => { + let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?; + (values, indices) + } + TopkMethod::NoAuxTc => { + let Some(e_score_correction_bias) = &self.e_score_correction_bias else { + candle::bail!("Expected e_score_correction_bias") + }; + let scores_for_choice = scores + .reshape((bs * seq_len, ()))? + .broadcast_add(&e_score_correction_bias.unsqueeze(0)?)?; + // (n, n_group) + let group_scores = scores_for_choice + .reshape((bs * seq_len, self.cfg.n_group, ()))? + .topk(2)? + .values + .sum(D::Minus1)?; + // (n, topk_group) + let group_idx = group_scores.topk(self.cfg.topk_group)?.indices; + // (n, n_group) + let mut group_mask = group_scores.zeros_like()?; + // (n, n_group) + group_mask = group_mask.scatter_add( + &group_idx, + &group_idx.ones_like()?.to_dtype(group_mask.dtype())?, + 1, + )?; + // (n, e) + let score_mask = group_mask + .unsqueeze(D::Minus1)? + .expand(( + bs * seq_len, + self.cfg.n_group, + self.n_routed_experts / self.cfg.n_group, + ))? + .reshape((bs * seq_len, ()))?; + // (n, e) + // Invert the mask + let tmp_scores = scores_for_choice.broadcast_mul(&score_mask)?; + let topk_idx = tmp_scores.topk(self.top_k)?.indices; + (scores.gather(&topk_idx, 1)?, topk_idx) + } + TopkMethod::GroupLimitedGreedy => { + // (n, n_group) + let group_scores = scores + .reshape((bs * seq_len, self.cfg.n_group, ()))? + .max(D::Minus1)?; + // (n, topk_group) + let group_idx = scores.topk_unsorted(self.cfg.topk_group)?.indices; + // (n, n_group) + let mut group_mask = group_scores.zeros_like()?; + // (n, n_group) + group_mask = group_mask.scatter_add( + &group_idx, + &group_idx.ones_like()?.to_dtype(group_mask.dtype())?, + 1, + )?; + // (n, e) + let score_mask = group_mask + .unsqueeze(D::Minus1)? + .expand(( + bs * seq_len, + self.cfg.n_group, + self.n_routed_experts / self.cfg.n_group, + ))? + .reshape((bs, seq_len, ()))?; + // (n, e) + // Invert the mask + let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?; + let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?; + (values, indices) + } + }; + + if self.top_k > 1 && self.cfg.norm_topk_prob { + let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?; + topk_weight = (topk_weight / denmoninator)?; + } else { + topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?; + } + Ok((topk_idx, topk_weight)) + } +} + +struct Moe { + experts: Vec, + shared_experts: Option, + gate: MoeGate, +} + +impl Moe { + fn new( + cfg: &DeepSeekV2Config, + vb: VarBuilder, + + n_shared_experts: Option, + n_routed_experts: usize, + ) -> Result { + let mut experts = Vec::with_capacity(n_routed_experts); + for i in 0..n_routed_experts { + let vb_e = vb.pp("experts").pp(i); + experts.push(Mlp::new(cfg, vb_e, None, Some(cfg.moe_intermediate_size))?); + } + let shared_experts = if let Some(n_shared_experts) = n_shared_experts { + let intermediate_size = cfg.moe_intermediate_size * n_shared_experts; + Some(Mlp::new( + cfg, + vb.pp("shared_experts"), + None, + Some(intermediate_size), + )?) + } else { + None + }; + let gate = MoeGate::new(cfg, vb.pp("gate"), n_routed_experts)?; + Ok(Self { + experts, + shared_experts, + gate, + }) + } + + fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result { + let mut y = xs.zeros_like()?; + let counts = topk_ids + .flatten_all()? + .bincount(self.experts.len() as u32)?; + for (i, expert) in self.experts.iter().enumerate() { + if counts[i] == 0 { + continue; + } + let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?; + let idx = &idx_top.i(0)?.contiguous()?; + let top = &idx_top.i(1)?.contiguous()?; + + y = y.index_add( + idx, + &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul( + &topk_weight + .index_select(idx, 0)? + .gather(&top.unsqueeze(1)?, 1)? + .squeeze(1)? + .unsqueeze(D::Minus1)? + .to_dtype(xs.dtype())?, + )?, + 0, + )?; + } + + Ok(y) + } + + fn forward(&self, xs: &Tensor) -> Result { + let identity = xs.clone(); + let orig_shape = xs.shape(); + let (topk_idx, topk_weight) = self.gate.forward(xs)?; + let xs = xs.reshape(((), xs.dim(D::Minus1)?))?; + + let mut y = self + .moe_infer(&xs, &topk_idx, &topk_weight)? + .reshape(orig_shape)?; + if let Some(ref shared_experts) = self.shared_experts { + y = (y + shared_experts.forward(&identity)?)?; + } + Ok(y) + } +} + +enum MoeOrMlp { + Moe(Moe), + Mlp(Mlp), +} + +impl MoeOrMlp { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Mlp(mlp) => mlp.forward(xs), + Self::Moe(moe) => moe.forward(xs), + } + } +} + +struct DecoderLayer { + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, + attn: Attention, + moe_or_mlp: MoeOrMlp, +} + +impl DecoderLayer { + fn new( + rotary_emb: Arc, + cfg: &DeepSeekV2Config, + vb: VarBuilder, + layer_idx: usize, + ) -> Result { + let attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let input_layernorm = + rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = rms_norm( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + let moe_or_mlp = if cfg.n_routed_experts.is_some() + && layer_idx >= cfg.first_k_dense_replace + && layer_idx % cfg.moe_layer_freq == 0 + { + MoeOrMlp::Moe(Moe::new( + cfg, + vb.pp("mlp"), + cfg.n_shared_experts, + cfg.n_routed_experts.unwrap(), + )?) + } else { + MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?) + }; + + Ok(Self { + input_layernorm, + post_attention_layernorm, + attn, + moe_or_mlp, + }) + } + + fn forward( + &self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.attn.forward(&xs, attention_mask, seqlen_offset)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .moe_or_mlp + .forward(&xs.apply(&self.post_attention_layernorm)?)?; + residual + xs + } +} + +pub struct DeepSeekV2 { + lm_head: Linear, + embed_tokens: Embedding, + norm: RmsNorm, + layers: Vec, + dtype: DType, + device: Device, +} + +impl DeepSeekV2 { + pub fn new(cfg: &DeepSeekV2Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + + let embed_tokens = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let lm_head = if !cfg.tie_word_embeddings { + candle_nn::linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + } else { + candle_nn::Linear::new(embed_tokens.embeddings().clone(), None) + }; + let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + + let rope_cfg = DeepSeekV2RopeConfig { + rope_scaling: cfg.rope_scaling.clone(), + max_position_embeddings: cfg.max_position_embeddings, + rope_theta: cfg.rope_theta, + qk_rope_head_dim: cfg.qk_rope_head_dim, + }; + let rotary_emb = Arc::new(DeepSeekV2RotaryEmbedding::new( + &rope_cfg, + vb.dtype(), + vb.device(), + )?); + + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx), layer_idx)?; + layers.push(layer) + } + + Ok(Self { + lm_head, + embed_tokens, + norm, + layers, + dtype: vb.dtype(), + device: vb.device().clone(), + }) + } + + fn prepare_decoder_attention_mask( + &self, + b_size: usize, + tgt_len: usize, + seqlen_offset: usize, + ) -> Result { + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. })) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(self.dtype) + } + + pub fn forward(&self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + let (bs, seq_len) = input_ids.dims2()?; + let mut xs = self.embed_tokens.forward(input_ids)?; + let attention_mask = if seq_len == 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(bs, seq_len, seqlen_offset)?; + Some(mask) + }; + for layer in &self.layers { + xs = layer.forward( + &xs, + attention_mask + .as_ref() + .map(|m| m.to_device(xs.device()).unwrap()) + .as_ref(), + seqlen_offset, + )?; + } + let xs = xs.apply(&self.norm)?; + let xs = xs.i((.., seq_len - 1, ..))?.contiguous()?; + let logits = self.lm_head.forward(&xs)?; + logits.to_dtype(DType::F32) + } +} From e367cd7894ff6d268b924a09e636e4cc728d028b Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Thu, 17 Oct 2024 07:28:05 -0400 Subject: [PATCH 2/8] Add the f8 e4m3 dtype --- Cargo.toml | 1 + candle-core/Cargo.toml | 3 +- candle-core/src/convert.rs | 6 + candle-core/src/cpu_backend/mod.rs | 121 ++++++++++++++++ candle-core/src/cpu_backend/utils.rs | 82 +++++++++++ candle-core/src/cuda_backend/device.rs | 50 ++++++- candle-core/src/cuda_backend/mod.rs | 193 +++++++++++++++++++++++++ candle-core/src/cuda_backend/utils.rs | 8 + candle-core/src/display.rs | 9 ++ candle-core/src/dtype.rs | 23 ++- candle-core/src/metal_backend/mod.rs | 4 + candle-core/src/npy.rs | 10 ++ candle-core/src/op.rs | 67 +++++++++ candle-core/src/safetensors.rs | 5 + candle-core/src/sort.rs | 28 ++++ candle-kernels/src/affine.cu | 26 ++-- candle-kernels/src/binary.cu | 15 ++ candle-kernels/src/cast.cu | 86 +++++++++++ candle-kernels/src/compatibility.cuh | 1 + candle-kernels/src/conv.cu | 12 ++ candle-kernels/src/cuda_utils.cuh | 23 +++ candle-kernels/src/fill.cu | 5 + candle-kernels/src/indexing.cu | 95 ++++++++++++ candle-kernels/src/reduce.cu | 8 + candle-kernels/src/sort.cu | 3 + candle-kernels/src/ternary.cu | 6 + candle-kernels/src/unary.cu | 27 ++++ candle-pyo3/Cargo.toml | 1 + candle-pyo3/src/lib.rs | 3 + candle-transformers/src/models/mod.rs | 1 + 30 files changed, 901 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e8d1f76988..d5a527b105 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.4.1" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] } hound = "3.5.1" image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } imageproc = { version = "0.24.0", default-features = false } diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 4ffc869ff8..799c40308b 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } libc = { workspace = true, optional = true } memmap2 = { workspace = true } @@ -42,7 +43,7 @@ criterion = { workspace = true } [features] default = [] -cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"] +cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 5ea5612a7c..db7bf6a4a8 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -1,5 +1,6 @@ //! Implement conversion traits for tensors use crate::{DType, Device, Error, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::convert::TryFrom; @@ -139,6 +140,11 @@ impl Tensor { let vs = vs.to_vec1::()?; f.write_all(&vs)?; } + DType::F8E4M3 => { + for v in vs.to_vec1::()? { + f.write_u8(v.to_bits())? + } + } } Ok(()) } diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 11ff1a406f..6d1185dbef 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -2,6 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; use rayon::prelude::*; @@ -25,6 +26,7 @@ pub enum CpuStorage { F16(Vec), F32(Vec), F64(Vec), + F8E4M3(Vec), } #[derive(Debug, Clone)] @@ -36,6 +38,7 @@ pub enum CpuStorageRef<'a> { F16(&'a [f16]), F32(&'a [f32]), F64(&'a [f64]), + F8E4M3(&'a [F8E4M3]), } #[derive(Debug, Clone)] @@ -1623,6 +1626,17 @@ impl CpuStorage { .concat(); Self::F64(storages) } + Self::F8E4M3(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::F8E4M3(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::F8E4M3(storages) + } }; Ok(s) } @@ -1640,6 +1654,7 @@ impl BackendStorage for CpuStorage { Self::F16(_) => DType::F16, Self::F32(_) => DType::F32, Self::F64(_) => DType::F64, + Self::F8E4M3(_) => DType::F8E4M3, } } @@ -1674,6 +1689,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, bf16::from_f64); Ok(Self::BF16(data)) } + (Self::F8E4M3(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); + Ok(Self::BF16(data)) + } (Self::U8(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -1702,6 +1721,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } + (Self::F8E4M3(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); + Ok(Self::F16(data)) + } (Self::U8(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -1730,6 +1753,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::F8E4M3(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v.to_f32()); + Ok(Self::F32(data)) + } (Self::U8(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v); Ok(Self::U8(data)) @@ -1758,6 +1785,14 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } + (Self::F8E4M3(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u8); + Ok(Self::U8(data)) + } (Self::U8(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -1786,6 +1821,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::F8E4M3(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as u32); + Ok(Self::U32(data)) + } (Self::U8(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -1814,6 +1853,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::F8E4M3(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i64); + Ok(Self::I64(data)) + } (Self::U8(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -1842,6 +1885,42 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } + (Self::F8E4M3(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v.to_f64()); + Ok(Self::F64(data)) + } + (Self::U8(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::U32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::I64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32)); + Ok(Self::F8E4M3(data)) + } + (Self::BF16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F16(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32())); + Ok(Self::F8E4M3(data)) + } + (Self::F32(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f32); + Ok(Self::F8E4M3(data)) + } + (Self::F64(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, F8E4M3::from_f64); + Ok(Self::F8E4M3(data)) + } + (Self::F8E4M3(storage), DType::F8E4M3) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::F8E4M3(data)) + } } } @@ -1955,6 +2034,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v.powf(e)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), @@ -1980,6 +2063,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| elu(v, alpha)); Ok(Self::F64(data)) } + Self::F8E4M3(storage) => { + let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha))); + Ok(Self::F8E4M3(data)) + } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), @@ -2024,6 +2111,15 @@ impl BackendStorage for CpuStorage { Ok(Self::F64(data)) } } + Self::F8E4M3(storage) => { + if B::F8E4M3_VEC { + let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec); + Ok(Self::F8E4M3(data)) + } else { + let data = unary_map(storage, layout, B::f8e4m3); + Ok(Self::F8E4M3(data)) + } + } Self::U8(storage) => { let data = unary_map(storage, layout, B::u8); Ok(Self::U8(data)) @@ -2505,6 +2601,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let uniform = + rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max)); + for _i in 0..elem_count { + data.push(rng.sample::(uniform)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let uniform = rand::distributions::Uniform::new(min as f32, max as f32); @@ -2551,6 +2656,15 @@ impl BackendDevice for CpuDevice { } Ok(CpuStorage::F16(data)) } + DType::F8E4M3 => { + let mut data = Vec::with_capacity(elem_count); + let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std)) + .map_err(Error::wrap)?; + for _i in 0..elem_count { + data.push(normal.sample(&mut rng)) + } + Ok(CpuStorage::F8E4M3(data)) + } DType::F32 => { let mut data = Vec::with_capacity(elem_count); let normal = @@ -2614,6 +2728,11 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::F64(v) } + DType::F8E4M3 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::F8E4M3(v) + } }; Ok(storage) } @@ -2626,6 +2745,7 @@ impl BackendDevice for CpuDevice { DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ONE; elem_count]), DType::F32 => CpuStorage::F32(vec![1f32; elem_count]), DType::F64 => CpuStorage::F64(vec![1f64; elem_count]), }; @@ -2640,6 +2760,7 @@ impl BackendDevice for CpuDevice { DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), + DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]), DType::F32 => CpuStorage::F32(vec![0f32; elem_count]), DType::F64 => CpuStorage::F64(vec![0f64; elem_count]), }; diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 3e0c69b4f7..f61005a9b0 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -15,6 +15,7 @@ pub trait Map1 { C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), + C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)), } } } @@ -31,6 +32,7 @@ pub trait Map1Any { C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), + C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?), } } } @@ -48,6 +50,85 @@ pub trait Map2 { (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map3 { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + v3: &mut [T], + l3: &Layout, + s: Option, + ) -> Result<()>; + + #[allow(clippy::too_many_arguments)] + fn map( + &self, + v1: &C, + l1: &Layout, + v2: &C, + l2: &Layout, + v3: &mut C, + l3: &Layout, + s: Option, + ) -> Result<()> { + let v3d = v3.dtype(); + match (v1, v2, v3) { + (C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::U32(v1), C::U32(v2), C::U32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::I64(v1), C::I64(v2), C::I64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::BF16(v1), C::BF16(v2), C::BF16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + (C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), + _ => Err(Error::DTypeMismatchBinaryOp3 { + lhs: v1.dtype(), + rhs: v2.dtype(), + c: v3d, + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map2Alpha { + const OP: &'static str; + #[allow(clippy::too_many_arguments)] + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + s: Option, + ) -> Result>; + + #[allow(clippy::too_many_arguments)] + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, s: Option) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2, s)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2, s)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2, s)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2, s)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), @@ -71,6 +152,7 @@ pub trait Map2U8 { (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), _ => Err(Error::DTypeMismatchBinaryOp { lhs: v1.dtype(), rhs: v2.dtype(), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index d3bd29030e..cc597cc702 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -3,6 +3,7 @@ use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape}; pub use candle_kernels as kernels; pub use cudarc; use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use float8::F8E4M3; use half::{bf16, f16}; use std::sync::{Arc, Mutex}; @@ -136,6 +137,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f8_e4m3", kernels::FILL)?; + let params = (&data, v, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -243,6 +252,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -256,7 +269,8 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 + | DType::F8E4M3 => { Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", @@ -300,13 +314,17 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } + DType::U8 + | DType::U32 + | DType::I16 + | DType::I32 + | DType::I64 + | DType::F16 + | DType::BF16 => Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()?, DType::F32 => { let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; curand @@ -362,6 +380,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::F64(data) } + DType::F8E4M3 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -399,6 +421,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorageRef::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -436,6 +462,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, @@ -473,6 +503,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::F64(data) } + CpuStorage::F8E4M3(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F8E4M3(data) + } }; Ok(CudaStorage { slice, diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 2cd97c182e..8b90fc3e35 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -9,6 +9,7 @@ use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; +use float8::F8E4M3; use half::{bf16, f16}; #[cfg(feature = "cudnn")] @@ -54,6 +55,7 @@ pub enum CudaStorageSlice { F16(CudaSlice), F32(CudaSlice), F64(CudaSlice), + F8E4M3(CudaSlice), } struct Clone; @@ -1033,6 +1035,7 @@ cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); cuda_dtype!(f32, F32); cuda_dtype!(f64, F64); +cuda_dtype!(F8E4M3, F8E4M3); impl CudaStorage { pub fn wrap_cuda_slice(slice: CudaSlice, device: CudaDevice) -> CudaStorage { @@ -1155,6 +1158,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F16(_) => DType::F16, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, + CudaStorageSlice::F8E4M3(_) => DType::F8E4M3, } } @@ -1181,6 +1185,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F8E4M3(inp) => *inp.slice(start_o..).device_ptr(), }; let inp = &inp; @@ -1229,6 +1234,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::F64(out) } + DType::F8E4M3 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F8E4M3(out) + } }; Ok(Self { slice, @@ -1320,6 +1331,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::F64(cpu_storage)) } + CudaStorageSlice::F8E4M3(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::F8E4M3(cpu_storage)) + } } } @@ -1772,6 +1788,11 @@ impl BackendStorage for CudaStorage { *d.slice(dst_o..).device_ptr(), "copy2d_f64", ), + (S::F8E4M3(s), S::F8E4M3(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_f8_e4m3", + ), _ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?, }; let func = dev.get_or_load_func(kname, kernels::FILL)?; @@ -1829,6 +1850,18 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()? } } + (CudaStorageSlice::F8E4M3(src), CudaStorageSlice::F8E4M3(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_f8_e4m3", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { @@ -2084,3 +2117,163 @@ unsafe fn gemm_strided_batched_bf16( sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, ) } + +pub struct KVConcat { + pub concat_dim: usize, +} +impl crate::CustomOp2 for KVConcat { + fn name(&self) -> &'static str { + "kvconcat" + } + + fn cpu_fwd( + &self, + _: &CpuStorage, + _: &Layout, + _: &CpuStorage, + _: &Layout, + ) -> Result<(CpuStorage, Shape)> { + crate::bail!("no cpu support for kvconcat") + } + + fn cuda_fwd( + &self, + ltensor: &CudaStorage, + ltensor_l: &Layout, + rtensor: &CudaStorage, + rtensor_l: &Layout, + ) -> Result<(CudaStorage, Shape)> { + assert!(self.concat_dim == 2 || self.concat_dim == 0); //must be in the dim of sequence len + let dev = <ensor.device; + let elem_count = ltensor_l.shape().elem_count() + rtensor_l.shape().elem_count(); + let dims_l = ltensor_l.shape().dims(); + let dims_r = rtensor_l.shape().dims(); + let dim_size = dims_l.len(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + + let chunk_l = if dim_size > 3 { + dims_l[0] * dims_l[1] + } else { + dims_l[0] + }; + let chunk_r = if dim_size > 3 { + dims_r[0] * dims_r[1] + } else { + dims_r[0] + }; + let lstride = if dim_size > 3 { + dims_l[2] * dims_l[3] + } else { + dims_l[1] * dims_l[2] + }; + let rstride = if dim_size > 3 { + dims_r[2] * dims_r[3] + } else { + dims_r[1] * dims_r[2] + }; + + let slice = match (<ensor.slice, &rtensor.slice) { + (CudaStorageSlice::BF16(left_), CudaStorageSlice::BF16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_bf16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(out) + } + (CudaStorageSlice::F32(left_), CudaStorageSlice::F32(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f32", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(out) + } + (CudaStorageSlice::F16(left_), CudaStorageSlice::F16(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f16", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(out) + } + (CudaStorageSlice::F64(left_), CudaStorageSlice::F64(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_f64", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(out) + } + (CudaStorageSlice::U8(left_), CudaStorageSlice::U8(right_)) => { + let out = unsafe { dev.alloc::(elem_count).w()? }; + let func = dev.get_or_load_func("kvconcat_u8", kernels::KVCONCAT)?; + let params = ( + left_, + right_, + &out, + self.concat_dim, + chunk_l, + chunk_r, + lstride, + rstride, + ); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(out) + } + _ => Err(CudaError::InternalError("dtype mismatch in kvconcat op"))?, + }; + + let mut lshape: Vec = ltensor_l.shape().dims().to_vec(); + if self.concat_dim == 0 { + lshape[0] += rtensor_l.shape().dims()[0]; + } else { + if dim_size > 3 { + lshape[2] += rtensor_l.shape().dims()[2]; + } else { + lshape[1] += rtensor_l.shape().dims()[1]; + } + } + + let device = dev.clone(); + Ok(( + CudaStorage { + slice: slice, + device, + }, + lshape.into(), + )) + } +} diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index c1210727ad..e6bb92fe13 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -24,6 +24,7 @@ pub trait Map1 { S::F16(s) => S::F16(self.f(s, d, l)?), S::F32(s) => S::F32(self.f(s, d, l)?), S::F64(s) => S::F64(self.f(s, d, l)?), + S::F8E4M3(s) => S::F8E4M3(self.f(s, d, l)?), }; Ok(out) } @@ -48,6 +49,7 @@ pub trait Map2 { (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2)) => S::F8E4M3(self.f(s1, l1, s2, l2, d)?), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; Ok(out) @@ -86,6 +88,9 @@ pub trait Map3 { (S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?), (S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?), + (S::F8E4M3(s1), S::F8E4M3(s2), S::F8E4M3(s3)) => { + S::F8E4M3(self.f(s1, l1, s2, l2, s3, l3, d)?) + } _ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?, }; Ok(out) @@ -118,6 +123,7 @@ pub trait Map2InPlace { (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F8E4M3(dst), S::F8E4M3(src)) => self.f(dst, dst_s, src, src_l, d), _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, } } @@ -141,6 +147,7 @@ pub trait Map1Any { S::F16(s) => self.f(s, d, l, S::F16)?, S::F32(s) => self.f(s, d, l, S::F32)?, S::F64(s) => self.f(s, d, l, S::F64)?, + S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, }; Ok(out) } @@ -165,6 +172,7 @@ pub trait Map2Any { (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, }; Ok(out) diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 76d39010a9..cdc930615d 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -3,6 +3,7 @@ //! This implementation should be in line with the [PyTorch version](https://github.com/pytorch/pytorch/blob/7b419e8513a024e172eae767e24ec1b849976b13/torch/_tensor_str.py). //! use crate::{DType, Result, Tensor, WithDType}; +use float8::F8E4M3; use half::{bf16, f16}; impl Tensor { @@ -61,6 +62,7 @@ impl std::fmt::Debug for Tensor { DType::F16 => self.fmt_dt::(f), DType::F32 => self.fmt_dt::(f), DType::F64 => self.fmt_dt::(f), + DType::F8E4M3 => self.fmt_dt::(f), } } } @@ -498,6 +500,13 @@ impl std::fmt::Display for Tensor { writeln!(f)?; } } + DType::F8E4M3 => { + if let Ok(tf) = FloatFormatter::::new(&to_display, &po) { + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } + } }; let device_str = match self.device().location() { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index de6cddc3a3..94dc8c1062 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -1,11 +1,14 @@ //! Types for elements that can be stored and manipulated using tensors. #![allow(clippy::redundant_closure_call)] use crate::backend::BackendStorage; +use crate::cpu::kernels::VecOps; use crate::{CpuStorage, CpuStorageRef, Error, Result}; /// The different types of elements allowed in tensors. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum DType { + // Floating-point 8 bits integer (4-bit exponent, 3-bit mantissa). + F8E4M3, // Unsigned 8 bits integer. U8, // Unsigned 32 bits integer. @@ -44,6 +47,7 @@ impl std::str::FromStr for DType { "f16" => Ok(Self::F16), "f32" => Ok(Self::F32), "f64" => Ok(Self::F64), + "f8_e4m3" => Ok(Self::F8E4M3), _ => Err(DTypeParseError(s.to_string())), } } @@ -60,6 +64,7 @@ impl DType { Self::F16 => "f16", Self::F32 => "f32", Self::F64 => "f64", + Self::F8E4M3 => "f8_e4m3", } } @@ -67,6 +72,7 @@ impl DType { pub fn size_in_bytes(&self) -> usize { match self { Self::U8 => 1, + Self::F8E4M3 => 1, Self::U32 => 4, Self::I64 => 8, Self::BF16 => 2, @@ -79,14 +85,14 @@ impl DType { pub fn is_int(&self) -> bool { match self { Self::U8 | Self::U32 | Self::I64 => true, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => false, } } pub fn is_float(&self) -> bool { match self { Self::U8 | Self::U32 | Self::I64 => false, - Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, + Self::BF16 | Self::F16 | Self::F32 | Self::F64 | Self::F8E4M3 => true, } } } @@ -165,6 +171,7 @@ macro_rules! with_dtype { } }; } +use float8::F8E4M3; use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); @@ -174,6 +181,17 @@ with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); with_dtype!(f64, F64, |v: f64| v, |v: f64| v); +with_dtype!(F8E4M3, F8E4M3, |v: f64| F8E4M3::from_f64(v), |v: F8E4M3| v + .to_f64()); + +impl VecOps for F8E4M3 { + fn max(self, rhs: Self) -> Self { + F8E4M3::max(self, rhs) + } + fn min(self, rhs: Self) -> Self { + F8E4M3::min(self, rhs) + } +} pub trait IntDType: WithDType { fn is_true(&self) -> bool; @@ -213,3 +231,4 @@ impl FloatDType for f16 {} impl FloatDType for bf16 {} impl FloatDType for f32 {} impl FloatDType for f64 {} +impl FloatDType for F8E4M3 {} diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 70a512bc8e..27592ad9a4 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -103,6 +103,7 @@ impl BackendStorage for MetalStorage { DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), + DType::F8E4M3 => Ok(CpuStorage::F64(self.to_cpu()?)), } } @@ -1913,6 +1914,7 @@ impl BackendDevice for MetalDevice { DType::F16 => "fill_f16", DType::BF16 => "fill_bf16", DType::F32 => "fill_f32", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), DType::F64 => { let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; return self.storage_from_cpu_storage(&cpu_storage); @@ -1948,6 +1950,7 @@ impl BackendDevice for MetalDevice { CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::F8E4M3(_) => crate::bail!("Metal device does not yet support F8E4M3."), }; Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE)) } @@ -1961,6 +1964,7 @@ impl BackendDevice for MetalDevice { CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::F8E4M3(_) => crate::bail!("Metal device does not yet support F8E4M3."), }; Ok(Self::Storage::new( buffer?, diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 83e4f6527f..c2d06d4d19 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -27,11 +27,13 @@ //! ``` use crate::{DType, Device, Error, Result, Shape, Tensor}; use byteorder::{LittleEndian, ReadBytesExt}; +use float8::F8E4M3; use half::{bf16, f16, slice::HalfFloatSliceExt}; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; use std::path::Path; +use std::slice; const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY"; const NPY_SUFFIX: &str = ".npy"; @@ -88,6 +90,7 @@ impl Header { DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", + DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?, }; if !shape.is_empty() { shape.push(',') @@ -239,6 +242,13 @@ impl Tensor { reader.read_i64_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::F8E4M3 => { + let mut data_t = vec![F8E4M3::ZERO; elem_count]; + let ptr = data_t.as_mut_ptr().cast::(); + let len = data_t.len(); + reader.read_i8_into(unsafe { slice::from_raw_parts_mut(ptr, len) })?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } } } diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index c5fc3fc475..2170341a69 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -2,6 +2,7 @@ //! #![allow(clippy::redundant_closure_call)] use crate::Tensor; +use float8::F8E4M3; use half::{bf16, f16}; use num_traits::float::Float; @@ -189,6 +190,7 @@ pub trait UnaryOpT { fn f16(v1: f16) -> f16; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; + fn f8e4m3(v1: F8E4M3) -> F8E4M3; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; fn i64(v1: i64) -> i64; @@ -199,6 +201,8 @@ pub trait UnaryOpT { fn bf16_vec(_xs: &[bf16], _ys: &mut [bf16]) {} const F16_VEC: bool = false; fn f16_vec(_xs: &[f16], _ys: &mut [f16]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs: &[F8E4M3], _ys: &mut [F8E4M3]) {} const F32_VEC: bool = false; fn f32_vec(_xs: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; @@ -213,6 +217,7 @@ pub trait BinaryOpT { fn f16(v1: f16, v2: f16) -> f16; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; fn i64(v1: i64, v2: i64) -> i64; @@ -225,6 +230,8 @@ pub trait BinaryOpT { fn f32_vec(_xs1: &[f32], _xs2: &[f32], _ys: &mut [f32]) {} const F64_VEC: bool = false; fn f64_vec(_xs1: &[f64], _xs2: &[f64], _ys: &mut [f64]) {} + const F8E4M3_VEC: bool = false; + fn f8e4m3_vec(_xs1: &[F8E4M3], __xs2: &[F8E4M3], _ys: &mut [F8E4M3]) {} const U8_VEC: bool = false; fn u8_vec(_xs1: &[u8], _xs2: &[u8], _ys: &mut [u8]) {} const U32_VEC: bool = false; @@ -282,6 +289,10 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn f8e4m3(v1: F8E4M3, v2: F8E4M3) -> F8E4M3 { + $e(v1, v2) + } + #[inline(always)] fn u8(v1: u8, v2: u8) -> u8 { $e(v1, v2) } @@ -362,6 +373,10 @@ macro_rules! unary_op { $e } #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] fn f32($a: f32) -> f32 { $e } @@ -406,6 +421,10 @@ macro_rules! unary_op { $e } #[inline(always)] + fn f8e4m3($a: F8E4M3) -> F8E4M3 { + $e + } + #[inline(always)] fn u8(_: u8) -> u8 { todo!("no unary function for u8") } @@ -497,6 +516,17 @@ impl UnaryOpT for Gelu { )) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f32(0.5) + * v + * (F8E4M3::ONE + + F8E4M3::tanh( + F8E4M3::from_f32(SQRT_TWO_OVER_PI_F32) + * v + * (F8E4M3::ONE + F8E4M3::from_f32(0.044715) * v * v), + )) + } + #[inline(always)] fn f32(v: f32) -> f32 { 0.5 * v * (1.0 + f32::tanh(SQRT_TWO_OVER_PI_F32 * v * (1.0 + 0.044715 * v * v))) } @@ -570,6 +600,10 @@ impl UnaryOpT for Erf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } @@ -604,6 +638,10 @@ impl UnaryOpT for Silu { v / (f16::ONE + (-v).exp()) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v / (F8E4M3::ONE + (-v).exp()) + } + #[inline(always)] fn f32(v: f32) -> f32 { v / (1.0 + (-v).exp()) } @@ -675,6 +713,10 @@ impl UnaryOpT for Abs { v.abs() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.abs() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.abs() } @@ -709,6 +751,10 @@ impl UnaryOpT for Ceil { v.ceil() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.ceil() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.ceil() } @@ -743,6 +789,10 @@ impl UnaryOpT for Floor { v.floor() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.floor() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.floor() } @@ -777,6 +827,10 @@ impl UnaryOpT for Round { v.round() } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.round() + } + #[inline(always)] fn f32(v: f32) -> f32 { v.round() } @@ -811,6 +865,10 @@ impl UnaryOpT for GeluErf { f16::from_f64(Self::f64(v.to_f64())) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from_f64(Self::f64(v.to_f64())) + } + #[inline(always)] fn f32(v: f32) -> f32 { Self::f64(v as f64) as f32 } @@ -845,6 +903,10 @@ impl UnaryOpT for Relu { v.max(f16::ZERO) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + v.max(F8E4M3::ZERO) + } + #[inline(always)] fn f32(v: f32) -> f32 { v.max(0f32) } @@ -943,6 +1005,11 @@ impl UnaryOpT for Sign { f16::from((v > f16::ZERO) as i8) - f16::from((v < f16::ZERO) as i8) } #[inline(always)] + fn f8e4m3(v: F8E4M3) -> F8E4M3 { + F8E4M3::from((v > F8E4M3::ZERO) as i8 as f32) + - F8E4M3::from((v < F8E4M3::ZERO) as i8 as f32) + } + #[inline(always)] fn f32(v: f32) -> f32 { f32::from(v > 0.) - f32::from(v < 0.) } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index d402d6b8e0..67ca079155 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -10,6 +10,7 @@ //! `Tensor::save_safetensors` method. //! use crate::{DType, Device, Error, Result, Tensor, WithDType}; +use float8::F8E4M3; use safetensors::tensor as st; use safetensors::tensor::SafeTensors; use std::borrow::Cow; @@ -26,6 +27,7 @@ impl From for st::Dtype { DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, DType::F64 => st::Dtype::F64, + DType::F8E4M3 => st::Dtype::F8_E4M3, } } } @@ -41,6 +43,7 @@ impl TryFrom for DType { st::Dtype::F16 => Ok(DType::F16), st::Dtype::F32 => Ok(DType::F32), st::Dtype::F64 => Ok(DType::F64), + st::Dtype::F8_E4M3 => Ok(DType::F8E4M3), dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } } @@ -203,6 +206,7 @@ impl Tensor { DType::F16 => convert_slice::(data, shape, device), DType::F32 => convert_slice::(data, shape, device), DType::F64 => convert_slice::(data, shape, device), + DType::F8E4M3 => convert_slice::(data, shape, device), } } } @@ -239,6 +243,7 @@ fn convert_back(tensor: &Tensor) -> Result> { DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F64 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::F8E4M3 => Ok(convert_back_::(tensor.to_vec1()?)), } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 0ebb18357d..9f741da84f 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -52,6 +52,32 @@ impl ArgSort { } } +impl crate::CustomOp1 for ArgSort { + fn name(&self) -> &'static str { + "argsort" + } + + fn cpu_fwd( + &self, + storage: &crate::CpuStorage, + layout: &crate::Layout, + ) -> Result<(crate::CpuStorage, crate::Shape)> { + let sort_indexes = match storage { + crate::CpuStorage::U8(vs) => self.asort(vs, layout), + crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I16(vs) => self.asort(vs, layout), + crate::CpuStorage::I32(vs) => self.asort(vs, layout), + crate::CpuStorage::I64(vs) => self.asort(vs, layout), + crate::CpuStorage::BF16(vs) => self.asort(vs, layout), + crate::CpuStorage::F16(vs) => self.asort(vs, layout), + crate::CpuStorage::F32(vs) => self.asort(vs, layout), + crate::CpuStorage::F64(vs) => self.asort(vs, layout), + crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), + }; + let sort_indexes = crate::CpuStorage::U32(sort_indexes); + Ok((sort_indexes, layout.shape().into())) + } + #[cfg(feature = "cuda")] mod cuda { use super::*; @@ -154,6 +180,7 @@ impl crate::CustomOp1 for ArgSort { DType::U8 => "asort_asc_u8", DType::U32 => "asort_asc_u32", DType::I64 => "asort_asc_i64", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), } } else { match storage.dtype() { @@ -164,6 +191,7 @@ impl crate::CustomOp1 for ArgSort { DType::U8 => "asort_desc_u8", DType::U32 => "asort_desc_u32", DType::I64 => "asort_desc_i64", + DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."), } } }; diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 540d0819f5..ef75dffd36 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -1,7 +1,7 @@ #include "cuda_utils.cuh" #include -#define AFFINE_OP(TYPENAME, FN_NAME) \ +#define AFFINE_OP(TYPENAME, FN_NAME, AFFINE) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ const size_t num_dims, \ @@ -16,28 +16,34 @@ extern "C" __global__ void FN_NAME( \ if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ TYPENAME x = inp ? inp[i] : out[i]; \ - out[i] = x * mul + add; \ + out[i] = AFFINE; \ } \ } \ else { \ for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ TYPENAME x = inp ? inp[strided_i] : out[i]; \ - out[i] = x * mul + add; \ + out[i] = AFFINE; \ } \ } \ } \ #if __CUDA_ARCH__ >= 800 -AFFINE_OP(__nv_bfloat16, affine_bf16) +AFFINE_OP(__nv_bfloat16, affine_bf16, x * mul + add) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +AFFINE_OP(__nv_fp8_e4m3, affine_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(mul) + F8E4M3_TO_FLOAT(add))) #endif #if __CUDA_ARCH__ >= 530 -AFFINE_OP(__half, affine_f16) +AFFINE_OP(__half, affine_f16, x * mul + add) #endif -AFFINE_OP(float, affine_f32) -AFFINE_OP(double, affine_f64) -AFFINE_OP(uint8_t, affine_u8) -AFFINE_OP(uint32_t, affine_u32) -AFFINE_OP(int64_t, affine_i64) +AFFINE_OP(float, affine_f32, x * mul + add) +AFFINE_OP(double, affine_f64, x * mul + add) +AFFINE_OP(uint8_t, affine_u8, x * mul + add) +AFFINE_OP(uint32_t, affine_u32, x * mul + add) +AFFINE_OP(int16_t, affine_i16, x * mul + add) +AFFINE_OP(int32_t, affine_i32, x * mul + add) +AFFINE_OP(int64_t, affine_i64, x * mul + add) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index d44e3b20ee..971a2c433c 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -14,6 +14,21 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, lt_bf16, x < y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, le_bf16, x <= y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, gt_bf16, x > y) BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +BINARY_OP(__nv_fp8_e4m3, badd_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) + F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bdiv_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) / F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmul_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) * F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bsub_f8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x) - F8E4M3_TO_FLOAT(y))) +BINARY_OP(__nv_fp8_e4m3, bmaximum_f8_e4m3, maxg(x, y)) +BINARY_OP(__nv_fp8_e4m3, bminimum_f8_e4m3, ming(x, y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, eq_f8_e4m3, F8E4M3_TO_FLOAT(x) == F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ne_f8_e4m3, F8E4M3_TO_FLOAT(x) != F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, lt_f8_e4m3, F8E4M3_TO_FLOAT(x) < F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, le_f8_e4m3, F8E4M3_TO_FLOAT(x) <= F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, gt_f8_e4m3, F8E4M3_TO_FLOAT(x) > F8E4M3_TO_FLOAT(y)) +BINARY_OP_OUT(__nv_fp8_e4m3, uint8_t, ge_f8_e4m3, F8E4M3_TO_FLOAT(x) >= F8E4M3_TO_FLOAT(y)) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 90f5e7ba48..1b38f58e1c 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -24,6 +24,53 @@ __device__ void cast_( } } +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +template +__device__ void cast_fp8_( + const size_t numel, + const size_t num_dims, + const size_t *info, + const __nv_fp8_e4m3 *inp, + T *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = F8E4M3_TO_FLOAT(inp[i]); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = F8E4M3_TO_FLOAT(inp[strided_i]); + } + } +} +template +__device__ void cast_fp8_into_( + const size_t numel, + const size_t num_dims, + const size_t *info, + const S *inp, + __nv_fp8_e4m3 *out +) { + const size_t *dims = info; + const size_t *strides = info + num_dims; + if (info == nullptr || is_contiguous(num_dims, dims, strides)) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + out[i] = __nv_fp8_e4m3((float)inp[i]); + } + } + else { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); + out[i] = __nv_fp8_e4m3((float)inp[strided_i]); + } + } +} + template __device__ void cast_through( const size_t numel, @@ -59,6 +106,30 @@ extern "C" __global__ void FN_NAME( \ cast_(numel, num_dims, info, inp, out); \ } \ + +#define CAST_OP_FP8(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + cast_fp8_(numel, num_dims, info, inp, out); \ +} \ + + +#define CAST_OP_FP8_INTO(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const SRC_TYPENAME *inp, \ + DST_TYPENAME *out \ +) { \ + cast_fp8_into_(numel, num_dims, info, inp, out); \ +} \ + #define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ @@ -72,6 +143,7 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16) +CAST_OP(__nv_fp8_e4m3, __nv_fp8_e4m3, cast_f8_e4m3_f8_e4m3) CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32) CAST_OP(__nv_bfloat16, float, cast_bf16_f32) @@ -83,6 +155,19 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16) CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) + +CAST_OP_FP8(__nv_fp8_e4m3, float, cast_f8_e4m3_f32) +CAST_OP_FP8_INTO(float, __nv_fp8_e4m3, cast_f32_f8_e4m3) +CAST_OP_FP8(__nv_fp8_e4m3, uint8_t, cast_f8_e4m3_u8) +CAST_OP_FP8(__nv_fp8_e4m3, __half, cast_f8_e4m3_f16) +CAST_OP_FP8(__nv_fp8_e4m3, double, cast_f8_e4m3_f64) +CAST_OP_FP8_INTO(__half, __nv_fp8_e4m3, cast_f16_f8_e4m3) +CAST_OP_FP8_INTO(double, __nv_fp8_e4m3, cast_f64_f8_e4m3) +CAST_OP_FP8_INTO(uint8_t, __nv_fp8_e4m3, cast_u8_f8_e4m3) +CAST_OP_FP8_INTO(int32_t, __nv_fp8_e4m3, cast_i32_f8_e4m3) +CAST_OP_FP8(__nv_fp8_e4m3, int32_t, cast_f8_e4m3_i32) +CAST_OP_FP8(__nv_fp8_e4m3, __nv_bfloat16, cast_f8_e4m3_bf16) +CAST_OP_FP8_INTO(__nv_bfloat16, __nv_fp8_e4m3, cast_bf16_f8_e4m3) #else #include #if CUDA_VERSION >= 11000 @@ -94,6 +179,7 @@ CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) +CAST_THROUGH_OP(__nv_bfloat16, __nv_fp8_e4m3, float, cast_bf16_f8_e4m3) #endif #endif diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh index d0791749bb..1e4cf215c1 100644 --- a/candle-kernels/src/compatibility.cuh +++ b/candle-kernels/src/compatibility.cuh @@ -1,5 +1,6 @@ #include "cuda_fp16.h" #include "cuda_bf16.h" +#include "cuda_fp8.h" // Table showing which features are supported on which compute capability // https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index fa834faa3a..6ca6fd7c2b 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -702,6 +702,18 @@ UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) IM2COL_OP(__nv_bfloat16, im2col_bf16) IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16) + +// NOTE: No conv ops for f8 +// CONV1D_OP(__nv_bfloat16, float, conv1d_f8_e5m) +// CONV2D_OP(__nv_fp8_e4m3, float, conv2d_f8_e5m) +// CONVT1D_OP(__nv_fp8_e4m3, float, conv_transpose1d_f8_e5m) +// CONVT2D_OP(__nv_fp8_e4m3, float, conv_transpose2d_f8_e5m) +// AVG_POOL2D_OP(__nv_fp8_e4m3, float, avg_pool2d_f8_e5m) +// MAX_POOL2D_OP(__nv_fp8_e4m3, max_pool2d_f8_e5m) +// UPSAMPLE_NEAREST2D_OP(__nv_fp8_e4m3, upsample_nearest2d_f8_e5m) +// IM2COL_OP(__nv_fp8_e4m3, im2col_f8_e5m) +// IM2COL1D_OP(__nv_fp8_e4m3, im2col1d_f8_e5m) +// COL2IM1D_OP(__nv_fp8_e4m3, col2im1d_f8_e5m) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 2673b8aaf1..eb1400b4da 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -198,4 +198,27 @@ __device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); __device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); } __device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); } __device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); } + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +__device__ __forceinline__ __nv_fp8_e4m3 powg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(powf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ bool isnang(__nv_fp8_e4m3 a) { return isnan(F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 sqrtg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sqrtf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 cosg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(cosf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 sing(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(sinf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 recipg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(1. / F8E4M3_TO_FLOAT(a)); } +__device__ __forceinline__ __nv_fp8_e4m3 maxg(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fmaxf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 tanhg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(tanhf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 erfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(erff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ceilg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(ceilf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 floorg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(floorf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 roundg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(roundf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 normcdfg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(normcdff(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 ming(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(fminf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } +__device__ __forceinline__ __nv_fp8_e4m3 logg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(logf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 expg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(expf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 absg(__nv_fp8_e4m3 a) { return __nv_fp8_e4m3(fabsf(F8E4M3_TO_FLOAT(a))); } +__device__ __forceinline__ __nv_fp8_e4m3 copysigng(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b) { return __nv_fp8_e4m3(copysignf(F8E4M3_TO_FLOAT(a), F8E4M3_TO_FLOAT(b))); } + + #endif diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index ca448d989f..1b72a901f2 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -43,6 +43,11 @@ COPY2D_OP(__half, copy2d_f16) #if __CUDA_ARCH__ >= 800 #include +#include + extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__nv_bfloat16, copy2d_bf16) + +extern "C" __global__ void fill_f8_e4m3(__nv_fp8_e4m3 *buf, __nv_fp8_e4m3 value, const size_t numel) { fill_with(buf, value, numel); } +COPY2D_OP(__nv_fp8_e4m3, copy2d_f8_e4m3) #endif diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 8af2954d13..32cc4e9ad1 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -99,6 +99,57 @@ __device__ void index_add( } } +#if __CUDA_ARCH__ >= 800 +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +template +__device__ void scatter_add_f8( + const I *ids, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < src_dim_size; ++j) { + const size_t src_i = (pre * src_dim_size + j) * right_size + post; + const size_t idx = ids[src_i]; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} + +template +__device__ void index_add_f8( + const I *ids, + const size_t ids_dim_size, + const __nv_fp8_e4m3 *inp, + __nv_fp8_e4m3 *out, + const size_t left_size, + const size_t src_dim_size, + const size_t dst_dim_size, + const size_t right_size +) { + const size_t numel = left_size * right_size; + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + const size_t pre = i / right_size; + const size_t post = i % right_size; + for (unsigned int j = 0; j < ids_dim_size; ++j) { + const size_t idx = ids[j]; + const size_t src_i = (pre * ids_dim_size + j) * right_size + post; + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] = __nv_fp8_e4m3(F8E4M3_TO_FLOAT(out[dst_i]) + F8E4M3_TO_FLOAT(inp[src_i])); + } + } +} +#endif + #define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ extern "C" __global__ void FN_NAME( \ const INDEX_TYPENAME *ids, \ @@ -111,6 +162,18 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +#define IA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const size_t ids_dim_size, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { index_add_f8(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + template __device__ void scatter_add( const I *ids, @@ -145,6 +208,17 @@ extern "C" __global__ void FN_NAME( \ const size_t right_size \ ) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +#define SA_OP_F8(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +extern "C" __global__ void FN_NAME( \ + const INDEX_TYPENAME *ids, \ + const TYPENAME *inp, \ + TYPENAME *out, \ + const size_t left_size, \ + const size_t src_dim_size, \ + const size_t dst_dim_size, \ + const size_t right_size \ +) { scatter_add_f8(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ + #if __CUDA_ARCH__ >= 800 IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) @@ -159,6 +233,27 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) + +IS_OP(__nv_fp8_e4m3, int16_t, is_i16_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int32_t, is_i32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, int64_t, is_i64_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint32_t, is_u32_f8_e4m3) +IS_OP(__nv_fp8_e4m3, uint8_t, is_u8_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int16_t, gather_i16_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int32_t, gather_i32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, int64_t, gather_i64_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint32_t, gather_u32_f8_e4m3) +GATHER_OP(__nv_fp8_e4m3, uint8_t, gather_u8_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int16_t, ia_i16_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int32_t, ia_i32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, int64_t, ia_i64_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint32_t, ia_u32_f8_e4m3) +IA_OP_F8(__nv_fp8_e4m3, uint8_t, ia_u8_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int16_t, sa_i16_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int32_t, sa_i32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, int64_t, sa_i64_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint32_t, sa_u32_f8_e4m3) +SA_OP_F8(__nv_fp8_e4m3, uint8_t, sa_u8_f8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 079c370873..2738d8254e 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -578,6 +578,14 @@ LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) SUM_OP(__nv_bfloat16, sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) + +// NOTE: No reduce ops for f8 +// SUM_OP(__nv_fp8_e4m3, sum_fp8_e4m3) +// SOFTMAX_OP(__nv_fp8_e4m3, float, softmax_fp8_e4m3) +// RMSNORM_OP(__nv_fp8_e4m3, rmsnorm_fp8_e4m3) +// LAYERNORM_OP(__nv_fp8_e4m3, layernorm_fp8_e4m3) +// ROPE_OP(__nv_fp8_e4m3, rope_fp8_e4m3, rope_i_fp8_e4m3, rope_thd_fp8_e4m3) +// FAST_OP(__nv_fp8_e4m3, fast_min_fp8_e4m3, fast_max_fp8_e4m3, fast_argmin_fp8_e4m3, fast_argmax_fp8_e4m3, fast_sum_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 08f1f9fc29..a7ad4f79c4 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -75,6 +75,9 @@ extern "C" __global__ void asort_desc_##RUST_NAME( \ #if __CUDA_ARCH__ >= 800 ASORT_OP(__nv_bfloat16, bf16) + +// NOTE: No sort ops for f8 +// ASORT_OP(__nv_fp8_e4m3, fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index aaa8a881fb..ef4009e3e0 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -36,6 +36,12 @@ extern "C" __global__ void FN_NAME( \ WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) + +WHERE_OP(__nv_fp8_e4m3, int16_t, where_i16_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int32_t, where_i32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, int64_t, where_i64_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint32_t, where_u32_fp8_e4m3) +WHERE_OP(__nv_fp8_e4m3, uint8_t, where_u8_fp8_e4m3) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c82a88375d..5fcb5e2b1a 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -122,6 +122,33 @@ UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x)) UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param)) UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) + +#define F8E4M3_TO_FLOAT(x) __half2float(__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)) + +UNARY_OP(__nv_fp8_e4m3, ucopy_f8_e4m3, x) +UNARY_OP(__nv_fp8_e4m3, uneg_fp8_e4m3, __nv_fp8_e4m3(-F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, urecip_fp8_e4m3, recipg(x)) +UNARY_OP(__nv_fp8_e4m3, uexp_fp8_e4m3, expg(x)) +UNARY_OP(__nv_fp8_e4m3, ulog_fp8_e4m3, logg(x)) +UNARY_OP(__nv_fp8_e4m3, usin_fp8_e4m3, sing(x)) +UNARY_OP(__nv_fp8_e4m3, ucos_fp8_e4m3, cosg(x)) +UNARY_OP(__nv_fp8_e4m3, utanh_fp8_e4m3, tanhg(x)) +UNARY_OP(__nv_fp8_e4m3, uerf_fp8_e4m3, erfg(x)) +UNARY_OP(__nv_fp8_e4m3, uceil_fp8_e4m3, ceilg(x)) +UNARY_OP(__nv_fp8_e4m3, ufloor_fp8_e4m3, floorg(x)) +UNARY_OP(__nv_fp8_e4m3, uround_fp8_e4m3, roundg(x)) +UNARY_OP(__nv_fp8_e4m3, unormcdf_fp8_e4m3, normcdfg(x)) +UNARY_OP(__nv_fp8_e4m3, uabs_fp8_e4m3, absg(x)) +UNARY_OP(__nv_fp8_e4m3, usqr_fp8_e4m3, __nv_fp8_e4m3(F8E4M3_TO_FLOAT(x)*F8E4M3_TO_FLOAT(x))) +UNARY_OP(__nv_fp8_e4m3, usqrt_fp8_e4m3, sqrtg(x)) +UNARY_OP(__nv_fp8_e4m3, ugelu_fp8_e4m3, __nv_fp8_e4m3(gelu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, ugelu_erf_fp8_e4m3, __nv_fp8_e4m3(gelu_erf_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, urelu_fp8_e4m3, __nv_fp8_e4m3(relu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, uelu_fp8_e4m3, __nv_fp8_e4m3(elu_fwd(F8E4M3_TO_FLOAT(x), F8E4M3_TO_FLOAT(param)))) +UNARY_OP(__nv_fp8_e4m3, usilu_fp8_e4m3, __nv_fp8_e4m3(silu_fwd(F8E4M3_TO_FLOAT(x)))) +UNARY_OP1(__nv_fp8_e4m3, upowf_fp8_e4m3, powg(x, param)) +UNARY_OP(__nv_fp8_e4m3, usign_fp8_e4m3, __nv_fp8_e4m3(sign_(F8E4M3_TO_FLOAT(x)))) +UNARY_OP(__nv_fp8_e4m3, usigmoid_fp8_e4m3, __nv_fp8_e4m3(sigmoid_fwd(F8E4M3_TO_FLOAT(x)))) #endif #if __CUDA_ARCH__ >= 530 diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index d91619fbb3..e884381c26 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -19,6 +19,7 @@ candle = { workspace = true } candle-nn = { workspace = true } candle-onnx = { workspace = true, optional = true } half = { workspace = true } +float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py311"] } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index b8695cc8a0..85bf7c4710 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::redundant_closure_call)] +use float8::F8E4M3; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; @@ -157,6 +158,7 @@ pydtype!(f16, f32::from); pydtype!(bf16, f32::from); pydtype!(f32, |v| v); pydtype!(f64, |v| v); +pydtype!(F8E4M3, f32::from); fn actual_index(t: &Tensor, dim: usize, index: i64) -> ::candle::Result { let dim = t.dim(dim)?; @@ -204,6 +206,7 @@ trait MapDType { DType::F16 => self.f::(t), DType::F32 => self.f::(t), DType::F64 => self.f::(t), + DType::F8E4M3 => self.f::(t), } } } diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index df1de0b276..269c2a02da 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -28,6 +28,7 @@ pub mod colpali; pub mod convmixer; pub mod convnext; pub mod dac; +pub mod deepseekv3; pub mod depth_anything_v2; pub mod dinov2; pub mod dinov2reg4; From dcfc5632de83cb4481018896a0891518cef01e2f Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 27 Jan 2025 13:06:50 -0500 Subject: [PATCH 3/8] Add ops --- .../src/models/deepseekv3/model.rs | 45 +---- .../src/models/deepseekv3/ops.rs | 163 ++++++++++++++++++ 2 files changed, 164 insertions(+), 44 deletions(-) create mode 100644 candle-transformers/src/models/deepseekv3/ops.rs diff --git a/candle-transformers/src/models/deepseekv3/model.rs b/candle-transformers/src/models/deepseekv3/model.rs index 55971cf36f..79ef3406f2 100644 --- a/candle-transformers/src/models/deepseekv3/model.rs +++ b/candle-transformers/src/models/deepseekv3/model.rs @@ -33,50 +33,6 @@ impl NonZero { } } -#[cfg(feature = "cuda")] -fn count_nonzero_cuda(dtype: candle::DType, d_in: *const c_void, n: u32) -> u32 { - unsafe { - match dtype { - candle::DType::U8 => ffi::count_nonzero_u8(d_in, n), - candle::DType::U32 => ffi::count_nonzero_u32(d_in, n), - candle::DType::I64 => ffi::count_nonzero_i64(d_in, n), - candle::DType::I16 => ffi::count_nonzero_i16(d_in, n), - candle::DType::I32 => ffi::count_nonzero_i32(d_in, n), - candle::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), - candle::DType::F16 => ffi::count_nonzero_f16(d_in, n), - candle::DType::F32 => ffi::count_nonzero_f32(d_in, n), - candle::DType::F64 => ffi::count_nonzero_f64(d_in, n), - candle::DType::F8E4M3 => todo!(), - } - } -} - -#[cfg(feature = "cuda")] -fn nonzero_cuda( - dtype: candle::DType, - d_in: *const c_void, - n: u32, - num_nonzero: u32, - dims: *const c_void, - num_dims: u32, - d_out: *mut c_void, -) { - unsafe { - match dtype { - candle::DType::U8 => ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::U32 => ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::I64 => ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::I32 => ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::I16 => ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::BF16 => ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::F16 => ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::F32 => ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::F64 => ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out), - candle::DType::F8E4M3 => todo!(), - } - } -} - impl CustomOp1 for NonZero { fn name(&self) -> &'static str { "nonzero" @@ -94,6 +50,7 @@ impl CustomOp1 for NonZero { candle::CpuStorage::F16(vs) => self.nonzero(vs, layout), candle::CpuStorage::F32(vs) => self.nonzero(vs, layout), candle::CpuStorage::F64(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F8E4M3(vs) => self.nonzero(vs, layout), }; let index_len = layout.dims().len(); let result_len = result.len() / index_len; diff --git a/candle-transformers/src/models/deepseekv3/ops.rs b/candle-transformers/src/models/deepseekv3/ops.rs new file mode 100644 index 0000000000..44090af2f7 --- /dev/null +++ b/candle-transformers/src/models/deepseekv3/ops.rs @@ -0,0 +1,163 @@ +use candle::{CpuStorage, CustomOp2, DType, Result, Tensor, WithDType}; +use float8::F8E4M3; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +struct Fp8BlockwiseDequantize { + weight_block_size: Vec, + out_ty: DType, +} + +impl Fp8BlockwiseDequantize { + fn dispatch_dequant_blockwise( + &self, + weight: &[F8E4M3], + scale: &[f32], + weight_l: &candle::Layout, + scale_l: &candle::Layout, + ) -> candle::Result> { + let grid_y = weight_l.dim(0)?.div_ceil(self.weight_block_size[0]); + let grid_x = weight_l.dim(1)?.div_ceil(self.weight_block_size[1]); + + let res = vec![T::zero(); weight.len()]; + + (0..grid_y).into_par_iter().for_each(|y| { + (0..grid_x).into_par_iter().for_each(|x| { + let res_ptr = res.as_ptr() as *mut T; + + let scale = scale[y * scale_l.stride()[0] + x]; + + let start_y = y * self.weight_block_size[0]; + let end_y = start_y + self.weight_block_size[0]; + + let start_x = x * self.weight_block_size[1]; + let end_x = start_x + self.weight_block_size[1]; + + for weight_y in start_y..end_y { + if weight_y >= weight_l.dims()[0] { + break; + } + + let row_offset = weight_y * weight_l.stride()[0]; + for weight_x in start_x..end_x { + if weight_x >= weight_l.dims()[1] { + break; + } + + let weight_pos = row_offset + weight_x; + + // SAFETY: We know each thread will only update indepedant values! + unsafe { + *res_ptr.wrapping_add(weight_pos) = + T::from_f64((weight[weight_pos].to_f32() * scale) as f64); + } + } + } + }); + }); + + Ok(res) + } +} + +impl CustomOp2 for Fp8BlockwiseDequantize { + fn name(&self) -> &'static str { + "fp8-blockwise-dequantize" + } + + fn cpu_fwd( + &self, + scale_s: &candle::CpuStorage, + scale_l: &candle::Layout, + weight_s: &candle::CpuStorage, + weight_l: &candle::Layout, + ) -> candle::Result<(candle::CpuStorage, candle::Shape)> { + let candle::CpuStorage::F8E4M3(weight) = weight_s else { + candle::bail!("Expected F8E4M3 weight!"); + }; + let candle::CpuStorage::F32(scale) = scale_s else { + candle::bail!("Expected F8E4M3 weight!"); + }; + if weight_l.start_offset() != 0 || !weight_l.is_contiguous() { + candle::bail!("Expected weight to have start offset 0, continuous"); + } + if scale_l.start_offset() != 0 || !scale_l.is_contiguous() { + candle::bail!("Expected scales to have start offset 0, continuous"); + } + if weight_l.dims().len() != 2 { + candle::bail!("Expected weight to be rank 2"); + } + if scale_l.dims().len() != 2 || self.weight_block_size.len() != 2 { + candle::bail!("Expected scale to be rank 2"); + } + + match self.out_ty { + DType::F32 => Ok(( + CpuStorage::F32(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?), + weight_l.shape().clone(), + )), + DType::BF16 => Ok(( + CpuStorage::BF16( + self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?, + ), + weight_l.shape().clone(), + )), + DType::F16 => Ok(( + CpuStorage::F16(self.dispatch_dequant_blockwise(weight, scale, weight_l, scale_l)?), + weight_l.shape().clone(), + )), + other => candle::bail!("unexpected out type of fp8 blockwise dequant {other:?}"), + } + } +} + +/// FP8 blockwise dequantize. +/// - Expects weight to be fp8 +/// - Expects inv_scales to be f32 +/// - weight * inv_scale = dequantized +/// - Only works on the CPU +pub fn fp8_blockwise_dequantize( + weight: &Tensor, + inv_scales: &Tensor, + weight_block_size: Vec, + out_ty: DType, +) -> Result { + inv_scales.apply_op2_no_bwd( + weight, + &Fp8BlockwiseDequantize { + weight_block_size, + out_ty, + }, + ) +} + +#[cfg(test)] +mod tests { + use candle::{DType, Device, Result, Tensor}; + + use crate::models::deepseekv3::ops; + + #[test] + fn test_fp8_blockwise_dequant() -> Result<()> { + let dev = &Device::Cpu; + let weight = Tensor::ones((5, 5), DType::F8E4M3, dev)?; + let weight_block_size = vec![2, 2]; + let inv_scales = Tensor::arange(0f32, (3 * 3) as f32, dev)?.reshape((3, 3))?; + + let dequant = + ops::fp8_blockwise_dequantize(&weight, &inv_scales, weight_block_size, DType::F32)?; + + let res = dequant.to_vec2::()?; + assert_eq!( + res, + vec![ + vec![0., 0., 1., 1., 2.], + vec![0., 0., 1., 1., 2.], + vec![3., 3., 4., 4., 5.], + vec![3., 3., 4., 4., 5.], + vec![6., 6., 7., 7., 8.], + ] + ); + + Ok(()) + } +} From 28cb4bb3a5075d4d12d76b705e249031fc9b1350 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 27 Jan 2025 13:34:45 -0500 Subject: [PATCH 4/8] Add blockwise fp8 linear --- candle-core/src/cpu_backend/utils.rs | 78 ------------ candle-core/src/cuda_backend/device.rs | 3 +- candle-core/src/sort.rs | 36 +----- candle-transformers/Cargo.toml | 1 + .../src/models/deepseekv3/mod.rs | 3 +- .../src/models/deepseekv3/model.rs | 89 ++++++++++---- .../src/models/deepseekv3/quant.rs | 116 ++++++++++++++++++ 7 files changed, 187 insertions(+), 139 deletions(-) create mode 100644 candle-transformers/src/models/deepseekv3/quant.rs diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index f61005a9b0..6a7e32b00e 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -61,84 +61,6 @@ pub trait Map2 { } } -pub trait Map3 { - const OP: &'static str; - #[allow(clippy::too_many_arguments)] - fn f( - &self, - v1: &[T], - l1: &Layout, - v2: &[T], - l2: &Layout, - v3: &mut [T], - l3: &Layout, - s: Option, - ) -> Result<()>; - - #[allow(clippy::too_many_arguments)] - fn map( - &self, - v1: &C, - l1: &Layout, - v2: &C, - l2: &Layout, - v3: &mut C, - l3: &Layout, - s: Option, - ) -> Result<()> { - let v3d = v3.dtype(); - match (v1, v2, v3) { - (C::U8(v1), C::U8(v2), C::U8(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), - (C::U32(v1), C::U32(v2), C::U32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), - (C::I64(v1), C::I64(v2), C::I64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), - (C::BF16(v1), C::BF16(v2), C::BF16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), - (C::F16(v1), C::F16(v2), C::F16(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), - (C::F32(v1), C::F32(v2), C::F32(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), - (C::F64(v1), C::F64(v2), C::F64(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), - (C::F8E4M3(v1), C::F8E4M3(v2), C::F8E4M3(v3)) => Ok(self.f(v1, l1, v2, l2, v3, l3, s)?), - _ => Err(Error::DTypeMismatchBinaryOp3 { - lhs: v1.dtype(), - rhs: v2.dtype(), - c: v3d, - op: Self::OP, - } - .bt()), - } - } -} - -pub trait Map2Alpha { - const OP: &'static str; - #[allow(clippy::too_many_arguments)] - fn f( - &self, - v1: &[T], - l1: &Layout, - v2: &[T], - l2: &Layout, - s: Option, - ) -> Result>; - - #[allow(clippy::too_many_arguments)] - fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout, s: Option) -> Result { - match (v1, v2) { - (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2, s)?)), - (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2, s)?)), - (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2, s)?)), - (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2, s)?)), - (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2, s)?)), - (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2, s)?)), - (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2, s)?)), - _ => Err(Error::DTypeMismatchBinaryOp { - lhs: v1.dtype(), - rhs: v2.dtype(), - op: Self::OP, - } - .bt()), - } - } -} - pub trait Map2U8 { const OP: &'static str; fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index cc597cc702..d9ec3e28ef 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -269,8 +269,7 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 - | DType::F8E4M3 => { + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 | DType::F8E4M3 => { Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 9f741da84f..7c84701bc6 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -52,32 +52,6 @@ impl ArgSort { } } -impl crate::CustomOp1 for ArgSort { - fn name(&self) -> &'static str { - "argsort" - } - - fn cpu_fwd( - &self, - storage: &crate::CpuStorage, - layout: &crate::Layout, - ) -> Result<(crate::CpuStorage, crate::Shape)> { - let sort_indexes = match storage { - crate::CpuStorage::U8(vs) => self.asort(vs, layout), - crate::CpuStorage::U32(vs) => self.asort(vs, layout), - crate::CpuStorage::I16(vs) => self.asort(vs, layout), - crate::CpuStorage::I32(vs) => self.asort(vs, layout), - crate::CpuStorage::I64(vs) => self.asort(vs, layout), - crate::CpuStorage::BF16(vs) => self.asort(vs, layout), - crate::CpuStorage::F16(vs) => self.asort(vs, layout), - crate::CpuStorage::F32(vs) => self.asort(vs, layout), - crate::CpuStorage::F64(vs) => self.asort(vs, layout), - crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), - }; - let sort_indexes = crate::CpuStorage::U32(sort_indexes); - Ok((sort_indexes, layout.shape().into())) - } - #[cfg(feature = "cuda")] mod cuda { use super::*; @@ -139,6 +113,7 @@ impl crate::CustomOp1 for ArgSort { crate::CpuStorage::F16(vs) => self.asort(vs, layout), crate::CpuStorage::F32(vs) => self.asort(vs, layout), crate::CpuStorage::F64(vs) => self.asort(vs, layout), + crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout), }; let sort_indexes = crate::CpuStorage::U32(sort_indexes); Ok((sort_indexes, layout.shape().into())) @@ -224,15 +199,6 @@ impl crate::CustomOp1 for ArgSort { } } -#[allow(unused)] -fn next_power_of_2(x: usize) -> usize { - let mut n = 1; - while n < x { - n *= 2 - } - n -} - impl Tensor { /// Returns the indices that sort the tensor along the last dimension. /// diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 6589b4b146..c9d4860e74 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -24,6 +24,7 @@ serde = { workspace = true } serde_json = { workspace = true } serde_plain = { workspace = true } tracing = { workspace = true } +float8 = { workspace = true } [features] default = [] diff --git a/candle-transformers/src/models/deepseekv3/mod.rs b/candle-transformers/src/models/deepseekv3/mod.rs index 766b13284b..bec5297816 100644 --- a/candle-transformers/src/models/deepseekv3/mod.rs +++ b/candle-transformers/src/models/deepseekv3/mod.rs @@ -1,2 +1,3 @@ -mod ops; pub mod model; +mod ops; +mod quant; diff --git a/candle-transformers/src/models/deepseekv3/model.rs b/candle-transformers/src/models/deepseekv3/model.rs index 79ef3406f2..cf7dfc4b56 100644 --- a/candle-transformers/src/models/deepseekv3/model.rs +++ b/candle-transformers/src/models/deepseekv3/model.rs @@ -6,10 +6,12 @@ use candle::{ shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Result, Shape, Tensor, WithDType, D, }; -use candle_nn::{embedding, rms_norm, Activation, Embedding, Linear, Module, RmsNorm, VarBuilder}; +use candle_nn::{embedding, rms_norm, Activation, Embedding, Module, RmsNorm, VarBuilder}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use serde::Deserialize; +use super::quant::{self, BlockwiseFP8Linear, QuantizedConfig}; + struct NonZero {} impl NonZero { @@ -272,6 +274,7 @@ pub struct DeepSeekV2Config { pub(crate) qk_nope_head_dim: usize, pub(crate) n_group: usize, pub(crate) topk_group: usize, + pub(crate) quantization_config: Option, } #[derive(Debug, Clone, Deserialize)] @@ -285,7 +288,7 @@ pub enum ScaledRopeType { #[serde(alias = "dynamic")] Dynamic, #[serde(alias = "linear")] - Linear, + BlockwiseFP8Linear, } impl FromStr for ScaledRopeType { @@ -294,7 +297,7 @@ impl FromStr for ScaledRopeType { match s { "su" | "longrope" => Ok(Self::Su), "yarn" => Ok(Self::Yarn), - "linear" => Ok(Self::Linear), + "linear" => Ok(Self::BlockwiseFP8Linear), "dynamic" => Ok(Self::Dynamic), _ => Err(candle::Error::Msg( "Expected either `su` or `yarn` scaled RoPE type.".to_string(), @@ -322,7 +325,7 @@ pub enum DeepSeekV2RopeScaling { #[serde(rename = "type")] scaling_type: ScaledRopeType, }, - LinearOrDynamic { + BlockwiseFP8LinearOrDynamic { #[serde(rename = "type")] scaling_type: ScaledRopeType, factor: f64, @@ -452,7 +455,7 @@ impl DeepSeekV2RotaryEmbedding { pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result { match &cfg.rope_scaling { - Some(DeepSeekV2RopeScaling::LinearOrDynamic { + Some(DeepSeekV2RopeScaling::BlockwiseFP8LinearOrDynamic { scaling_type: _, factor: _, }) => candle::bail!("linear and dynamic rope are not implemented yet!"), @@ -518,8 +521,12 @@ impl DeepSeekV2Config { } enum QProj { - Plain(Linear), - Lora { a: Linear, norm: RmsNorm, b: Linear }, + Plain(BlockwiseFP8Linear), + Lora { + a: BlockwiseFP8Linear, + norm: RmsNorm, + b: BlockwiseFP8Linear, + }, } impl QProj { @@ -533,10 +540,10 @@ impl QProj { struct Attention { q: QProj, - kv_a_proj_with_mqa: Linear, + kv_a_proj_with_mqa: BlockwiseFP8Linear, kv_a_layernorm: RmsNorm, - kv_b_proj: Linear, - o_proj: Linear, + kv_b_proj: BlockwiseFP8Linear, + o_proj: BlockwiseFP8Linear, rotary_emb: Arc, cfg: DeepSeekV2Config, q_head_dim: usize, @@ -552,44 +559,59 @@ impl Attention { let q_head_dim = cfg.q_head_dim(); let q = match cfg.q_lora_rank { Some(lora_rank) => { - let a = candle_nn::linear_b( + let a = quant::blockwise_fp8_linear_b( cfg.hidden_size, lora_rank, + &cfg.quantization_config, cfg.attention_bias, + None, vb.pp("q_a_proj"), )?; let norm = rms_norm(lora_rank, cfg.rms_norm_eps, vb.pp("q_a_layernorm"))?; - let b = candle_nn::linear_no_bias( + let b = quant::blockwise_fp8_linear_b( lora_rank, cfg.num_attention_heads * q_head_dim, + &cfg.quantization_config, + false, + None, vb.pp("q_b_proj"), )?; QProj::Lora { a, norm, b } } - None => QProj::Plain(candle_nn::linear_no_bias( + None => QProj::Plain(quant::blockwise_fp8_linear_b( cfg.hidden_size, cfg.num_attention_heads * q_head_dim, + &cfg.quantization_config, + false, + None, vb.pp("q_proj"), )?), }; - let kv_a_proj_with_mqa = candle_nn::linear_b( + let kv_a_proj_with_mqa = quant::blockwise_fp8_linear_b( cfg.hidden_size, cfg.kv_lora_rank + cfg.qk_rope_head_dim, + &cfg.quantization_config, cfg.attention_bias, + None, vb.pp("kv_a_proj_with_mqa"), )?; let kv_a_layernorm = rms_norm(cfg.kv_lora_rank, cfg.rms_norm_eps, vb.pp("kv_a_layernorm"))?; - let kv_b_proj = candle_nn::linear_no_bias( + let kv_b_proj = quant::blockwise_fp8_linear_b( cfg.kv_lora_rank, cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim), + &cfg.quantization_config, + false, + None, vb.pp("kv_b_proj"), )?; - let o_proj = candle_nn::linear_b( + let o_proj = quant::blockwise_fp8_linear_b( cfg.num_attention_heads * cfg.v_head_dim, cfg.hidden_size, + &cfg.quantization_config, cfg.attention_bias, + None, vb.pp("o_proj"), )?; @@ -725,9 +747,9 @@ impl Attention { } struct Mlp { - gate: Linear, - up: Linear, - down: Linear, + gate: BlockwiseFP8Linear, + up: BlockwiseFP8Linear, + down: BlockwiseFP8Linear, act: Activation, } @@ -742,9 +764,30 @@ impl Mlp { let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size); Ok(Self { - gate: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?, - up: candle_nn::linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?, - down: candle_nn::linear_no_bias(intermediate_size, hidden_size, vb.pp("down_proj"))?, + gate: quant::blockwise_fp8_linear_b( + hidden_size, + intermediate_size, + &cfg.quantization_config, + false, + None, + vb.pp("gate_proj"), + )?, + up: quant::blockwise_fp8_linear_b( + hidden_size, + intermediate_size, + &cfg.quantization_config, + false, + None, + vb.pp("up_proj"), + )?, + down: quant::blockwise_fp8_linear_b( + intermediate_size, + hidden_size, + &cfg.quantization_config, + false, + None, + vb.pp("down_proj"), + )?, act: cfg.hidden_act, }) } @@ -1045,7 +1088,7 @@ impl DecoderLayer { } pub struct DeepSeekV2 { - lm_head: Linear, + lm_head: candle_nn::Linear, embed_tokens: Embedding, norm: RmsNorm, layers: Vec, diff --git a/candle-transformers/src/models/deepseekv3/quant.rs b/candle-transformers/src/models/deepseekv3/quant.rs new file mode 100644 index 0000000000..9bae465769 --- /dev/null +++ b/candle-transformers/src/models/deepseekv3/quant.rs @@ -0,0 +1,116 @@ +use candle::{ + quantized::{GgmlDType, QMatMul, QTensor}, + DType, Module, Result, Tensor, +}; +use candle_nn::{Linear, VarBuilder}; +use serde::Deserialize; + +use super::ops; + +#[derive(Debug, Clone, Deserialize)] +pub enum QuantMethodType { + #[serde(rename = "fp8")] + Fp8, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct QuantizedConfig { + pub weight_block_size: Option>, + pub quant_method: QuantMethodType, +} + +pub enum BlockwiseFP8Linear { + Quantized { w: QMatMul, b: Option }, + Unquantized(Linear), +} + +impl Module for BlockwiseFP8Linear { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Quantized { w, b } => { + let in_ty = xs.dtype(); + let mut xs = xs.to_dtype(DType::F32)?; + + xs = w.forward(&xs)?; + if let Some(bias) = b { + xs = xs.broadcast_add(bias)?; + } + + xs.to_dtype(in_ty) + } + Self::Unquantized(l) => xs.apply(l), + } + } +} + +/// Load a blockwise quantized FP8 layer and optionally quantize it in-place for faster inference. +pub fn blockwise_fp8_linear_b( + in_dim: usize, + out_dim: usize, + config: &Option, + bias: bool, + quant: Option, + vb: VarBuilder, +) -> Result { + let Some(config) = config else { + return Ok(BlockwiseFP8Linear::Unquantized(candle_nn::linear_b( + in_dim, out_dim, bias, vb, + )?)); + }; + + if !matches!(config.quant_method, QuantMethodType::Fp8) { + candle::bail!("Expected FP8 quant method!"); + } + + let weight_block_size = config + .weight_block_size + .as_ref() + .expect("Blockwise FP8 requires weight_block_size in config"); + if weight_block_size.len() != 2 { + candle::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}") + } + let weight = vb.get_with_hints_dtype( + (out_dim, in_dim), + "weight", + Default::default(), + DType::F8E4M3, + )?; + let weight_scale_inv = vb.get_with_hints_dtype( + ( + out_dim.div_ceil(weight_block_size[0]), + in_dim.div_ceil(weight_block_size[1]), + ), + "weight_scale_inv", + Default::default(), + DType::F32, + )?; + + let bias_ty = if quant.is_some() { + DType::F32 + } else { + vb.dtype() + }; + + let bias = if bias { + Some(vb.get((out_dim,), "bias")?.to_dtype(bias_ty)?) + } else { + None + }; + + let dequant = ops::fp8_blockwise_dequantize( + &weight, + &weight_scale_inv, + weight_block_size.to_vec(), + vb.dtype(), + )?; + + let layer = match quant { + Some(q) => BlockwiseFP8Linear::Quantized { + w: QMatMul::from_qtensor(QTensor::quantize(&dequant, q)?)?, + b: bias, + }, + None => BlockwiseFP8Linear::Unquantized(Linear::new(dequant, None)), + }; + + Ok(layer) +} From e91d89a0914691fece34c8c43cdfe7aa4cc8109d Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 27 Jan 2025 13:36:40 -0500 Subject: [PATCH 5/8] Expose quant setting --- .../src/models/deepseekv3/model.rs | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/candle-transformers/src/models/deepseekv3/model.rs b/candle-transformers/src/models/deepseekv3/model.rs index cf7dfc4b56..da89d6af25 100644 --- a/candle-transformers/src/models/deepseekv3/model.rs +++ b/candle-transformers/src/models/deepseekv3/model.rs @@ -3,8 +3,8 @@ use std::{f32::consts::PI, str::FromStr, sync::Arc}; use candle::{ - shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Result, Shape, - Tensor, WithDType, D, + quantized::GgmlDType, shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, + Result, Shape, Tensor, WithDType, D, }; use candle_nn::{embedding, rms_norm, Activation, Embedding, Module, RmsNorm, VarBuilder}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; @@ -555,6 +555,7 @@ impl Attention { rotary_emb: Arc, cfg: &DeepSeekV2Config, vb: VarBuilder, + quant: Option, ) -> Result { let q_head_dim = cfg.q_head_dim(); let q = match cfg.q_lora_rank { @@ -564,7 +565,7 @@ impl Attention { lora_rank, &cfg.quantization_config, cfg.attention_bias, - None, + quant, vb.pp("q_a_proj"), )?; let norm = rms_norm(lora_rank, cfg.rms_norm_eps, vb.pp("q_a_layernorm"))?; @@ -573,7 +574,7 @@ impl Attention { cfg.num_attention_heads * q_head_dim, &cfg.quantization_config, false, - None, + quant, vb.pp("q_b_proj"), )?; QProj::Lora { a, norm, b } @@ -583,7 +584,7 @@ impl Attention { cfg.num_attention_heads * q_head_dim, &cfg.quantization_config, false, - None, + quant, vb.pp("q_proj"), )?), }; @@ -593,7 +594,7 @@ impl Attention { cfg.kv_lora_rank + cfg.qk_rope_head_dim, &cfg.quantization_config, cfg.attention_bias, - None, + quant, vb.pp("kv_a_proj_with_mqa"), )?; let kv_a_layernorm = rms_norm(cfg.kv_lora_rank, cfg.rms_norm_eps, vb.pp("kv_a_layernorm"))?; @@ -602,7 +603,7 @@ impl Attention { cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim), &cfg.quantization_config, false, - None, + quant, vb.pp("kv_b_proj"), )?; @@ -611,7 +612,7 @@ impl Attention { cfg.hidden_size, &cfg.quantization_config, cfg.attention_bias, - None, + quant, vb.pp("o_proj"), )?; @@ -759,6 +760,7 @@ impl Mlp { vb: VarBuilder, hidden_size: Option, intermediate_size: Option, + quant: Option, ) -> Result { let hidden_size = hidden_size.unwrap_or(cfg.hidden_size); let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size); @@ -769,7 +771,7 @@ impl Mlp { intermediate_size, &cfg.quantization_config, false, - None, + quant, vb.pp("gate_proj"), )?, up: quant::blockwise_fp8_linear_b( @@ -777,7 +779,7 @@ impl Mlp { intermediate_size, &cfg.quantization_config, false, - None, + quant, vb.pp("up_proj"), )?, down: quant::blockwise_fp8_linear_b( @@ -785,7 +787,7 @@ impl Mlp { hidden_size, &cfg.quantization_config, false, - None, + quant, vb.pp("down_proj"), )?, act: cfg.hidden_act, @@ -937,14 +939,20 @@ impl Moe { fn new( cfg: &DeepSeekV2Config, vb: VarBuilder, - n_shared_experts: Option, n_routed_experts: usize, + quant: Option, ) -> Result { let mut experts = Vec::with_capacity(n_routed_experts); for i in 0..n_routed_experts { let vb_e = vb.pp("experts").pp(i); - experts.push(Mlp::new(cfg, vb_e, None, Some(cfg.moe_intermediate_size))?); + experts.push(Mlp::new( + cfg, + vb_e, + None, + Some(cfg.moe_intermediate_size), + quant, + )?); } let shared_experts = if let Some(n_shared_experts) = n_shared_experts { let intermediate_size = cfg.moe_intermediate_size * n_shared_experts; @@ -953,6 +961,7 @@ impl Moe { vb.pp("shared_experts"), None, Some(intermediate_size), + quant, )?) } else { None @@ -1038,8 +1047,9 @@ impl DecoderLayer { cfg: &DeepSeekV2Config, vb: VarBuilder, layer_idx: usize, + quant: Option, ) -> Result { - let attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"), quant)?; let input_layernorm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = rms_norm( @@ -1056,9 +1066,10 @@ impl DecoderLayer { vb.pp("mlp"), cfg.n_shared_experts, cfg.n_routed_experts.unwrap(), + quant, )?) } else { - MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?) + MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None, quant)?) }; Ok(Self { @@ -1097,7 +1108,7 @@ pub struct DeepSeekV2 { } impl DeepSeekV2 { - pub fn new(cfg: &DeepSeekV2Config, vb: VarBuilder) -> Result { + pub fn new(cfg: &DeepSeekV2Config, vb: VarBuilder, quant: Option) -> Result { let vb_m = vb.pp("model"); let embed_tokens = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; @@ -1123,7 +1134,13 @@ impl DeepSeekV2 { let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { - let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx), layer_idx)?; + let layer = DecoderLayer::new( + rotary_emb.clone(), + cfg, + vb_l.pp(layer_idx), + layer_idx, + quant, + )?; layers.push(layer) } From f0bbeea583dfab353200ec81f4680b50f3140c4c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 27 Jan 2025 13:38:58 -0500 Subject: [PATCH 6/8] Fix clippy --- candle-core/src/cpu_backend/mod.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 6d1185dbef..09eec61e09 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -1789,10 +1789,6 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v.to_f32() as u8); Ok(Self::U8(data)) } - (Self::F8E4M3(storage), DType::U8) => { - let data = unary_map(storage, layout, |v| v.to_f32() as u8); - Ok(Self::U8(data)) - } (Self::U8(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) From 406b991cfbf1e610b6061fb2d119950eaea3cfde Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 28 Jan 2025 08:44:42 -0500 Subject: [PATCH 7/8] Tensor parallelism support --- .vscode/settings.json | 3 +- Cargo.toml | 2 +- candle-core/src/cuda_backend/device.rs | 18 +- candle-core/src/cuda_backend/mod.rs | 160 ------- candle-core/src/sort.rs | 9 + candle-examples/Cargo.toml | 5 + candle-examples/examples/deepseekv3/main.rs | 236 +++++++++ .../examples}/deepseekv3/model.rs | 186 ++++--- .../examples}/deepseekv3/ops.rs | 0 candle-examples/examples/deepseekv3/quant.rs | 452 ++++++++++++++++++ candle-flash-attn/cutlass | 2 +- .../src/models/deepseekv3/mod.rs | 3 - .../src/models/deepseekv3/quant.rs | 116 ----- candle-transformers/src/models/mod.rs | 1 - 14 files changed, 826 insertions(+), 367 deletions(-) create mode 100644 candle-examples/examples/deepseekv3/main.rs rename {candle-transformers/src/models => candle-examples/examples}/deepseekv3/model.rs (89%) rename {candle-transformers/src/models => candle-examples/examples}/deepseekv3/ops.rs (100%) create mode 100644 candle-examples/examples/deepseekv3/quant.rs delete mode 100644 candle-transformers/src/models/deepseekv3/mod.rs delete mode 100644 candle-transformers/src/models/deepseekv3/quant.rs diff --git a/.vscode/settings.json b/.vscode/settings.json index b2dbd68012..9beadf9fb9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,6 @@ "candle-pyo3" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "rust-analyzer.cargo.features": ["cuda", "nccl"] } \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index d5a527b105..03459de660 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,7 @@ fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } hf-hub = "0.4.1" half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } -float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] } +float8 = { version = "0.1.3", features = ["num-traits", "rand_distr"] } hound = "3.5.1" image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] } imageproc = { version = "0.24.0", default-features = false } diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index d9ec3e28ef..fa4e8480b7 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -313,17 +313,13 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 - | DType::U32 - | DType::I16 - | DType::I32 - | DType::I64 - | DType::F16 - | DType::BF16 => Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()?, + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 | DType::F8E4M3 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()? + } DType::F32 => { let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; curand diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 8b90fc3e35..d6fed7813b 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -2117,163 +2117,3 @@ unsafe fn gemm_strided_batched_bf16( sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP, ) } - -pub struct KVConcat { - pub concat_dim: usize, -} -impl crate::CustomOp2 for KVConcat { - fn name(&self) -> &'static str { - "kvconcat" - } - - fn cpu_fwd( - &self, - _: &CpuStorage, - _: &Layout, - _: &CpuStorage, - _: &Layout, - ) -> Result<(CpuStorage, Shape)> { - crate::bail!("no cpu support for kvconcat") - } - - fn cuda_fwd( - &self, - ltensor: &CudaStorage, - ltensor_l: &Layout, - rtensor: &CudaStorage, - rtensor_l: &Layout, - ) -> Result<(CudaStorage, Shape)> { - assert!(self.concat_dim == 2 || self.concat_dim == 0); //must be in the dim of sequence len - let dev = <ensor.device; - let elem_count = ltensor_l.shape().elem_count() + rtensor_l.shape().elem_count(); - let dims_l = ltensor_l.shape().dims(); - let dims_r = rtensor_l.shape().dims(); - let dim_size = dims_l.len(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); - - let chunk_l = if dim_size > 3 { - dims_l[0] * dims_l[1] - } else { - dims_l[0] - }; - let chunk_r = if dim_size > 3 { - dims_r[0] * dims_r[1] - } else { - dims_r[0] - }; - let lstride = if dim_size > 3 { - dims_l[2] * dims_l[3] - } else { - dims_l[1] * dims_l[2] - }; - let rstride = if dim_size > 3 { - dims_r[2] * dims_r[3] - } else { - dims_r[1] * dims_r[2] - }; - - let slice = match (<ensor.slice, &rtensor.slice) { - (CudaStorageSlice::BF16(left_), CudaStorageSlice::BF16(right_)) => { - let out = unsafe { dev.alloc::(elem_count).w()? }; - let func = dev.get_or_load_func("kvconcat_bf16", kernels::KVCONCAT)?; - let params = ( - left_, - right_, - &out, - self.concat_dim, - chunk_l, - chunk_r, - lstride, - rstride, - ); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::BF16(out) - } - (CudaStorageSlice::F32(left_), CudaStorageSlice::F32(right_)) => { - let out = unsafe { dev.alloc::(elem_count).w()? }; - let func = dev.get_or_load_func("kvconcat_f32", kernels::KVCONCAT)?; - let params = ( - left_, - right_, - &out, - self.concat_dim, - chunk_l, - chunk_r, - lstride, - rstride, - ); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F32(out) - } - (CudaStorageSlice::F16(left_), CudaStorageSlice::F16(right_)) => { - let out = unsafe { dev.alloc::(elem_count).w()? }; - let func = dev.get_or_load_func("kvconcat_f16", kernels::KVCONCAT)?; - let params = ( - left_, - right_, - &out, - self.concat_dim, - chunk_l, - chunk_r, - lstride, - rstride, - ); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F16(out) - } - (CudaStorageSlice::F64(left_), CudaStorageSlice::F64(right_)) => { - let out = unsafe { dev.alloc::(elem_count).w()? }; - let func = dev.get_or_load_func("kvconcat_f64", kernels::KVCONCAT)?; - let params = ( - left_, - right_, - &out, - self.concat_dim, - chunk_l, - chunk_r, - lstride, - rstride, - ); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F64(out) - } - (CudaStorageSlice::U8(left_), CudaStorageSlice::U8(right_)) => { - let out = unsafe { dev.alloc::(elem_count).w()? }; - let func = dev.get_or_load_func("kvconcat_u8", kernels::KVCONCAT)?; - let params = ( - left_, - right_, - &out, - self.concat_dim, - chunk_l, - chunk_r, - lstride, - rstride, - ); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::U8(out) - } - _ => Err(CudaError::InternalError("dtype mismatch in kvconcat op"))?, - }; - - let mut lshape: Vec = ltensor_l.shape().dims().to_vec(); - if self.concat_dim == 0 { - lshape[0] += rtensor_l.shape().dims()[0]; - } else { - if dim_size > 3 { - lshape[2] += rtensor_l.shape().dims()[2]; - } else { - lshape[1] += rtensor_l.shape().dims()[1]; - } - } - - let device = dev.clone(); - Ok(( - CudaStorage { - slice: slice, - device, - }, - lshape.into(), - )) - } -} diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 7c84701bc6..06b34db912 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -61,6 +61,15 @@ mod cuda { use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr}; use crate::{CudaDevice, WithDType}; + #[allow(unused)] + fn next_power_of_2(x: usize) -> usize { + let mut n = 1; + while n < x { + n *= 2 + } + n + } + impl crate::cuda_backend::Map1Any for ArgSort { fn f) -> S>( &self, diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index e679d01b60..70499de1c1 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -49,6 +49,7 @@ ab_glyph = { workspace = true } tracing = { workspace = true } tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } +float8 = { workspace = true } # Necessary to disambiguate with tokio in wasm examples which are 1.28.1 tokio = "1.43.0" @@ -122,3 +123,7 @@ required-features = ["onnx"] [[example]] name = "colpali" required-features = ["pdf2image"] + +[[example]] +name = "deepseekv3" +required-features = ["cuda", "nccl"] diff --git a/candle-examples/examples/deepseekv3/main.rs b/candle-examples/examples/deepseekv3/main.rs new file mode 100644 index 0000000000..f50578ca52 --- /dev/null +++ b/candle-examples/examples/deepseekv3/main.rs @@ -0,0 +1,236 @@ +// An implementation of LLaMA https://github.com/facebookresearch/llama +// +// This is based on nanoGPT in a similar way to: +// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py +// +// The tokenizer config can be retrieved from: +// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::{bail, Error as E, Result}; +use clap::{Parser, ValueEnum}; + +use candle::{DType, Device, Tensor}; +use candle_transformers::generation::LogitsProcessor; +use cudarc::driver::safe::CudaDevice; +use cudarc::nccl::safe::{Comm, Id}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use std::io::Write; +use std::rc::Rc; + +mod model; +mod ops; +mod quant; +use model::{DeepSeekV3, DeepSeekV3Config}; + +const DEFAULT_PROMPT: &str = "My favorite theorem is "; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + R1, + V3, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long)] + num_shards: usize, + + #[arg(long)] + rank: Option, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, default_value_t = 100)] + sample_len: usize, + + /// Disable the key-value cache. + #[arg(long)] + no_kv_cache: bool, + + /// The initial prompt. + #[arg(long)] + prompt: Option, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + dtype: Option, + + #[arg(long, default_value = "r1")] + which: Which, + + #[arg(long, default_value = "nccl_id.txt")] + comm_file: String, +} + +fn main() -> Result<()> { + use tokenizers::Tokenizer; + + let args = Args::parse(); + + let dtype = match args.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => bail!("Unsupported dtype {dtype}"), + None => DType::BF16, + }; + + let comm_file = std::path::PathBuf::from(&args.comm_file); + if comm_file.exists() { + bail!("comm file {comm_file:?} already exists, please remove it first") + } + + let api = Api::new()?; + let model_id = match args.model_id { + Some(model) => model, + None => match args.which { + Which::V3 => "deepseek-ai/DeepSeek-V3".to_string(), + Which::R1 => "deepseek-ai/DeepSeek-R1".to_string(), + }, + }; + println!("loading the model weights from {model_id}"); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let config_filename = api.get("config.json")?; + let config: DeepSeekV3Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let tokenizer_filename = api.get("tokenizer.json")?; + let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?; + + let rank = match args.rank { + None => { + println!("creating {} child processes", args.num_shards); + let children: Vec<_> = (0..args.num_shards) + .map(|rank| { + let mut args: std::collections::VecDeque<_> = std::env::args().collect(); + args.push_back("--rank".to_string()); + args.push_back(format!("{rank}")); + let name = args.pop_front().unwrap(); + std::process::Command::new(name).args(args).spawn().unwrap() + }) + .collect(); + for mut child in children { + child.wait()?; + } + return Ok(()); + } + Some(rank) => rank, + }; + + let num_shards = args.num_shards; + // Primitive IPC + let id = if rank == 0 { + let id = Id::new().unwrap(); + let tmp_file = comm_file.with_extension(".comm.tgz"); + std::fs::File::create(&tmp_file)? + .write_all(&id.internal().iter().map(|&i| i as u8).collect::>())?; + std::fs::rename(&tmp_file, &comm_file)?; + id + } else { + while !comm_file.exists() { + std::thread::sleep(std::time::Duration::from_secs(1)); + } + let data = std::fs::read(&comm_file)?; + let internal: [i8; 128] = data + .into_iter() + .map(|i| i as i8) + .collect::>() + .try_into() + .unwrap(); + let id: Id = Id::uninit(internal); + id + }; + let device = CudaDevice::new(rank)?; + let comm = match Comm::from_rank(device, rank, num_shards, id) { + Ok(comm) => Rc::new(comm), + Err(err) => anyhow::bail!("nccl error {:?}", err.0), + }; + if rank == 0 { + std::fs::remove_file(comm_file)?; + } + println!("Rank {rank:?} spawned"); + + let device = Device::new_cuda(rank)?; + + println!("building the model"); + let vb = unsafe { + candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)? + }; + let llama = DeepSeekV3::new(&config, vb, None, comm)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); + let mut tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer); + + println!("starting the inference loop"); + let temperature = if args.temperature <= 0. { + None + } else { + Some(args.temperature) + }; + let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); + let mut new_tokens = vec![]; + let mut start_gen = std::time::Instant::now(); + let mut index_pos = 0; + for index in 0..args.sample_len { + // Only start timing at the second token as processing the first token waits for all the + // weights to be loaded in an async way. + if index == 1 { + start_gen = std::time::Instant::now() + }; + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + let logits = llama.forward(&input, index_pos)?; + let logits = logits.squeeze(0)?; + index_pos += ctxt.len(); + + let next_token = logits_processor.sample(&logits)?; + tokens.push(next_token); + new_tokens.push(next_token); + if next_token == config.eos_token_id { + break; + } + + if rank == 0 { + if let Some(t) = tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + } + } + println!(); + if rank == 0 { + let dt = start_gen.elapsed(); + println!( + "\n\n{} tokens generated ({} token/s)\n", + args.sample_len, + (args.sample_len - 1) as f64 / dt.as_secs_f64(), + ); + } + Ok(()) +} diff --git a/candle-transformers/src/models/deepseekv3/model.rs b/candle-examples/examples/deepseekv3/model.rs similarity index 89% rename from candle-transformers/src/models/deepseekv3/model.rs rename to candle-examples/examples/deepseekv3/model.rs index da89d6af25..0bf948ad0f 100644 --- a/candle-transformers/src/models/deepseekv3/model.rs +++ b/candle-examples/examples/deepseekv3/model.rs @@ -1,16 +1,21 @@ #![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -use std::{f32::consts::PI, str::FromStr, sync::Arc}; +use std::{f32::consts::PI, rc::Rc, str::FromStr, sync::Arc}; use candle::{ quantized::GgmlDType, shape::Dim, CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Result, Shape, Tensor, WithDType, D, }; -use candle_nn::{embedding, rms_norm, Activation, Embedding, Module, RmsNorm, VarBuilder}; +use candle_nn::{var_builder::ShardedVarBuilder, Activation, Embedding, Linear, Module, RmsNorm}; +use cudarc::nccl::Comm; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use serde::Deserialize; -use super::quant::{self, BlockwiseFP8Linear, QuantizedConfig}; +use crate::quant::BlockwiseFP8ReplicatedLinear; + +use super::quant::{ + self, BlockwiseFP8ParallelColumnLinear, BlockwiseFP8ParallelRowLinear, QuantizedConfig, +}; struct NonZero {} @@ -199,8 +204,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result Ok(m) } -#[doc(hidden)] -#[macro_export] macro_rules! serde_default_fn { ($t:ty, $name:ident, $v:expr) => { fn $name() -> $t { @@ -235,7 +238,7 @@ enum ScoringFunc { } #[derive(Deserialize, Clone, Debug)] -pub struct DeepSeekV2Config { +pub struct DeepSeekV3Config { pub(crate) vocab_size: usize, pub(crate) hidden_size: usize, pub(crate) intermediate_size: usize, @@ -265,7 +268,7 @@ pub struct DeepSeekV2Config { #[serde(default = "tie_word_embeddings")] pub(crate) tie_word_embeddings: bool, pub(crate) rope_theta: f32, - pub(crate) rope_scaling: Option, + pub(crate) rope_scaling: Option, pub(crate) attention_bias: bool, pub(crate) q_lora_rank: Option, pub(crate) qk_rope_head_dim: usize, @@ -275,6 +278,7 @@ pub struct DeepSeekV2Config { pub(crate) n_group: usize, pub(crate) topk_group: usize, pub(crate) quantization_config: Option, + pub(crate) eos_token_id: u32, } #[derive(Debug, Clone, Deserialize)] @@ -307,14 +311,14 @@ impl FromStr for ScaledRopeType { } #[derive(Debug, Clone)] -pub struct DeepSeekV2RotaryEmbedding { +pub struct DeepSeekV3RotaryEmbedding { sin: Tensor, cos: Tensor, } #[derive(Debug, Clone, Deserialize)] #[serde(untagged)] -pub enum DeepSeekV2RopeScaling { +pub enum DeepSeekV3RopeScaling { Yarn { original_max_position_embeddings: usize, beta_fast: f32, @@ -322,25 +326,18 @@ pub enum DeepSeekV2RopeScaling { mscale: f32, mscale_all_dim: f32, factor: f32, - #[serde(rename = "type")] - scaling_type: ScaledRopeType, - }, - BlockwiseFP8LinearOrDynamic { - #[serde(rename = "type")] - scaling_type: ScaledRopeType, - factor: f64, }, } -pub struct DeepSeekV2RopeConfig { - pub rope_scaling: Option, +pub struct DeepSeekV3RopeConfig { + pub rope_scaling: Option, pub max_position_embeddings: usize, pub rope_theta: f32, pub qk_rope_head_dim: usize, } -impl DeepSeekV2RotaryEmbedding { - fn new_unscaled(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result { +impl DeepSeekV3RotaryEmbedding { + fn new_unscaled(cfg: &DeepSeekV3RopeConfig, dtype: DType, dev: &Device) -> Result { let max_seq_len = cfg.max_position_embeddings; let dim = cfg.qk_rope_head_dim; @@ -404,7 +401,7 @@ impl DeepSeekV2RotaryEmbedding { #[allow(clippy::too_many_arguments)] fn new_yarn( - cfg: &DeepSeekV2RopeConfig, + cfg: &DeepSeekV3RopeConfig, dtype: DType, dev: &Device, original_max_position_embeddings: usize, @@ -453,20 +450,15 @@ impl DeepSeekV2RotaryEmbedding { Ok(Self { sin, cos }) } - pub fn new(cfg: &DeepSeekV2RopeConfig, dtype: DType, dev: &Device) -> Result { + pub fn new(cfg: &DeepSeekV3RopeConfig, dtype: DType, dev: &Device) -> Result { match &cfg.rope_scaling { - Some(DeepSeekV2RopeScaling::BlockwiseFP8LinearOrDynamic { - scaling_type: _, - factor: _, - }) => candle::bail!("linear and dynamic rope are not implemented yet!"), - Some(DeepSeekV2RopeScaling::Yarn { + Some(DeepSeekV3RopeScaling::Yarn { original_max_position_embeddings, beta_fast, beta_slow, factor, mscale, mscale_all_dim, - scaling_type: _, }) => Self::new_yarn( cfg, dtype, @@ -500,32 +492,50 @@ impl DeepSeekV2RotaryEmbedding { } } -impl DeepSeekV2Config { +impl DeepSeekV3Config { pub(crate) fn q_head_dim(&self) -> usize { self.qk_rope_head_dim + self.qk_nope_head_dim } fn softmax_scale(&self) -> f32 { let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt(); - if let Some(DeepSeekV2RopeScaling::Yarn { + if let Some(DeepSeekV3RopeScaling::Yarn { mscale_all_dim, factor, .. }) = self.rope_scaling { - let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim); + let mscale = DeepSeekV3RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim); softmax_scale = softmax_scale * mscale * mscale; } softmax_scale } } +fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard { + candle_nn::var_builder::Shard { + dim, + rank, + world_size, + } +} + +fn rms_norm(size: usize, eps: f64, vb: ShardedVarBuilder) -> Result { + let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?; + Ok(RmsNorm::new(weight, eps)) +} + +fn embedding(in_size: usize, out_size: usize, vb: ShardedVarBuilder) -> Result { + let embeddings = vb.get((in_size, out_size), "weight")?; + Ok(Embedding::new(embeddings, out_size)) +} + enum QProj { - Plain(BlockwiseFP8Linear), + Plain(BlockwiseFP8ParallelColumnLinear), Lora { - a: BlockwiseFP8Linear, + a: BlockwiseFP8ReplicatedLinear, norm: RmsNorm, - b: BlockwiseFP8Linear, + b: BlockwiseFP8ParallelColumnLinear, }, } @@ -540,27 +550,28 @@ impl QProj { struct Attention { q: QProj, - kv_a_proj_with_mqa: BlockwiseFP8Linear, + kv_a_proj_with_mqa: BlockwiseFP8ReplicatedLinear, kv_a_layernorm: RmsNorm, - kv_b_proj: BlockwiseFP8Linear, - o_proj: BlockwiseFP8Linear, - rotary_emb: Arc, - cfg: DeepSeekV2Config, + kv_b_proj: BlockwiseFP8ParallelColumnLinear, + o_proj: BlockwiseFP8ParallelRowLinear, + rotary_emb: Arc, + cfg: DeepSeekV3Config, q_head_dim: usize, softmax_scale: f64, } impl Attention { fn new( - rotary_emb: Arc, - cfg: &DeepSeekV2Config, - vb: VarBuilder, + rotary_emb: Arc, + cfg: &DeepSeekV3Config, + vb: ShardedVarBuilder, quant: Option, + comm: Rc, ) -> Result { let q_head_dim = cfg.q_head_dim(); let q = match cfg.q_lora_rank { Some(lora_rank) => { - let a = quant::blockwise_fp8_linear_b( + let a = quant::blockwise_fp8_linear_b_replicated( cfg.hidden_size, lora_rank, &cfg.quantization_config, @@ -569,27 +580,29 @@ impl Attention { vb.pp("q_a_proj"), )?; let norm = rms_norm(lora_rank, cfg.rms_norm_eps, vb.pp("q_a_layernorm"))?; - let b = quant::blockwise_fp8_linear_b( + let b = quant::blockwise_fp8_linear_b_parallel_column( lora_rank, cfg.num_attention_heads * q_head_dim, &cfg.quantization_config, false, quant, + comm.clone(), vb.pp("q_b_proj"), )?; QProj::Lora { a, norm, b } } - None => QProj::Plain(quant::blockwise_fp8_linear_b( + None => QProj::Plain(quant::blockwise_fp8_linear_b_parallel_column( cfg.hidden_size, cfg.num_attention_heads * q_head_dim, &cfg.quantization_config, false, quant, + comm.clone(), vb.pp("q_proj"), )?), }; - let kv_a_proj_with_mqa = quant::blockwise_fp8_linear_b( + let kv_a_proj_with_mqa = quant::blockwise_fp8_linear_b_replicated( cfg.hidden_size, cfg.kv_lora_rank + cfg.qk_rope_head_dim, &cfg.quantization_config, @@ -598,21 +611,23 @@ impl Attention { vb.pp("kv_a_proj_with_mqa"), )?; let kv_a_layernorm = rms_norm(cfg.kv_lora_rank, cfg.rms_norm_eps, vb.pp("kv_a_layernorm"))?; - let kv_b_proj = quant::blockwise_fp8_linear_b( + let kv_b_proj = quant::blockwise_fp8_linear_b_parallel_column( cfg.kv_lora_rank, cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim), &cfg.quantization_config, false, quant, + comm.clone(), vb.pp("kv_b_proj"), )?; - let o_proj = quant::blockwise_fp8_linear_b( + let o_proj = quant::blockwise_fp8_linear_b_parallel_row( cfg.num_attention_heads * cfg.v_head_dim, cfg.hidden_size, &cfg.quantization_config, cfg.attention_bias, quant, + comm, vb.pp("o_proj"), )?; @@ -748,46 +763,50 @@ impl Attention { } struct Mlp { - gate: BlockwiseFP8Linear, - up: BlockwiseFP8Linear, - down: BlockwiseFP8Linear, + gate: BlockwiseFP8ParallelColumnLinear, + up: BlockwiseFP8ParallelColumnLinear, + down: BlockwiseFP8ParallelRowLinear, act: Activation, } impl Mlp { fn new( - cfg: &DeepSeekV2Config, - vb: VarBuilder, + cfg: &DeepSeekV3Config, + vb: ShardedVarBuilder, hidden_size: Option, intermediate_size: Option, quant: Option, + comm: Rc, ) -> Result { let hidden_size = hidden_size.unwrap_or(cfg.hidden_size); let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size); Ok(Self { - gate: quant::blockwise_fp8_linear_b( + gate: quant::blockwise_fp8_linear_b_parallel_column( hidden_size, intermediate_size, &cfg.quantization_config, false, quant, + comm.clone(), vb.pp("gate_proj"), )?, - up: quant::blockwise_fp8_linear_b( + up: quant::blockwise_fp8_linear_b_parallel_column( hidden_size, intermediate_size, &cfg.quantization_config, false, quant, + comm.clone(), vb.pp("up_proj"), )?, - down: quant::blockwise_fp8_linear_b( + down: quant::blockwise_fp8_linear_b_parallel_row( intermediate_size, hidden_size, &cfg.quantization_config, false, quant, + comm.clone(), vb.pp("down_proj"), )?, act: cfg.hidden_act, @@ -803,14 +822,14 @@ impl Mlp { struct MoeGate { weight: Tensor, - cfg: DeepSeekV2Config, + cfg: DeepSeekV3Config, top_k: usize, n_routed_experts: usize, e_score_correction_bias: Option, } impl MoeGate { - fn new(cfg: &DeepSeekV2Config, vb: VarBuilder, n_routed_experts: usize) -> Result { + fn new(cfg: &DeepSeekV3Config, vb: ShardedVarBuilder, n_routed_experts: usize) -> Result { let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?; let e_score_correction_bias = if matches!(cfg.topk_method, TopkMethod::NoAuxTc) { Some(vb.get_with_hints_dtype( @@ -937,11 +956,12 @@ struct Moe { impl Moe { fn new( - cfg: &DeepSeekV2Config, - vb: VarBuilder, + cfg: &DeepSeekV3Config, + vb: ShardedVarBuilder, n_shared_experts: Option, n_routed_experts: usize, quant: Option, + comm: Rc, ) -> Result { let mut experts = Vec::with_capacity(n_routed_experts); for i in 0..n_routed_experts { @@ -952,6 +972,7 @@ impl Moe { None, Some(cfg.moe_intermediate_size), quant, + comm.clone(), )?); } let shared_experts = if let Some(n_shared_experts) = n_shared_experts { @@ -962,6 +983,7 @@ impl Moe { None, Some(intermediate_size), quant, + comm, )?) } else { None @@ -1043,13 +1065,14 @@ struct DecoderLayer { impl DecoderLayer { fn new( - rotary_emb: Arc, - cfg: &DeepSeekV2Config, - vb: VarBuilder, + rotary_emb: Arc, + cfg: &DeepSeekV3Config, + vb: ShardedVarBuilder, layer_idx: usize, quant: Option, + comm: Rc, ) -> Result { - let attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"), quant)?; + let attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"), quant, comm.clone())?; let input_layernorm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = rms_norm( @@ -1067,9 +1090,10 @@ impl DecoderLayer { cfg.n_shared_experts, cfg.n_routed_experts.unwrap(), quant, + comm, )?) } else { - MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None, quant)?) + MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None, quant, comm)?) }; Ok(Self { @@ -1098,8 +1122,8 @@ impl DecoderLayer { } } -pub struct DeepSeekV2 { - lm_head: candle_nn::Linear, +pub struct DeepSeekV3 { + lm_head: BlockwiseFP8ReplicatedLinear, embed_tokens: Embedding, norm: RmsNorm, layers: Vec, @@ -1107,25 +1131,40 @@ pub struct DeepSeekV2 { device: Device, } -impl DeepSeekV2 { - pub fn new(cfg: &DeepSeekV2Config, vb: VarBuilder, quant: Option) -> Result { +impl DeepSeekV3 { + pub fn new( + cfg: &DeepSeekV3Config, + vb: ShardedVarBuilder, + quant: Option, + comm: Rc, + ) -> Result { let vb_m = vb.pp("model"); let embed_tokens = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; let lm_head = if !cfg.tie_word_embeddings { - candle_nn::linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + quant::blockwise_fp8_linear_b_replicated( + cfg.hidden_size, + cfg.vocab_size, + &cfg.quantization_config, + false, + None, + vb.pp("lm_head"), + )? } else { - candle_nn::Linear::new(embed_tokens.embeddings().clone(), None) + BlockwiseFP8ReplicatedLinear::Unquantized(Linear::new( + embed_tokens.embeddings().clone(), + None, + )) }; let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let rope_cfg = DeepSeekV2RopeConfig { + let rope_cfg = DeepSeekV3RopeConfig { rope_scaling: cfg.rope_scaling.clone(), max_position_embeddings: cfg.max_position_embeddings, rope_theta: cfg.rope_theta, qk_rope_head_dim: cfg.qk_rope_head_dim, }; - let rotary_emb = Arc::new(DeepSeekV2RotaryEmbedding::new( + let rotary_emb = Arc::new(DeepSeekV3RotaryEmbedding::new( &rope_cfg, vb.dtype(), vb.device(), @@ -1140,6 +1179,7 @@ impl DeepSeekV2 { vb_l.pp(layer_idx), layer_idx, quant, + comm.clone(), )?; layers.push(layer) } diff --git a/candle-transformers/src/models/deepseekv3/ops.rs b/candle-examples/examples/deepseekv3/ops.rs similarity index 100% rename from candle-transformers/src/models/deepseekv3/ops.rs rename to candle-examples/examples/deepseekv3/ops.rs diff --git a/candle-examples/examples/deepseekv3/quant.rs b/candle-examples/examples/deepseekv3/quant.rs new file mode 100644 index 0000000000..bf364f9031 --- /dev/null +++ b/candle-examples/examples/deepseekv3/quant.rs @@ -0,0 +1,452 @@ +use std::rc::Rc; + +use candle::{ + quantized::{GgmlDType, QMatMul, QTensor}, + CpuStorage, CustomOp1, DType, Layout, Module, Result, Shape, Tensor, +}; +use candle_nn::{var_builder::ShardedVarBuilder, Linear}; +use cudarc::nccl::Comm; +use serde::Deserialize; + +use super::ops; + +#[derive(Debug, Clone, Deserialize)] +pub enum QuantMethodType { + #[serde(rename = "fp8")] + Fp8, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct QuantizedConfig { + pub weight_block_size: Option>, + pub quant_method: QuantMethodType, +} + +pub struct AllReduce { + comm: Rc, +} + +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Sync for AllReduce {} +unsafe impl Send for AllReduce {} + +impl CustomOp1 for AllReduce { + fn name(&self) -> &'static str { + "allreduce" + } + + fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> { + candle::bail!("AllReduce is never used on cpu") + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s: &candle::CudaStorage, + l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::{backend::BackendStorage, cuda_backend::WrapErr}; + use cudarc::{driver::DeviceSlice, nccl::ReduceOp}; + use half::{bf16, f16}; + + let elem_count = l.shape().elem_count(); + let dev = s.device().clone(); + let dst = match s.dtype() { + DType::BF16 => { + let s = s.as_cuda_slice::()?; + let s = match l.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle::bail!("input has to be contiguous"), + }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle::Error::debug)?; + candle::CudaStorage::wrap_cuda_slice(dst, dev) + } + DType::F16 => { + let s = s.as_cuda_slice::()?; + let s = match l.contiguous_offsets() { + Some((0, l)) if l == s.len() => s, + Some(_) | None => candle::bail!("input has to be contiguous"), + }; + let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; + self.comm + .all_reduce(s, &mut dst, &ReduceOp::Sum) + .map_err(candle::Error::debug)?; + candle::CudaStorage::wrap_cuda_slice(dst, dev) + } + dtype => candle::bail!("unsupported dtype {dtype:?}"), + }; + Ok((dst, l.shape().clone())) + } +} + +fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard { + candle_nn::var_builder::Shard { + dim, + rank, + world_size, + } +} + +/// This linear layer has a weight that is parallelized along the input dimension, +/// returning the "full" output dimension. +pub enum BlockwiseFP8ParallelRowLinear { + Quantized { + w: QMatMul, + b: Option, + all_reduce: AllReduce, + }, + Unquantized { + w: Linear, + b: Option, + all_reduce: AllReduce, + }, +} + +impl Module for BlockwiseFP8ParallelRowLinear { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Quantized { w, b, all_reduce } => { + let in_ty = xs.dtype(); + let mut xs = xs.to_dtype(DType::F32)?; + + xs = w.forward(&xs)?.apply_op1_no_bwd(all_reduce)?; + if let Some(bias) = b { + xs = xs.broadcast_add(bias)?; + } + + xs.to_dtype(in_ty) + } + Self::Unquantized { w, b, all_reduce } => { + let mut xs = w.forward(&xs)?.apply_op1_no_bwd(all_reduce)?; + if let Some(bias) = b { + xs = xs.broadcast_add(bias)?; + } + Ok(xs) + } + } + } +} + +/// This linear layer has a weight that is parallelized along the output dimension, +/// taking the "full" input dimension. +pub enum BlockwiseFP8ParallelColumnLinear { + Quantized { w: QMatMul, b: Option }, + Unquantized(Linear), +} + +impl Module for BlockwiseFP8ParallelColumnLinear { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Quantized { w, b } => { + let in_ty = xs.dtype(); + let mut xs = xs.to_dtype(DType::F32)?; + + xs = w.forward(&xs)?; + if let Some(bias) = b { + xs = xs.broadcast_add(bias)?; + } + + xs.to_dtype(in_ty) + } + Self::Unquantized(l) => xs.apply(l), + } + } +} + +/// This linear layer has a weight that is parallelized along the output dimension, +/// taking the "full" input dimension. +pub enum BlockwiseFP8ReplicatedLinear { + Quantized { w: QMatMul, b: Option }, + Unquantized(Linear), +} + +impl Module for BlockwiseFP8ReplicatedLinear { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Quantized { w, b } => { + let in_ty = xs.dtype(); + let mut xs = xs.to_dtype(DType::F32)?; + + xs = w.forward(&xs)?; + if let Some(bias) = b { + xs = xs.broadcast_add(bias)?; + } + + xs.to_dtype(in_ty) + } + Self::Unquantized(l) => xs.apply(l), + } + } +} + +/// Load a blockwise quantized FP8 layer and optionally quantize it in-place for faster inference. +/// This linear layer has a weight that is parallelized along the output dimension, +/// taking the "full" input dimension. +/// +/// The bias is parallelized. +pub fn blockwise_fp8_linear_b_parallel_column( + in_dim: usize, + out_dim: usize, + config: &Option, + bias: bool, + quant: Option, + comm: Rc, + vb: ShardedVarBuilder, +) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + + let Some(config) = config else { + let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard(0, rank, size))?; + + let bias = if bias { + Some(vb.get_with_hints((out_dim,), "bias", shard(0, rank, size))?) + } else { + None + }; + + return Ok(BlockwiseFP8ParallelColumnLinear::Unquantized(Linear::new( + weight, bias, + ))); + }; + + if !matches!(config.quant_method, QuantMethodType::Fp8) { + candle::bail!("Expected FP8 quant method!"); + } + + let weight_block_size = config + .weight_block_size + .as_ref() + .expect("Blockwise FP8 requires weight_block_size in config"); + if weight_block_size.len() != 2 { + candle::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}") + } + let weight = vb.get_with_hints_dtype( + (out_dim, in_dim), + "weight", + shard(0, rank, size), + DType::F8E4M3, + )?; + let weight_scale_inv = vb.get_with_hints_dtype( + ( + out_dim.div_ceil(weight_block_size[0]), + in_dim.div_ceil(weight_block_size[1]), + ), + "weight_scale_inv", + shard(0, rank, size), + DType::F32, + )?; + + let bias_ty = if quant.is_some() { + DType::F32 + } else { + vb.dtype() + }; + + let bias = if bias { + Some( + vb.get_with_hints((out_dim,), "bias", shard(0, rank, size))? + .to_dtype(bias_ty)?, + ) + } else { + None + }; + + let dequant = ops::fp8_blockwise_dequantize( + &weight, + &weight_scale_inv, + weight_block_size.to_vec(), + vb.dtype(), + )?; + + let layer = match quant { + Some(q) => BlockwiseFP8ParallelColumnLinear::Quantized { + w: QMatMul::from_qtensor(QTensor::quantize(&dequant, q)?)?, + b: bias, + }, + None => BlockwiseFP8ParallelColumnLinear::Unquantized(Linear::new(dequant, None)), + }; + + Ok(layer) +} + +/// Load a blockwise quantized FP8 layer and optionally quantize it in-place for faster inference. +/// This linear layer has a weight that is parallelized along the input dimension, +/// returning the "full" output dimension. +/// +/// The bias is not parallelized. +pub fn blockwise_fp8_linear_b_parallel_row( + in_dim: usize, + out_dim: usize, + config: &Option, + bias: bool, + quant: Option, + comm: Rc, + vb: ShardedVarBuilder, +) -> Result { + let rank = comm.rank(); + let size = comm.world_size(); + + let all_reduce = AllReduce { comm }; + + let Some(config) = config else { + let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard(0, rank, size))?; + + let bias = if bias { + Some(vb.get((out_dim,), "bias")?) + } else { + None + }; + + return Ok(BlockwiseFP8ParallelRowLinear::Unquantized { + w: Linear::new(weight, None), + b: bias, + all_reduce, + }); + }; + + if !matches!(config.quant_method, QuantMethodType::Fp8) { + candle::bail!("Expected FP8 quant method!"); + } + + let weight_block_size = config + .weight_block_size + .as_ref() + .expect("Blockwise FP8 requires weight_block_size in config"); + if weight_block_size.len() != 2 { + candle::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}") + } + let weight = vb.get_with_hints_dtype( + (out_dim, in_dim), + "weight", + shard(1, rank, size), + DType::F8E4M3, + )?; + let weight_scale_inv = vb.get_with_hints_dtype( + ( + out_dim.div_ceil(weight_block_size[0]), + in_dim.div_ceil(weight_block_size[1]), + ), + "weight_scale_inv", + shard(1, rank, size), + DType::F32, + )?; + + let bias_ty = if quant.is_some() { + DType::F32 + } else { + vb.dtype() + }; + + let bias = if bias { + Some(vb.get((out_dim,), "bias")?.to_dtype(bias_ty)?) + } else { + None + }; + + let dequant = ops::fp8_blockwise_dequantize( + &weight, + &weight_scale_inv, + weight_block_size.to_vec(), + vb.dtype(), + )?; + + let layer = match quant { + Some(q) => BlockwiseFP8ParallelRowLinear::Quantized { + w: QMatMul::from_qtensor(QTensor::quantize(&dequant, q)?)?, + b: bias, + all_reduce, + }, + None => BlockwiseFP8ParallelRowLinear::Unquantized { + w: Linear::new(dequant, None), + b: bias, + all_reduce, + }, + }; + + Ok(layer) +} + +/// Load a blockwise quantized FP8 layer and optionally quantize it in-place for faster inference. +pub fn blockwise_fp8_linear_b_replicated( + in_dim: usize, + out_dim: usize, + config: &Option, + bias: bool, + quant: Option, + vb: ShardedVarBuilder, +) -> Result { + let Some(config) = config else { + let weight = vb.get((out_dim, in_dim), "weight")?; + + let bias = if bias { + Some(vb.get((out_dim,), "bias")?) + } else { + None + }; + + return Ok(BlockwiseFP8ReplicatedLinear::Unquantized(Linear::new( + weight, bias, + ))); + }; + + if !matches!(config.quant_method, QuantMethodType::Fp8) { + candle::bail!("Expected FP8 quant method!"); + } + + let weight_block_size = config + .weight_block_size + .as_ref() + .expect("Blockwise FP8 requires weight_block_size in config"); + if weight_block_size.len() != 2 { + candle::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}") + } + let weight = vb.get_with_hints_dtype( + (out_dim, in_dim), + "weight", + Default::default(), + DType::F8E4M3, + )?; + let weight_scale_inv = vb.get_with_hints_dtype( + ( + out_dim.div_ceil(weight_block_size[0]), + in_dim.div_ceil(weight_block_size[1]), + ), + "weight_scale_inv", + Default::default(), + DType::F32, + )?; + + let bias_ty = if quant.is_some() { + DType::F32 + } else { + vb.dtype() + }; + + let bias = if bias { + Some(vb.get((out_dim,), "bias")?.to_dtype(bias_ty)?) + } else { + None + }; + + let dequant = ops::fp8_blockwise_dequantize( + &weight, + &weight_scale_inv, + weight_block_size.to_vec(), + vb.dtype(), + )?; + + let layer = match quant { + Some(q) => BlockwiseFP8ReplicatedLinear::Quantized { + w: QMatMul::from_qtensor(QTensor::quantize(&dequant, q)?)?, + b: bias, + }, + None => BlockwiseFP8ReplicatedLinear::Unquantized(Linear::new(dequant, bias)), + }; + + Ok(layer) +} diff --git a/candle-flash-attn/cutlass b/candle-flash-attn/cutlass index 4c42f73fda..7d49e6c7e2 160000 --- a/candle-flash-attn/cutlass +++ b/candle-flash-attn/cutlass @@ -1 +1 @@ -Subproject commit 4c42f73fdab5787e3bb57717f35a8cb1b3c0dc6d +Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc diff --git a/candle-transformers/src/models/deepseekv3/mod.rs b/candle-transformers/src/models/deepseekv3/mod.rs deleted file mode 100644 index bec5297816..0000000000 --- a/candle-transformers/src/models/deepseekv3/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod model; -mod ops; -mod quant; diff --git a/candle-transformers/src/models/deepseekv3/quant.rs b/candle-transformers/src/models/deepseekv3/quant.rs deleted file mode 100644 index 9bae465769..0000000000 --- a/candle-transformers/src/models/deepseekv3/quant.rs +++ /dev/null @@ -1,116 +0,0 @@ -use candle::{ - quantized::{GgmlDType, QMatMul, QTensor}, - DType, Module, Result, Tensor, -}; -use candle_nn::{Linear, VarBuilder}; -use serde::Deserialize; - -use super::ops; - -#[derive(Debug, Clone, Deserialize)] -pub enum QuantMethodType { - #[serde(rename = "fp8")] - Fp8, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct QuantizedConfig { - pub weight_block_size: Option>, - pub quant_method: QuantMethodType, -} - -pub enum BlockwiseFP8Linear { - Quantized { w: QMatMul, b: Option }, - Unquantized(Linear), -} - -impl Module for BlockwiseFP8Linear { - fn forward(&self, xs: &Tensor) -> Result { - match self { - Self::Quantized { w, b } => { - let in_ty = xs.dtype(); - let mut xs = xs.to_dtype(DType::F32)?; - - xs = w.forward(&xs)?; - if let Some(bias) = b { - xs = xs.broadcast_add(bias)?; - } - - xs.to_dtype(in_ty) - } - Self::Unquantized(l) => xs.apply(l), - } - } -} - -/// Load a blockwise quantized FP8 layer and optionally quantize it in-place for faster inference. -pub fn blockwise_fp8_linear_b( - in_dim: usize, - out_dim: usize, - config: &Option, - bias: bool, - quant: Option, - vb: VarBuilder, -) -> Result { - let Some(config) = config else { - return Ok(BlockwiseFP8Linear::Unquantized(candle_nn::linear_b( - in_dim, out_dim, bias, vb, - )?)); - }; - - if !matches!(config.quant_method, QuantMethodType::Fp8) { - candle::bail!("Expected FP8 quant method!"); - } - - let weight_block_size = config - .weight_block_size - .as_ref() - .expect("Blockwise FP8 requires weight_block_size in config"); - if weight_block_size.len() != 2 { - candle::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}") - } - let weight = vb.get_with_hints_dtype( - (out_dim, in_dim), - "weight", - Default::default(), - DType::F8E4M3, - )?; - let weight_scale_inv = vb.get_with_hints_dtype( - ( - out_dim.div_ceil(weight_block_size[0]), - in_dim.div_ceil(weight_block_size[1]), - ), - "weight_scale_inv", - Default::default(), - DType::F32, - )?; - - let bias_ty = if quant.is_some() { - DType::F32 - } else { - vb.dtype() - }; - - let bias = if bias { - Some(vb.get((out_dim,), "bias")?.to_dtype(bias_ty)?) - } else { - None - }; - - let dequant = ops::fp8_blockwise_dequantize( - &weight, - &weight_scale_inv, - weight_block_size.to_vec(), - vb.dtype(), - )?; - - let layer = match quant { - Some(q) => BlockwiseFP8Linear::Quantized { - w: QMatMul::from_qtensor(QTensor::quantize(&dequant, q)?)?, - b: bias, - }, - None => BlockwiseFP8Linear::Unquantized(Linear::new(dequant, None)), - }; - - Ok(layer) -} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 269c2a02da..df1de0b276 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -28,7 +28,6 @@ pub mod colpali; pub mod convmixer; pub mod convnext; pub mod dac; -pub mod deepseekv3; pub mod depth_anything_v2; pub mod dinov2; pub mod dinov2reg4; From 533232c10358434404381211613252ac3e288a71 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Tue, 28 Jan 2025 08:58:59 -0500 Subject: [PATCH 8/8] Add quant support --- candle-examples/examples/deepseekv3/main.rs | 23 ++++++++++-- candle-examples/examples/deepseekv3/model.rs | 37 +++++++++++++++++--- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/deepseekv3/main.rs b/candle-examples/examples/deepseekv3/main.rs index f50578ca52..f77d8aa7b7 100644 --- a/candle-examples/examples/deepseekv3/main.rs +++ b/candle-examples/examples/deepseekv3/main.rs @@ -10,6 +10,7 @@ extern crate intel_mkl_src; use anyhow::{bail, Error as E, Result}; +use candle::quantized::GgmlDType; use clap::{Parser, ValueEnum}; use candle::{DType, Device, Tensor}; @@ -75,6 +76,10 @@ struct Args { #[arg(long)] dtype: Option, + /// Quantization to apply to the model for faster inference. Defaults to q4k. One of: q2k,q3k,q4k,q5k,q8_0 + #[arg(long)] + quant: Option, + #[arg(long, default_value = "r1")] which: Which, @@ -171,11 +176,25 @@ fn main() -> Result<()> { let device = Device::new_cuda(rank)?; + let quant = match args.quant { + Some(x) => match x.to_lowercase().as_str() { + "q2k" => GgmlDType::Q2K, + "q3k" => GgmlDType::Q3K, + "q4k" => GgmlDType::Q4K, + "q5k" => GgmlDType::Q5K, + "q8_0" => GgmlDType::Q8_0, + other => { + anyhow::bail!("Quantization {other} is not supported, try q2k,q3k,q4k,q5k,q8_0") + } + }, + None => GgmlDType::Q4K, + }; + println!("building the model"); let vb = unsafe { candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)? }; - let llama = DeepSeekV3::new(&config, vb, None, comm)?; + let mut model = DeepSeekV3::new(&config, vb, Some(quant), comm)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str()); @@ -205,7 +224,7 @@ fn main() -> Result<()> { let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - let logits = llama.forward(&input, index_pos)?; + let logits = model.forward(&input, index_pos)?; let logits = logits.squeeze(0)?; index_pos += ctxt.len(); diff --git a/candle-examples/examples/deepseekv3/model.rs b/candle-examples/examples/deepseekv3/model.rs index 0bf948ad0f..e9a4446fa4 100644 --- a/candle-examples/examples/deepseekv3/model.rs +++ b/candle-examples/examples/deepseekv3/model.rs @@ -558,6 +558,7 @@ struct Attention { cfg: DeepSeekV3Config, q_head_dim: usize, softmax_scale: f64, + kv_cache: Option<(Tensor, Tensor)>, } impl Attention { @@ -641,11 +642,12 @@ impl Attention { cfg: cfg.clone(), q_head_dim, softmax_scale: cfg.softmax_scale() as f64, + kv_cache: None, }) } fn forward( - &self, + &mut self, xs: &Tensor, attention_mask: Option<&Tensor>, seqlen_offset: usize, @@ -687,7 +689,7 @@ impl Attention { let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?; let k_nope = kv_split[0].clone(); - let v = kv_split[1].clone(); + let mut v = kv_split[1].clone(); (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offset)?; @@ -740,6 +742,16 @@ impl Attention { &k_pe, )?; + (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &k], 2)?; + let value_states = Tensor::cat(&[prev_v, &v], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + let mut attn_out = { let att = (q.matmul(&k.t()?)? * self.softmax_scale)?; let att = match attention_mask { @@ -760,6 +772,10 @@ impl Attention { self.o_proj.forward(&attn_out) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None; + } } struct Mlp { @@ -1105,7 +1121,7 @@ impl DecoderLayer { } fn forward( - &self, + &mut self, xs: &Tensor, attention_mask: Option<&Tensor>, seqlen_offset: usize, @@ -1120,6 +1136,10 @@ impl DecoderLayer { .forward(&xs.apply(&self.post_attention_layernorm)?)?; residual + xs } + + fn clear_kv_cache(&mut self) { + self.attn.clear_kv_cache(); + } } pub struct DeepSeekV3 { @@ -1214,7 +1234,7 @@ impl DeepSeekV3 { .to_dtype(self.dtype) } - pub fn forward(&self, input_ids: &Tensor, seqlen_offset: usize) -> Result { + pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result { let (bs, seq_len) = input_ids.dims2()?; let mut xs = self.embed_tokens.forward(input_ids)?; let attention_mask = if seq_len == 1 { @@ -1223,7 +1243,7 @@ impl DeepSeekV3 { let mask = self.prepare_decoder_attention_mask(bs, seq_len, seqlen_offset)?; Some(mask) }; - for layer in &self.layers { + for layer in &mut self.layers { xs = layer.forward( &xs, attention_mask @@ -1238,4 +1258,11 @@ impl DeepSeekV3 { let logits = self.lm_head.forward(&xs)?; logits.to_dtype(DType::F32) } + + #[allow(unused)] + pub fn clear_kv_cache(&mut self) { + for layer in &mut self.layers { + layer.clear_kv_cache(); + } + } }