From 1be6483b4fc6cd1ffe70ec7e9f892c7ed388f619 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 26 Jan 2025 20:23:14 +0100 Subject: [PATCH 1/2] Remove the old MFA gemm kernels. --- candle-core/src/metal_backend/device.rs | 6 - candle-core/src/metal_backend/mod.rs | 33 +-- .../examples/metal_benchmarks.rs | 88 +++----- candle-metal-kernels/src/lib.rs | 194 +---------------- candle-metal-kernels/src/tests.rs | 206 ------------------ 5 files changed, 40 insertions(+), 487 deletions(-) diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 46be6ce4bb..fab80d34ec 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -121,8 +121,6 @@ pub struct MetalDevice { pub(crate) kernels: Arc, /// Seed for random number generation. pub(crate) seed: Arc>, - /// Whether to use the MLX matmul kernels instead of the MFA ones. - pub(crate) use_mlx_mm: bool, } impl std::fmt::Debug for MetalDevice { @@ -140,10 +138,6 @@ impl std::ops::Deref for MetalDevice { } impl MetalDevice { - pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) { - self.use_mlx_mm = use_mlx_mm - } - pub fn compile( &self, func_name: &'static str, diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 435b2ec549..70a512bc8e 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1469,7 +1469,7 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - } else if self.device.use_mlx_mm { + } else { let dtype = match self.dtype { DType::F32 => candle_metal_kernels::GemmDType::F32, DType::F16 => candle_metal_kernels::GemmDType::F16, @@ -1496,32 +1496,6 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - } else { - let name = match self.dtype { - DType::F32 => "sgemm", - DType::F16 => "hgemm", - dtype => { - return Err( - MetalError::Message(format!("matmul doesn't support {dtype:?}")).into(), - ) - } - }; - - candle_metal_kernels::call_gemm( - &self.device.device, - &command_buffer, - &self.device.kernels, - name, - (b, m, n, k), - lhs_l.stride(), - lhs_l.start_offset() * self.dtype.size_in_bytes(), - &self.buffer, - rhs_l.stride(), - rhs_l.start_offset() * rhs.dtype.size_in_bytes(), - &rhs.buffer, - &buffer, - ) - .map_err(MetalError::from)?; } Ok(Self::new( buffer, @@ -1884,10 +1858,6 @@ impl BackendDevice for MetalDevice { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); let kernels = Arc::new(Kernels::new()); - let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() { - Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true, - Ok(_) => false, - }; let seed = Arc::new(Mutex::new(device.new_buffer_with_data( [299792458].as_ptr() as *const c_void, 4, @@ -1901,7 +1871,6 @@ impl BackendDevice for MetalDevice { buffers: Arc::new(RwLock::new(HashMap::new())), kernels, seed, - use_mlx_mm, }) } diff --git a/candle-metal-kernels/examples/metal_benchmarks.rs b/candle-metal-kernels/examples/metal_benchmarks.rs index c9c279970d..f0de21e0c2 100644 --- a/candle-metal-kernels/examples/metal_benchmarks.rs +++ b/candle-metal-kernels/examples/metal_benchmarks.rs @@ -44,66 +44,46 @@ fn run_gemm(f32: bool, n: usize) -> Result<()> { ); (lhs, rhs) }; - let (dtype, name, sizeof) = if f32 { - (GemmDType::F32, "sgemm", core::mem::size_of::()) + let (dtype, sizeof) = if f32 { + (GemmDType::F32, core::mem::size_of::()) } else { - (GemmDType::F16, "hgemm", core::mem::size_of::()) + (GemmDType::F16, core::mem::size_of::()) }; let output = device.new_buffer((b * m * n * sizeof) as u64, options); - for mlx in [false, true] { - let mut sum_dt = 0f64; - let mut iters = 0usize; - for idx in 0.. { - let command_buffer = command_queue.new_command_buffer(); - let start_time = std::time::Instant::now(); - if mlx { - candle_metal_kernels::call_mlx_gemm( - &device, - command_buffer, - &kernels, - dtype, - (b, m, n, k), - &[m * k, k, 1], - 0, - &lhs, - &[n * k, n, 1], - 0, - &rhs, - &output, - )?; - } else { - candle_metal_kernels::call_gemm( - &device, - command_buffer, - &kernels, - name, - (b, m, n, k), - &[m * k, k, 1], - 0, - &lhs, - &[n * k, n, 1], - 0, - &rhs, - &output, - )?; - } - command_buffer.commit(); - command_buffer.wait_until_completed(); - let dt = start_time.elapsed().as_secs_f64(); - if idx < WARMUP_ITERS { - continue; - } - sum_dt += dt; - iters += 1; - if sum_dt > MIN_DUR { - break; - } + let mut sum_dt = 0f64; + let mut iters = 0usize; + for idx in 0.. { + let command_buffer = command_queue.new_command_buffer(); + let start_time = std::time::Instant::now(); + candle_metal_kernels::call_mlx_gemm( + &device, + command_buffer, + &kernels, + dtype, + (b, m, n, k), + &[m * k, k, 1], + 0, + &lhs, + &[n * k, n, 1], + 0, + &rhs, + &output, + )?; + command_buffer.commit(); + command_buffer.wait_until_completed(); + let dt = start_time.elapsed().as_secs_f64(); + if idx < WARMUP_ITERS { + continue; + } + sum_dt += dt; + iters += 1; + if sum_dt > MIN_DUR { + break; } - let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt); - let mlx = if mlx { "MLX" } else { "MFA" }; - println!("{mlx} {dtype:?}, {n:6} gflops {gflops:.0}"); } + let gflops = (2 * n * n * n * iters) as f64 / (1e9 * sum_dt); + println!("{dtype:?}, {n:6} gflops {gflops:.0}"); Ok(()) } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 79cfb99035..2e001a0f68 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -16,8 +16,6 @@ const CAST: &str = include_str!("cast.metal"); const CONV: &str = include_str!("conv.metal"); const FILL: &str = include_str!("fill.metal"); const INDEXING: &str = include_str!("indexing.metal"); -// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle -const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const MLX_GEMM: &str = include_str!("mlx_gemm.metal"); const QUANTIZED: &str = include_str!("quantized.metal"); const RANDOM: &str = include_str!("random.metal"); @@ -36,7 +34,6 @@ pub enum Source { Fill, Gemm, Indexing, - Mfa, Quantized, Random, Reduce, @@ -221,7 +218,6 @@ impl Kernels { Source::Ternary => TERNARY, Source::Unary => UNARY, Source::Sdpa => SDPA, - Source::Mfa => panic!("Invalid lib"), } } @@ -236,21 +232,11 @@ impl Kernels { if let Some(lib) = libraries.get(&source) { Ok(lib.clone()) } else { - let lib = match source { - Source::Mfa => { - let source_data = MFA; - device.new_library_with_data(source_data).map_err(|e| { - MetalKernelError::LoadLibraryError(format!( - "Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}" - )) - })? - } - source => { - let source_content = self.get_library_source(source); - device - .new_library_with_source(source_content, &CompileOptions::new()) - .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? - } + let lib = { + let source_content = self.get_library_source(source); + device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))? }; libraries.insert(source, lib.clone()); Ok(lib) @@ -1471,176 +1457,6 @@ impl ConstantValues { } } -#[allow(clippy::too_many_arguments)] -pub fn call_gemm( - device: &Device, - ep: impl EncoderProvider, - kernels: &Kernels, - name: &'static str, - (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - lhs_offset: usize, - lhs_buffer: &Buffer, - rhs_stride: &[usize], - rhs_offset: usize, - rhs_buffer: &Buffer, - output: &Buffer, -) -> Result<(), MetalKernelError> { - assert!(rhs_stride.len() >= 2); - assert!(lhs_stride.len() >= 2); - let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; - let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; - let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; - let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - // lhs has shape b, m, k - // We also allow for the case where the stride on the minor dimension is not as expected but - // there is a single element. - let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { - false - } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { - true - } else { - return Err(MetalKernelError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })?; - }; - // rhs has shape b, k, n - let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { - false - } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { - true - } else { - return Err(MetalKernelError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })?; - }; - let d_trans = false; - let alpha = 1.0f32; - let beta = 0.0f32; - let batched = b > 1; - let fused_activation = false; - let fused_bias = false; - let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { - let m_simd = 8; - let n_simd = 8; - let k_simd = 64; - let m_splits = 1; - let n_splits = 1; - (m_simd, n_simd, k_simd, m_splits, n_splits) - } else { - let m_simd = 40; - let n_simd = 40; - let k_simd = 32; - let m_splits = 1; - let n_splits = 1; - (m_simd, n_simd, k_simd, m_splits, n_splits) - }; - let constants = Some(ConstantValues::new(vec![ - (0, Value::USize(m)), - (1, Value::USize(n)), - (2, Value::USize(k)), - (10, Value::Bool(a_trans)), - (11, Value::Bool(b_trans)), - (13, Value::Bool(d_trans)), - (20, Value::F32(alpha)), - (21, Value::F32(beta)), - (100, Value::Bool(batched)), - (101, Value::Bool(fused_activation)), - // Garbage - (102, Value::Bool(false)), - (103, Value::Bool(false)), - (113, Value::Bool(false)), - (50_000, Value::Bool(false)), - // End garbage - (200, Value::U16(m_simd)), - (201, Value::U16(n_simd)), - (202, Value::U16(k_simd)), - (210, Value::U16(m_splits)), - (211, Value::U16(n_splits)), - (50_001, Value::Bool(fused_bias)), - ])); - let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; - let m_group = m_simd * m_splits; - let n_group = n_simd * n_splits; - - let a_block_length = m_group * k_simd; - let b_block_length = k_simd * n_group; - - let mut block_elements = a_block_length + b_block_length; - if (m % 8 != 0) && (n % 8 != 0) { - let c_block_length = m_group * n_group; - block_elements = std::cmp::max(c_block_length, block_elements) - } - if fused_bias { - if d_trans { - block_elements = std::cmp::max(block_elements, m_group); - } else { - block_elements = std::cmp::max(block_elements, n_group); - } - } - let bytes = match name { - "sgemm" => 4, - "hgemm" => 2, - "bgemm" => 2, - other => { - return Err(MetalKernelError::LoadLibraryError(format!( - "{other} is not a valid kernel for gemm" - ))); - } - }; - let block_bytes = block_elements * bytes; - - let encoder = ep.encoder(); - let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, block_bytes.into()); - encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); - encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); - encoder.set_buffer(2, Some(output), 0); - // TODO Tensor D - - let grid_z = b; - if batched { - let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; - let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; - let byte_stride_c = m * n * bytes as usize; - // TODO byte_stride_d - let byte_stride_d = 0; - - let buffer: Vec = vec![ - byte_stride_a as _, - byte_stride_b as _, - byte_stride_c as _, - byte_stride_d as _, - ]; - encoder.set_bytes( - 10, - (buffer.len() * core::mem::size_of::()) as NSUInteger, - buffer.as_ptr() as *const NSUInteger as *const c_void, - ); - } - - let grid_size = MTLSize { - width: divide(n, n_group.into()), - height: divide(m, m_group.into()), - depth: grid_z as NSUInteger, - }; - let group_size = MTLSize { - width: 32 * (m_splits as u64) * (n_splits as u64), - height: 1, - depth: 1, - }; - encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(grid_size, group_size); - Ok(()) -} - #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum SdpaDType { BF16, diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 637bf2e243..99e711f151 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1046,168 +1046,6 @@ fn where_cond_u32_f32() { assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } -#[allow(clippy::too_many_arguments)] -fn run_gemm( - name: &'static str, - (b, m, n, k): (usize, usize, usize, usize), - lhs: &[T], - lhs_stride: &[usize], - lhs_offset: usize, - rhs: &[T], - rhs_stride: &[usize], - rhs_offset: usize, -) -> Vec { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - - let lhs = device.new_buffer_with_data( - lhs.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(lhs) as u64, - options, - ); - let rhs = device.new_buffer_with_data( - rhs.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(rhs) as u64, - options, - ); - let length = b * m * n; - let output = device.new_buffer((length * core::mem::size_of::()) as u64, options); - call_gemm( - &device, - command_buffer, - &kernels, - name, - (b, m, n, k), - lhs_stride, - lhs_offset, - &lhs, - rhs_stride, - rhs_offset, - &rhs, - &output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - read_to_vec(&output, length) -} - -#[test] -fn gemm() { - let (b, m, n, k) = (1, 2, 4, 3); - let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); - let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); - let results = run_gemm( - "sgemm", - (b, m, n, k), - &lhs, - &lhs_stride, - 0, - &rhs, - &rhs_stride, - 0, - ); - assert_eq!( - approx(results, 4), - vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] - ); - - let (b, m, n, k) = (2, 2, 4, 3); - let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); - let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); - let results = run_gemm( - "sgemm", - (b, m, n, k), - &lhs, - &lhs_stride, - 0, - &rhs, - &rhs_stride, - 0, - ); - assert_eq!( - approx(results, 4), - vec![ - 20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0, - 518.0, 548.0, 578.0 - ] - ); - - // OFFSET - let (b, m, n, k) = (2, 2, 4, 3); - let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); - let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); - // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 - let results = run_gemm( - "sgemm", - (1, m, n, k), - &lhs, - &lhs_stride, - 0, - &rhs, - &rhs_stride, - 12 * 4, - ); - assert_eq!( - approx(results, 4), - vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] - ); - - // bgemm sanity test - if false { - let (b, m, n, k) = (1, 2, 4, 3); - let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); - let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); - let results = run_gemm( - "bgemm", - (b, m, n, k), - &lhs, - &lhs_stride, - 0, - &rhs, - &rhs_stride, - 0, - ); - assert_eq!( - approx_bf16(results, 4), - vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] - ); - } - - // hgemm sanity test - let (b, m, n, k) = (1, 2, 4, 3); - let lhs_stride = vec![m * k, k, 1]; - let lhs: Vec = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); - let rhs_stride = vec![n * k, n, 1]; - let rhs: Vec = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); - let results = run_gemm( - "hgemm", - (b, m, n, k), - &lhs, - &lhs_stride, - 0, - &rhs, - &rhs_stride, - 0, - ); - assert_eq!( - approx_f16(results, 4), - vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] - ); -} - #[allow(clippy::too_many_arguments)] fn run_mlx_gemm( dtype: GemmDType, @@ -1258,50 +1096,6 @@ fn run_mlx_gemm( read_to_vec(&output, length) } -fn mlx_vs_mfa_one(b: usize, m: usize, n: usize, k: usize, dtype: GemmDType) { - use rand::SeedableRng; - use rand_distr::Distribution; - - let mut rng = rand::rngs::StdRng::seed_from_u64(42424242); - let normal = rand_distr::Normal::new(0.0, 1.0).unwrap(); - - let lhs: Vec<_> = (0..b * m * k).map(|_| normal.sample(&mut rng)).collect(); - let rhs: Vec<_> = (0..b * n * k).map(|_| normal.sample(&mut rng)).collect(); - let v1: Vec = run_mlx_gemm( - dtype, - (b, m, n, k), - &lhs, - &[m * k, k, 1], - 0, - &rhs, - &[k * n, n, 1], - 0, - ); - let v2: Vec = run_gemm( - "sgemm", - (b, m, n, k), - &lhs, - &[m * k, k, 1], - 0, - &rhs, - &[k * n, n, 1], - 0, - ); - for (a, b) in v1.iter().zip(v2.iter()) { - let diff = (a - b).abs(); - assert_eq!((diff * 1e4).round(), 0.) - } -} - -#[test] -fn mlx_vs_mfa() { - mlx_vs_mfa_one(1, 32, 32, 25, GemmDType::F32); - mlx_vs_mfa_one(1, 128, 128, 100, GemmDType::F32); - mlx_vs_mfa_one(1, 256, 256, 256, GemmDType::F32); - mlx_vs_mfa_one(1, 192, 200, 75, GemmDType::F32); - mlx_vs_mfa_one(3, 27, 67, 64, GemmDType::F32); -} - #[test] fn mlx_gemm() { let (b, m, n, k) = (1, 2, 4, 3); From ab2b79a65cb97a697b00a6ac4f5d9e1e4cabbac2 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 26 Jan 2025 20:30:02 +0100 Subject: [PATCH 2/2] Use bf16 in helium on metal. --- candle-examples/examples/helium/main.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/candle-examples/examples/helium/main.rs b/candle-examples/examples/helium/main.rs index 31f949bf33..fc7e6b6044 100644 --- a/candle-examples/examples/helium/main.rs +++ b/candle-examples/examples/helium/main.rs @@ -263,11 +263,7 @@ fn main() -> Result<()> { }; let device = candle_examples::device(args.cpu)?; let (model, device) = { - let dtype = if device.is_cuda() { - DType::BF16 - } else { - DType::F32 - }; + let dtype = device.bf16_default_to_f32(); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let model = Model::new(&config, vb)?; (model, device)