diff --git a/model-parameters/Cargo.toml b/model-parameters/Cargo.toml index 85e135cc..11b91f14 100644 --- a/model-parameters/Cargo.toml +++ b/model-parameters/Cargo.toml @@ -9,13 +9,9 @@ authors = ["YdrMaster "] [dependencies] common = { path = "../common" } tensor = { path = "../tensor" } -log = "0.4" half = "2.3" rayon = "1.8" memmap2 = "0.9" safetensors = "0.4" serde_json = "1.0" serde = { version = "1.0", features = ["derive"] } - -[dev-dependencies] -env_logger = "0.11" diff --git a/model-parameters/src/lib.rs b/model-parameters/src/lib.rs index 7970ba1d..46487155 100644 --- a/model-parameters/src/lib.rs +++ b/model-parameters/src/lib.rs @@ -3,9 +3,6 @@ mod memory; mod save; mod storage; -#[macro_use] -extern crate log; - use common::utok; use storage::Storage; use tensor::{DataType, Tensor}; @@ -79,15 +76,3 @@ struct ConfigJson { pub rope_theta: f32, pub torch_dtype: DataType, } - -struct LayerParamsOffset { - input_layernorm: usize, - self_attn_q_proj: usize, - self_attn_k_proj: usize, - self_attn_v_proj: usize, - self_attn_o_proj: usize, - post_attention_layernorm: usize, - mlp_gate: usize, - mlp_down: usize, - mlp_up: usize, -} diff --git a/model-parameters/src/memory/cast.rs b/model-parameters/src/memory/cast.rs new file mode 100644 index 00000000..7d400c7a --- /dev/null +++ b/model-parameters/src/memory/cast.rs @@ -0,0 +1,94 @@ +use super::Layer; +use crate::{storage::Storage, ConfigJson, DataType, Llama2, Memory}; +use half::{bf16, f16}; +use std::sync::Arc; +use tensor::Tensor; + +impl Memory { + pub fn cast(src: &dyn Llama2, new_dtype: DataType) -> Self { + Self { + config: ConfigJson { + bos_token_id: src.bos_token_id(), + eos_token_id: src.eos_token_id(), + hidden_size: src.hidden_size(), + intermediate_size: src.intermediate_size(), + max_position_embeddings: src.max_position_embeddings(), + num_attention_heads: src.num_attention_heads(), + num_hidden_layers: src.num_hidden_layers(), + num_key_value_heads: src.num_key_value_heads(), + vocab_size: src.vocab_size(), + rms_norm_eps: src.rms_norm_eps(), + rope_theta: src.rope_theta(), + torch_dtype: new_dtype, + }, + embed_tokens: cast(src.embed_tokens(), new_dtype), + layers: (0..src.num_hidden_layers()) + .map(|l| Layer { + input_layernorm: cast(src.input_layernorm(l), new_dtype), + self_attn_q_proj: cast(src.self_attn_q_proj(l), new_dtype), + self_attn_k_proj: cast(src.self_attn_k_proj(l), new_dtype), + self_attn_v_proj: cast(src.self_attn_v_proj(l), new_dtype), + self_attn_o_proj: cast(src.self_attn_o_proj(l), new_dtype), + post_attention_layernorm: cast(src.post_attention_layernorm(l), new_dtype), + mlp_gate: cast(src.mlp_gate(l), new_dtype), + mlp_down: cast(src.mlp_down(l), new_dtype), + mlp_up: cast(src.mlp_up(l), new_dtype), + }) + .collect(), + model_norm: cast(src.model_norm(), new_dtype), + lm_head: cast(src.lm_head(), new_dtype), + } + } +} + +fn cast(src: Tensor, new_dtype: DataType) -> Tensor { + if src.data_type() == new_dtype { + return src; + } + + let src_data = src.physical().as_slice(); + let mut data = vec![0u8; src.size() * new_dtype.size()]; + + macro_rules! cast { + ($f:expr; $src:expr, $src_ty:ty => $dst:expr, $dst_ty:ty) => { + use rayon::iter::*; + use std::{mem::size_of, slice::from_raw_parts, slice::from_raw_parts_mut}; + + let len = $src.len() / size_of::<$src_ty>(); + debug_assert_eq!(len * size_of::<$dst_ty>(), $dst.len()); + let src = unsafe { from_raw_parts($src.as_ptr() as *const $src_ty, len) }; + let dst = unsafe { from_raw_parts_mut($dst.as_mut_ptr() as *mut $dst_ty, len) }; + + #[allow(clippy::redundant_closure_call)] + src.par_iter() + .zip(dst) + .for_each(|(src, dst)| *dst = $f(*src)); + }; + } + + match (src.data_type(), new_dtype) { + (DataType::F16, DataType::BF16) => { + cast!(|x: f16| bf16::from_f32(x.to_f32()); src_data, f16 => &mut data, bf16); + } + (DataType::F16, DataType::F32) => { + cast!(|x: f16| x.to_f32(); src_data, f16 => &mut data, f32); + } + (DataType::BF16, DataType::F16) => { + cast!(|x: bf16| f16::from_f32(x.to_f32()); src_data, bf16 => &mut data, f16); + } + (DataType::BF16, DataType::F32) => { + cast!(|x: bf16| x.to_f32(); src_data, bf16 => &mut data, f32); + } + (DataType::F32, DataType::F16) => { + cast!(|x: f32| f16::from_f32(x); src_data, f32 => &mut data, f16); + } + (DataType::F32, DataType::BF16) => { + cast!(|x: f32| bf16::from_f32(x); src_data, f32 => &mut data, bf16); + } + _ => todo!(), + } + + let len = data.len(); + let pysical = Storage::new(Arc::new(data), 0, len); + unsafe { src.cast(new_dtype, pysical) } +} diff --git a/model-parameters/src/memory/inside_memory.rs b/model-parameters/src/memory/inside_memory.rs deleted file mode 100644 index 0449d618..00000000 --- a/model-parameters/src/memory/inside_memory.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::sync::Arc; - -use crate::{ConfigJson, DataType, LayerParamsOffset, Llama2, Memory}; -use half::{bf16, f16}; - -impl Memory> { - pub fn cast(src: &dyn Llama2, new_dtype: DataType) -> Self { - let from = src.data_type(); - let mut blob = Vec::with_capacity(src.size() * new_dtype.size() / from.size()); - let mut append = |src: &[u8]| { - let start = blob.len(); - let additional = src.len() * new_dtype.size() / from.size(); - let end = start + additional; - - blob.reserve_exact(additional); - unsafe { blob.set_len(end) }; - - cast(from, src, new_dtype, &mut blob[start..end]); - start - }; - - if let Some(n) = std::option_env!("RAYON_PARALLELISM") { - rayon::ThreadPoolBuilder::new() - .num_threads(n.parse().unwrap()) - .build_global() - .unwrap(); - } - - let embed_tokens = append(src.embed_tokens().physical().as_slice()); - let layers = (0..src.num_hidden_layers()) - .map(|layer| LayerParamsOffset { - input_layernorm: append(src.input_layernorm(layer).physical().as_slice()), - self_attn_q_proj: append(src.self_attn_q_proj(layer).physical().as_slice()), - self_attn_k_proj: append(src.self_attn_k_proj(layer).physical().as_slice()), - self_attn_v_proj: append(src.self_attn_v_proj(layer).physical().as_slice()), - self_attn_o_proj: append(src.self_attn_o_proj(layer).physical().as_slice()), - post_attention_layernorm: append( - src.post_attention_layernorm(layer).physical().as_slice(), - ), - mlp_gate: append(src.mlp_gate(layer).physical().as_slice()), - mlp_down: append(src.mlp_down(layer).physical().as_slice()), - mlp_up: append(src.mlp_up(layer).physical().as_slice()), - }) - .collect(); - let model_norm = append(src.model_norm().physical().as_slice()); - let lm_head = append(src.lm_head().physical().as_slice()); - - Self { - config: ConfigJson { - bos_token_id: src.bos_token_id(), - eos_token_id: src.eos_token_id(), - hidden_size: src.hidden_size(), - intermediate_size: src.intermediate_size(), - max_position_embeddings: src.max_position_embeddings(), - num_attention_heads: src.num_attention_heads(), - num_hidden_layers: src.num_hidden_layers(), - num_key_value_heads: src.num_key_value_heads(), - vocab_size: src.vocab_size(), - rms_norm_eps: src.rms_norm_eps(), - rope_theta: src.rope_theta(), - torch_dtype: new_dtype, - }, - blob: Arc::new(blob), - embed_tokens, - layers, - model_norm, - lm_head, - } - } -} - -fn cast(src_dtype: DataType, src: &[u8], dst_dtype: DataType, dst: &mut [u8]) { - macro_rules! cast { - ($f:expr; $src:expr, $src_ty:ty => $dst:expr, $dst_ty:ty) => { - use rayon::iter::*; - use std::{mem::size_of, slice::from_raw_parts, slice::from_raw_parts_mut}; - - let len = $src.len() / size_of::<$src_ty>(); - debug_assert_eq!(len * size_of::<$dst_ty>(), $dst.len()); - let src = unsafe { from_raw_parts($src.as_ptr() as *const $src_ty, len) }; - let dst = unsafe { from_raw_parts_mut($dst.as_mut_ptr() as *mut $dst_ty, len) }; - - #[allow(clippy::redundant_closure_call)] - src.par_iter() - .zip(dst) - .for_each(|(src, dst)| *dst = $f(*src)); - }; - } - - match (src_dtype, dst_dtype) { - (DataType::F16, DataType::F16) - | (DataType::BF16, DataType::BF16) - | (DataType::F32, DataType::F32) => dst.copy_from_slice(src), - - (DataType::F16, DataType::BF16) => { - cast!(|x: f16| bf16::from_f32(x.to_f32()); src, f16 => dst, bf16); - } - (DataType::F16, DataType::F32) => { - cast!(|x: f16| x.to_f32(); src, f16 => dst, f32); - } - (DataType::BF16, DataType::F16) => { - cast!(|x: bf16| f16::from_f32(x.to_f32()); src, bf16 => dst, f16); - } - (DataType::BF16, DataType::F32) => { - cast!(|x: bf16| x.to_f32(); src, bf16 => dst, f32); - } - (DataType::F32, DataType::F16) => { - cast!(|x: f32| f16::from_f32(x); src, f32 => dst, f16); - } - (DataType::F32, DataType::BF16) => { - cast!(|x: f32| bf16::from_f32(x); src, f32 => dst, bf16); - } - _ => todo!(), - } -} diff --git a/model-parameters/src/memory/mod.rs b/model-parameters/src/memory/mod.rs index a644fd23..e605eca7 100644 --- a/model-parameters/src/memory/mod.rs +++ b/model-parameters/src/memory/mod.rs @@ -1,23 +1,34 @@ -mod inside_memory; +mod cast; mod safe_tensors; -use crate::{storage::Storage, ConfigJson, DataType, LayerParamsOffset, Llama2}; +use crate::{storage::Storage, ConfigJson, DataType, Llama2}; use common::utok; +use tensor::Tensor; + pub use safe_tensors::SafeTensorError; pub(crate) use safe_tensors::SafeTensorHeaderJson; -use std::sync::Arc; -use tensor::{Shape, Tensor}; -pub struct Memory { +pub struct Memory { config: ConfigJson, - blob: Arc, - embed_tokens: usize, - layers: Vec, - model_norm: usize, - lm_head: usize, + embed_tokens: Tensor, + layers: Vec, + model_norm: Tensor, + lm_head: Tensor, +} + +struct Layer { + input_layernorm: Tensor, + self_attn_q_proj: Tensor, + self_attn_k_proj: Tensor, + self_attn_v_proj: Tensor, + self_attn_o_proj: Tensor, + post_attention_layernorm: Tensor, + mlp_gate: Tensor, + mlp_down: Tensor, + mlp_up: Tensor, } -impl> Llama2 for Memory { +impl Llama2 for Memory { #[inline] fn bos_token_id(&self) -> utok { self.config.bos_token_id @@ -80,165 +91,62 @@ impl> Llama2 for Memory { #[inline] fn embed_tokens(&self) -> Tensor { - let d = self.config.hidden_size; - let dv = self.config.vocab_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[dv as _, d as _]), - Storage::new(self.blob.clone(), self.embed_tokens, dv * d * dt), - ) + self.embed_tokens.clone() } #[inline] fn input_layernorm(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[d as _]), - Storage::new( - self.blob.clone(), - self.layers[layer].input_layernorm, - d * dt, - ), - ) + self.layers[layer].input_layernorm.clone() } #[inline] fn self_attn_q_proj(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[d as _, d as _]), - Storage::new( - self.blob.clone(), - self.layers[layer].self_attn_q_proj, - d * d * dt, - ), - ) + self.layers[layer].self_attn_q_proj.clone() } #[inline] fn self_attn_k_proj(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let dkv = d * self.config.num_key_value_heads / self.config.num_attention_heads; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[dkv as _, d as _]), - Storage::new( - self.blob.clone(), - self.layers[layer].self_attn_k_proj, - dkv * d * dt, - ), - ) + self.layers[layer].self_attn_k_proj.clone() } #[inline] fn self_attn_v_proj(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let dkv = d * self.config.num_key_value_heads / self.config.num_attention_heads; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[dkv as _, d as _]), - Storage::new( - self.blob.clone(), - self.layers[layer].self_attn_v_proj, - dkv * d * dt, - ), - ) + self.layers[layer].self_attn_v_proj.clone() } #[inline] fn self_attn_o_proj(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[d as _, d as _]), - Storage::new( - self.blob.clone(), - self.layers[layer].self_attn_o_proj, - d * d * dt, - ), - ) + self.layers[layer].self_attn_o_proj.clone() } #[inline] fn post_attention_layernorm(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[d as _]), - Storage::new( - self.blob.clone(), - self.layers[layer].post_attention_layernorm, - d * dt, - ), - ) + self.layers[layer].post_attention_layernorm.clone() } #[inline] fn mlp_gate(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let di = self.config.intermediate_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[di as _, d as _]), - Storage::new(self.blob.clone(), self.layers[layer].mlp_gate, di * d * dt), - ) + self.layers[layer].mlp_gate.clone() } #[inline] fn mlp_down(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let di = self.config.intermediate_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[d as _, di as _]), - Storage::new(self.blob.clone(), self.layers[layer].mlp_down, d * di * dt), - ) + self.layers[layer].mlp_down.clone() } #[inline] fn mlp_up(&self, layer: usize) -> Tensor { - let d = self.config.hidden_size; - let di = self.config.intermediate_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[di as _, d as _]), - Storage::new(self.blob.clone(), self.layers[layer].mlp_up, di * d * dt), - ) + self.layers[layer].mlp_up.clone() } #[inline] fn model_norm(&self) -> Tensor { - let d = self.config.hidden_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[d as _]), - Storage::new(self.blob.clone(), self.model_norm, d * dt), - ) + self.model_norm.clone() } #[inline] fn lm_head(&self) -> Tensor { - let d = self.config.hidden_size; - let dv: usize = self.config.vocab_size; - let dt: usize = self.data_type().size(); - Tensor::new( - self.data_type(), - Shape::from_slice(&[dv as _, d as _]), - Storage::new(self.blob.clone(), self.lm_head, dv * d * dt), - ) + self.lm_head.clone() } } @@ -246,9 +154,6 @@ impl> Llama2 for Memory { fn test_load() { use std::time::Instant; - // set env for POWERSHELL: `$env:RUST_LOG="INFO";` - env_logger::init(); - let t0 = Instant::now(); let safetensors = Memory::load_safetensors("../../TinyLlama-1.1B-Chat-v1.0"); let t1 = Instant::now(); diff --git a/model-parameters/src/memory/safe_tensors.rs b/model-parameters/src/memory/safe_tensors.rs index 2a4d4c60..462711c2 100644 --- a/model-parameters/src/memory/safe_tensors.rs +++ b/model-parameters/src/memory/safe_tensors.rs @@ -1,8 +1,9 @@ -use super::Memory; -use crate::{ConfigJson, DataType, LayerParamsOffset}; +use super::{Layer, Memory}; +use crate::{storage::Storage, ConfigJson, DataType}; use memmap2::Mmap; use safetensors::{tensor::TensorInfo, Dtype}; use std::{collections::HashMap, fs::File, path::Path, sync::Arc}; +use tensor::Tensor; #[derive(Debug)] pub enum SafeTensorError { @@ -10,19 +11,13 @@ pub enum SafeTensorError { Serde(serde_json::Error), } -impl Memory { +impl Memory { pub fn load_safetensors(model_dir: impl AsRef) -> Result { let dir = model_dir.as_ref(); let config = File::open(dir.join("config.json")).map_err(SafeTensorError::Io)?; let model = File::open(dir.join("model.safetensors")).map_err(SafeTensorError::Io)?; let config: ConfigJson = serde_json::from_reader(config).map_err(SafeTensorError::Serde)?; - let dtype = match config.torch_dtype { - DataType::F16 => Dtype::F16, - DataType::BF16 => Dtype::BF16, - DataType::F32 => Dtype::F32, - _ => todo!(), - }; let mmap = unsafe { Mmap::map(&model) }.map_err(SafeTensorError::Io)?; let len = unsafe { *mmap.as_ptr().cast::() } as usize; @@ -31,114 +26,54 @@ impl Memory { let header: SafeTensorHeaderJson = serde_json::from_slice(header).map_err(SafeTensorError::Serde)?; - let d = config.hidden_size; - let kv_dim = d * config.num_key_value_heads / config.num_attention_heads; - let di = config.intermediate_size; - - let mut embed_tokens = 0; - let mut layers = (0..config.num_hidden_layers) - .map(|_| LayerParamsOffset { - input_layernorm: 0, - self_attn_q_proj: 0, - self_attn_k_proj: 0, - self_attn_v_proj: 0, - self_attn_o_proj: 0, - post_attention_layernorm: 0, - mlp_gate: 0, - mlp_down: 0, - mlp_up: 0, - }) - .collect::>(); - let mut model_norm = 0; - let mut lm_head = 0; - - let header_offset = BASE_OFFSET + len; - for (name, tensor) in header.tensors { - let path = name.split('.').collect::>(); - let offset = header_offset + tensor.data_offsets.0; - - debug!(target: "import safetensors", "detect {offset:#010x} -> \"{name}\""); - match path.as_slice() { - ["model", "embed_tokens", "weight"] => { - assert_eq!(&tensor.shape, &[config.vocab_size, d]); - assert_eq!(tensor.dtype, dtype); - embed_tokens = offset; - } - ["model", "layers", n, path @ .., "weight"] => { - let layer = n.parse::().unwrap(); - - match path { - ["input_layernorm"] => { - assert_eq!(&tensor.shape, &[d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].input_layernorm = offset; - } - ["self_attn", "q_proj"] => { - assert_eq!(&tensor.shape, &[d, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].self_attn_q_proj = offset; - } - ["self_attn", "k_proj"] => { - assert_eq!(&tensor.shape, &[kv_dim, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].self_attn_k_proj = offset; - } - ["self_attn", "v_proj"] => { - assert_eq!(&tensor.shape, &[kv_dim, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].self_attn_v_proj = offset; - } - ["self_attn", "o_proj"] => { - assert_eq!(&tensor.shape, &[d, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].self_attn_o_proj = offset; - } - ["post_attention_layernorm"] => { - assert_eq!(&tensor.shape, &[d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].post_attention_layernorm = offset; - } - ["mlp", "gate_proj"] => { - assert_eq!(&tensor.shape, &[di, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].mlp_gate = offset; - } - ["mlp", "down_proj"] => { - assert_eq!(&tensor.shape, &[d, di]); - assert_eq!(tensor.dtype, dtype); - layers[layer].mlp_down = offset; - } - ["mlp", "up_proj"] => { - assert_eq!(&tensor.shape, &[di, d]); - assert_eq!(tensor.dtype, dtype); - layers[layer].mlp_up = offset; - } - [..] => { - warn!(target: "import safetensors", "Unknown tensor path: \"{name}\"") - } - }; - } - ["model", "norm", "weight"] => { - assert_eq!(&tensor.shape, &[d]); - assert_eq!(tensor.dtype, dtype); - model_norm = offset; - } - ["lm_head", "weight"] => { - assert_eq!(&tensor.shape, &[config.vocab_size, d]); - assert_eq!(tensor.dtype, dtype); - lm_head = offset; - } - [..] => warn!(target: "import safetensors", "Unknown tensor path: \"{name}\""), - } - } + let mmap = Arc::new(mmap); + let tensor = |name: &str| { + println!("name = {name}"); + let info = &header.tensors[name]; + let (start, end) = info.data_offsets; + Tensor::new( + match info.dtype { + Dtype::BOOL => DataType::Bool, + Dtype::I8 => DataType::I8, + Dtype::I16 => DataType::I16, + Dtype::I32 => DataType::I32, + Dtype::I64 => DataType::I64, + Dtype::U8 => DataType::U8, + Dtype::U16 => DataType::U16, + Dtype::U32 => DataType::U32, + Dtype::U64 => DataType::U64, + Dtype::F16 => DataType::F16, + Dtype::BF16 => DataType::BF16, + Dtype::F32 => DataType::F32, + Dtype::F64 => DataType::F64, + _ => unreachable!(), + }, + info.shape.iter().map(|&d| d as _).collect(), + Storage::new(mmap.clone(), start, end - start), + ) + }; Ok(Self { + embed_tokens: tensor("model.embed_tokens.weight"), + layers: (0..config.num_hidden_layers) + .map(|l| { + let name = |name: &str| format!("model.layers.{l}.{name}.weight"); + Layer { + input_layernorm: tensor(&name("input_layernorm")), + self_attn_q_proj: tensor(&name("self_attn.q_proj")), + self_attn_k_proj: tensor(&name("self_attn.k_proj")), + self_attn_v_proj: tensor(&name("self_attn.v_proj")), + self_attn_o_proj: tensor(&name("self_attn.o_proj")), + post_attention_layernorm: tensor(&name("post_attention_layernorm")), + mlp_gate: tensor(&name("mlp.gate_proj")), + mlp_down: tensor(&name("mlp.down_proj")), + mlp_up: tensor(&name("mlp.up_proj")), + } + }) + .collect(), + model_norm: tensor("model.norm.weight"), + lm_head: tensor("lm_head.weight"), config, - blob: Arc::new(mmap), - embed_tokens, - layers, - model_norm, - lm_head, }) } } diff --git a/tensor/src/tensor.rs b/tensor/src/tensor.rs index b7bf6091..21954bea 100644 --- a/tensor/src/tensor.rs +++ b/tensor/src/tensor.rs @@ -35,6 +35,11 @@ impl Tensor { &self.physical } + #[inline] + pub fn size(&self) -> usize { + self.shape.iter().map(|&d| d as usize).product() + } + pub fn is_contiguous(&self) -> bool { let strides = self.pattern.0.as_slice(); let n = strides.len(); @@ -42,6 +47,16 @@ impl Tensor { && (0..n - 2).all(|i| strides[i] == strides[i + 1] * self.shape[i + 1] as idim) } + #[inline] + pub unsafe fn cast(&self, dtype: DataType, physical: Physical) -> Self { + Self { + data_type: dtype, + shape: self.shape.clone(), + pattern: self.pattern.clone(), + physical, + } + } + pub fn reshape(&self, shape: Shape) -> Self { debug_assert!(self.is_contiguous()); debug_assert_eq!(