Skip to content

Commit

Permalink
feat(tensor): 添加 tensor crate,实现 transpose 算子
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Feb 18, 2024
1 parent b5353d1 commit 079f96b
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 1 deletion.
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
[workspace]
members = ["common", "model-parameters", "tokenizer", "transformer-cpu", "xtask"]
members = [
"common",
"tensor",
"model-parameters",
"tokenizer",
"transformer-cpu",
"xtask",
]
resolver = "2"
11 changes: 11 additions & 0 deletions tensor/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "tensor"
version = "0.0.0"
edition = "2021"
authors = ["YdrMaster <[email protected]>"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
smallvec = "1.13"
nalgebra = "0.32"
17 changes: 17 additions & 0 deletions tensor/src/data_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
#[repr(u8)]
pub enum DataType {
Bool,
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
F16,
BF16,
F32,
F64,
}
13 changes: 13 additions & 0 deletions tensor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
mod data_type;
mod operator;
mod tensor;

#[allow(non_camel_case_types)]
pub type udim = u32;

#[allow(non_camel_case_types)]
pub type idim = i32;

pub use data_type::DataType;
pub use operator::Operator;
pub use tensor::{Affine, Pattern, Shape, Tensor};
8 changes: 8 additions & 0 deletions tensor/src/operator/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
mod transpose;

use crate::{udim, Affine, Shape};

pub trait Operator {
fn infer_shape(&self, input: &[udim]) -> Shape;
fn to_affine(&self, input: &[udim]) -> Affine;
}
52 changes: 52 additions & 0 deletions tensor/src/operator/transpose.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use super::Operator;
use crate::{udim, Affine, Shape};
use smallvec::SmallVec;

pub struct Transpose {
perm: SmallVec<[udim; 4]>,
}

impl Operator for Transpose {
#[inline]
fn infer_shape(&self, input: &[udim]) -> Shape {
debug_assert_eq!(input.len(), self.perm.len());
self.perm.iter().map(|&i| input[i as usize]).collect()
}

fn to_affine(&self, input: &[udim]) -> Affine {
debug_assert_eq!(input.len(), self.perm.len());
let n = self.perm.len();
Affine::from_fn(n + 1, n + 1, |r, c| {
if c == self.perm.get(r).map_or(r, |&p| p as usize) {
1
} else {
0
}
})
}
}

#[test]
fn test() {
let operator = Transpose {
perm: Shape::from_slice(&[0, 2, 1, 3]),
};
assert_eq!(
operator.infer_shape(&[1, 2, 3, 4]),
Shape::from_slice(&[1, 3, 2, 4])
);
assert_eq!(
operator.to_affine(&[1, 2, 3, 4]),
Affine::from_vec(
5,
5,
vec![
1, 0, 0, 0, 0, //
0, 0, 1, 0, 0, //
0, 1, 0, 0, 0, //
0, 0, 0, 1, 0, //
0, 0, 0, 0, 1, //
]
)
);
}
42 changes: 42 additions & 0 deletions tensor/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use crate::{idim, udim, DataType, Operator};
use nalgebra::{DMatrix, DVector};
use smallvec::SmallVec;

pub struct Tensor<Physical> {
data_type: DataType,
shape: Shape,
pattern: Pattern,
physical: SmallVec<[Physical; 1]>,
}

impl<Physical: Clone> Tensor<Physical> {
#[inline]
pub fn new(
data_type: DataType,
shape: Shape,
pattern: Pattern,
physical: impl IntoIterator<Item = Physical>,
) -> Self {
Self {
data_type,
shape,
pattern,
physical: physical.into_iter().collect(),
}
}

#[inline]
pub fn apply(&self, operator: &impl Operator) -> Self {
Self {
data_type: self.data_type,
shape: operator.infer_shape(&self.shape),
pattern: Pattern(operator.to_affine(&self.shape) * &self.pattern.0),
physical: self.physical.clone(),
}
}
}

pub type Shape = SmallVec<[udim; 4]>;
pub type Affine = DMatrix<idim>;

pub struct Pattern(DVector<idim>);

0 comments on commit 079f96b

Please sign in to comment.