Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DeepSeek V3/R1 #2745

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Tensor parallelism support
EricLBuehler committed Jan 28, 2025
commit 406b991cfbf1e610b6061fb2d119950eaea3cfde
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -7,5 +7,6 @@
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"rust-analyzer.cargo.features": ["cuda", "nccl"]
}
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -48,7 +48,7 @@ fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] }
float8 = { version = "0.1.3", features = ["num-traits", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
18 changes: 7 additions & 11 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
@@ -313,17 +313,13 @@ impl BackendDevice for CudaDevice {
elem_count
};
let slice = match dtype {
DType::U8
| DType::U32
| DType::I16
| DType::I32
| DType::I64
| DType::F16
| DType::BF16 => Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?,
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 | DType::F8E4M3 => {
Err(CudaError::UnsupportedDtype {
dtype,
op: "rand_normal",
})
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
160 changes: 0 additions & 160 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
@@ -2117,163 +2117,3 @@ unsafe fn gemm_strided_batched_bf16(
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}

pub struct KVConcat {
pub concat_dim: usize,
}
impl crate::CustomOp2 for KVConcat {
fn name(&self) -> &'static str {
"kvconcat"
}

fn cpu_fwd(
&self,
_: &CpuStorage,
_: &Layout,
_: &CpuStorage,
_: &Layout,
) -> Result<(CpuStorage, Shape)> {
crate::bail!("no cpu support for kvconcat")
}

fn cuda_fwd(
&self,
ltensor: &CudaStorage,
ltensor_l: &Layout,
rtensor: &CudaStorage,
rtensor_l: &Layout,
) -> Result<(CudaStorage, Shape)> {
assert!(self.concat_dim == 2 || self.concat_dim == 0); //must be in the dim of sequence len
let dev = &ltensor.device;
let elem_count = ltensor_l.shape().elem_count() + rtensor_l.shape().elem_count();
let dims_l = ltensor_l.shape().dims();
let dims_r = rtensor_l.shape().dims();
let dim_size = dims_l.len();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);

let chunk_l = if dim_size > 3 {
dims_l[0] * dims_l[1]
} else {
dims_l[0]
};
let chunk_r = if dim_size > 3 {
dims_r[0] * dims_r[1]
} else {
dims_r[0]
};
let lstride = if dim_size > 3 {
dims_l[2] * dims_l[3]
} else {
dims_l[1] * dims_l[2]
};
let rstride = if dim_size > 3 {
dims_r[2] * dims_r[3]
} else {
dims_r[1] * dims_r[2]
};

let slice = match (&ltensor.slice, &rtensor.slice) {
(CudaStorageSlice::BF16(left_), CudaStorageSlice::BF16(right_)) => {
let out = unsafe { dev.alloc::<bf16>(elem_count).w()? };
let func = dev.get_or_load_func("kvconcat_bf16", kernels::KVCONCAT)?;
let params = (
left_,
right_,
&out,
self.concat_dim,
chunk_l,
chunk_r,
lstride,
rstride,
);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::BF16(out)
}
(CudaStorageSlice::F32(left_), CudaStorageSlice::F32(right_)) => {
let out = unsafe { dev.alloc::<f32>(elem_count).w()? };
let func = dev.get_or_load_func("kvconcat_f32", kernels::KVCONCAT)?;
let params = (
left_,
right_,
&out,
self.concat_dim,
chunk_l,
chunk_r,
lstride,
rstride,
);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F16(left_), CudaStorageSlice::F16(right_)) => {
let out = unsafe { dev.alloc::<f16>(elem_count).w()? };
let func = dev.get_or_load_func("kvconcat_f16", kernels::KVCONCAT)?;
let params = (
left_,
right_,
&out,
self.concat_dim,
chunk_l,
chunk_r,
lstride,
rstride,
);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F16(out)
}
(CudaStorageSlice::F64(left_), CudaStorageSlice::F64(right_)) => {
let out = unsafe { dev.alloc::<f64>(elem_count).w()? };
let func = dev.get_or_load_func("kvconcat_f64", kernels::KVCONCAT)?;
let params = (
left_,
right_,
&out,
self.concat_dim,
chunk_l,
chunk_r,
lstride,
rstride,
);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::F64(out)
}
(CudaStorageSlice::U8(left_), CudaStorageSlice::U8(right_)) => {
let out = unsafe { dev.alloc::<u8>(elem_count).w()? };
let func = dev.get_or_load_func("kvconcat_u8", kernels::KVCONCAT)?;
let params = (
left_,
right_,
&out,
self.concat_dim,
chunk_l,
chunk_r,
lstride,
rstride,
);
unsafe { func.launch(cfg, params) }.w()?;
CudaStorageSlice::U8(out)
}
_ => Err(CudaError::InternalError("dtype mismatch in kvconcat op"))?,
};

let mut lshape: Vec<usize> = ltensor_l.shape().dims().to_vec();
if self.concat_dim == 0 {
lshape[0] += rtensor_l.shape().dims()[0];
} else {
if dim_size > 3 {
lshape[2] += rtensor_l.shape().dims()[2];
} else {
lshape[1] += rtensor_l.shape().dims()[1];
}
}

let device = dev.clone();
Ok((
CudaStorage {
slice: slice,
device,
},
lshape.into(),
))
}
}
9 changes: 9 additions & 0 deletions candle-core/src/sort.rs
Original file line number Diff line number Diff line change
@@ -61,6 +61,15 @@ mod cuda {
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType};

#[allow(unused)]
fn next_power_of_2(x: usize) -> usize {
let mut n = 1;
while n < x {
n *= 2
}
n
}

impl crate::cuda_backend::Map1Any for ArgSort {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
&self,
5 changes: 5 additions & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ ab_glyph = { workspace = true }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
float8 = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.43.0"

@@ -122,3 +123,7 @@ required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]

[[example]]
name = "deepseekv3"
required-features = ["cuda", "nccl"]
236 changes: 236 additions & 0 deletions candle-examples/examples/deepseekv3/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// An implementation of LLaMA https://github.com/facebookresearch/llama
//
// This is based on nanoGPT in a similar way to:
// https://github.com/Lightning-AI/lit-llama/blob/main/lit_llama/model.py
//
// The tokenizer config can be retrieved from:
// https://huggingface.co/hf-internal-testing/llama-tokenizer/raw/main/tokenizer.json

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use anyhow::{bail, Error as E, Result};
use clap::{Parser, ValueEnum};

use candle::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use std::rc::Rc;

mod model;
mod ops;
mod quant;
use model::{DeepSeekV3, DeepSeekV3Config};

const DEFAULT_PROMPT: &str = "My favorite theorem is ";

#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
R1,
V3,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long)]
num_shards: usize,

