Skip to content

Commit

Permalink
refactor(transformer): 使 Transformer Trait 对任意模型结构开放
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Apr 17, 2024
1 parent 72cba6b commit 2fc5b1a
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 13 deletions.
9 changes: 7 additions & 2 deletions nvidia/distributed/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ impl transformer::Transformer for Transformer {
type Cache = Cache;

#[inline]
fn model(&self) -> &dyn Llama2 {
&self.host
fn max_position_embeddings(&self) -> usize {
self.host.max_position_embeddings()
}

#[inline]
fn eos_token(&self) -> utok {
self.host.eos_token_id()
}

#[inline]
Expand Down
9 changes: 7 additions & 2 deletions nvidia/transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ impl transformer::Transformer for Transformer {
type Cache = Cache;

#[inline]
fn model(&self) -> &dyn Llama2 {
&self.host
fn max_position_embeddings(&self) -> usize {
self.host.max_position_embeddings()
}

#[inline]
fn eos_token(&self) -> utok {
self.host.eos_token_id()
}

fn new_cache(&self) -> Vec<LayerCache<Self::Cache>> {
Expand Down
6 changes: 3 additions & 3 deletions service/src/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ where

pub fn manage(&mut self) -> UnboundedSender<Message<T::Cache>> {
let (sender, mut receiver) = unbounded_channel();
let max_seq_len = self.transformer.model().max_position_embeddings();
let max_seq_len = self.transformer.max_position_embeddings();
let transformer = self.transformer.clone();
let batcher = self.batcher.clone();
self.set.spawn(async move {
Expand Down Expand Up @@ -108,8 +108,8 @@ where
sample: Arc<Mutex<SampleArgs>>,
sender: UnboundedSender<Message<T::Cache>>,
) {
let max_seq_len = self.transformer.model().max_position_embeddings();
let eos = self.transformer.model().eos_token_id();
let max_seq_len = self.transformer.max_position_embeddings();
let eos = self.transformer.eos_token();
let transformer = self.transformer.clone();
let batcher = self.batcher.clone();
self.set.spawn_blocking(move || loop {
Expand Down
13 changes: 9 additions & 4 deletions transformer-cpu/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod kernel;

use common::Blob;
use common::{utok, Blob};
use gemm::f16;
use kernel::CpuKernels;
use tensor::{reslice, slice, udim, DataType, LocalSplitable, Tensor};
Expand All @@ -12,8 +12,13 @@ impl transformer::Transformer for Transformer {
type Cache = Blob;

#[inline]
fn model(&self) -> &dyn Llama2 {
&self.0
fn max_position_embeddings(&self) -> usize {
self.0.max_position_embeddings()
}

#[inline]
fn eos_token(&self) -> utok {
self.0.eos_token_id()
}

#[inline]
Expand Down Expand Up @@ -66,7 +71,7 @@ impl transformer::Transformer for Transformer {
args: &SampleArgs,
requests: Vec<Id>,
logits: Tensor<Self::Cache>,
) -> Vec<(Id, common::utok)> {
) -> Vec<(Id, utok)> {
let &[_, voc] = logits.shape() else { panic!() };
let dt = logits.data_type();

Expand Down
3 changes: 2 additions & 1 deletion transformer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ use tensor::Tensor;
pub trait Transformer {
type Cache;

fn model(&self) -> &dyn Llama2;
fn max_position_embeddings(&self) -> usize;
fn eos_token(&self) -> utok;
fn new_cache(&self) -> Vec<LayerCache<Self::Cache>>;
fn decode<Id>(&self, requests: Vec<Request<Id, Self::Cache>>)
-> (Vec<Id>, Tensor<Self::Cache>);
Expand Down
2 changes: 1 addition & 1 deletion transformer/src/parameters/safe_tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ fn concat0(tensors: &[&Tensor<Storage>]) -> Tensor<Storage> {
assert!(!tensors.is_empty());
assert!(tensors
.windows(2)
.all(|t| t[0].data_type() == t[1].data_type() && t[0].shape()[1..] == t[1].shape()[1..]));
.all(|t| t[0].data_type() == t[1].data_type()));

let data_type = tensors[0].data_type();
let mut shape = Shape::from_slice(tensors[0].shape());
Expand Down

0 comments on commit 2fc5b1a

Please sign in to comment.