#[arg(long)]
rank: Option<usize>,

/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.8)]
temperature: f64,

/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,

/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
sample_len: usize,

/// Disable the key-value cache.
#[arg(long)]
no_kv_cache: bool,

/// The initial prompt.
#[arg(long)]
prompt: Option<String>,

#[arg(long)]
model_id: Option<String>,

#[arg(long)]
revision: Option<String>,

#[arg(long)]
dtype: Option<String>,

#[arg(long, default_value = "r1")]
which: Which,

#[arg(long, default_value = "nccl_id.txt")]
comm_file: String,
}

fn main() -> Result<()> {
use tokenizers::Tokenizer;

let args = Args::parse();

let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::BF16,
};

let comm_file = std::path::PathBuf::from(&args.comm_file);
if comm_file.exists() {
bail!("comm file {comm_file:?} already exists, please remove it first")
}

let api = Api::new()?;
let model_id = match args.model_id {
Some(model) => model,
None => match args.which {
Which::V3 => "deepseek-ai/DeepSeek-V3".to_string(),
Which::R1 => "deepseek-ai/DeepSeek-R1".to_string(),
},
};
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let config_filename = api.get("config.json")?;
let config: DeepSeekV3Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let tokenizer_filename = api.get("tokenizer.json")?;
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;

let rank = match args.rank {
None => {
println!("creating {} child processes", args.num_shards);
let children: Vec<_> = (0..args.num_shards)
.map(|rank| {
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
args.push_back("--rank".to_string());
args.push_back(format!("{rank}"));
let name = args.pop_front().unwrap();
std::process::Command::new(name).args(args).spawn().unwrap()
})
.collect();
for mut child in children {
child.wait()?;
}
return Ok(());
}
Some(rank) => rank,
};

let num_shards = args.num_shards;
// Primitive IPC
let id = if rank == 0 {
let id = Id::new().unwrap();
let tmp_file = comm_file.with_extension(".comm.tgz");
std::fs::File::create(&tmp_file)?
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
std::fs::rename(&tmp_file, &comm_file)?;
id
} else {
while !comm_file.exists() {
std::thread::sleep(std::time::Duration::from_secs(1));
}
let data = std::fs::read(&comm_file)?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
.collect::<Vec<_>>()
.try_into()
.unwrap();
let id: Id = Id::uninit(internal);
id
};
let device = CudaDevice::new(rank)?;
let comm = match Comm::from_rank(device, rank, num_shards, id) {
Ok(comm) => Rc::new(comm),
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
};
if rank == 0 {
std::fs::remove_file(comm_file)?;
}
println!("Rank {rank:?} spawned");

let device = Device::new_cuda(rank)?;

println!("building the model");
let vb = unsafe {
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
};
let llama = DeepSeekV3::new(&config, vb, None, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
let mut tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);

println!("starting the inference loop");
let temperature = if args.temperature <= 0. {
None
} else {
Some(args.temperature)
};
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let mut new_tokens = vec![];
let mut start_gen = std::time::Instant::now();
let mut index_pos = 0;
for index in 0..args.sample_len {
// Only start timing at the second token as processing the first token waits for all the
// weights to be loaded in an async way.
if index == 1 {
start_gen = std::time::Instant::now()
};
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();

let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
if next_token == config.eos_token_id {
break;
}

if rank == 0 {
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
}
println!();
if rank == 0 {
let dt = start_gen.elapsed();
println!(
"\n\n{} tokens generated ({} token/s)\n",
args.sample_len,
(args.sample_len - 1) as f64 / dt.as_secs_f64(),
);
}
Ok(())
}

Large diffs are not rendered by default.

File renamed without changes.
452 changes: 452 additions & 0 deletions candle-examples/examples/deepseekv3/quant.rs

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion candle-flash-attn/cutlass
Submodule cutlass updated 582 files
3 changes: 0 additions & 3 deletions candle-transformers/src/models/deepseekv3/mod.rs

This file was deleted.

116 changes: 0 additions & 116 deletions candle-transformers/src/models/deepseekv3/quant.rs

This file was deleted.

1 change: 0 additions & 1 deletion candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
@@ -28,7 +28,6 @@ pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod dac;
pub mod deepseekv3;
pub mod depth_anything_v2;
pub mod dinov2;
pub mod dinov2reg4;