Skip to content

Commit

Permalink
Update web-rwkv to v0.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jan 31, 2024
1 parent 192da75 commit 4e65c79
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 25 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ai00_server"
version = "0.3.12"
version = "0.3.13"
edition = "2021"
authors = ["Gu ZhenNiu <[email protected]>", "Zhang Zhenyuan <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand All @@ -19,7 +19,7 @@ rust-version = "1.75"
name = "converter"

[dependencies]
web-rwkv = "0.5.1"
web-rwkv = "0.6.0"
# web-rwkv = { git = "https://github.com/cryscan/web-rwkv", tag = "v0.5.0" }
tower = { version = "0.4.13", features = ["full"] }
tower-http = { version = "0.5.0", features = ["full"] }
Expand Down
60 changes: 40 additions & 20 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use rayon::prelude::{
};
use tokio::sync::{Mutex, RwLock};
use web_rwkv::{
model::{v4, v5, v6, BackedState, FromBuilder, Model, ModelInfo, ModelState, StateBuilder},
model::{
v4, v5, v6, BackedState, FromBuilder, Model, ModelInfo, ModelInput, ModelOutput,
ModelState, StateBuilder,
},
tokenizer::Tokenizer,
};

Expand Down Expand Up @@ -488,64 +491,81 @@ where
}
};

let mut input_tokens = payloads
let mut inputs = payloads
.iter()
.map(|payload| match payload {
Payload::Busy(context) => context.suffix.0.clone(),
_ => vec![],
})
.map(|tokens| ModelInput {
tokens,
..Default::default()
})
.collect_vec();

// run the model until there is at least one slot finished
let occupancy = payloads.iter().filter(|x| x.is_busy()).count();
let logits = match occupancy {
0 => vec![None; payloads.len()],
let outputs = match occupancy {
0 => vec![ModelOutput::None; payloads.len()],
_ => loop {
let logits = self.model.run(&mut input_tokens, &self.state).await?;
if logits.iter().any(Option::is_some) {
let logits = self.model.run(&mut inputs, &self.state).await?;
if logits.iter().any(|x| matches!(x, ModelOutput::Last(_))) {
break logits;
}
},
};
let penalty_free_tokens = &self.penalty_free_tokens;
let logits: Vec<_> = payloads
let outputs = payloads
.par_iter()
.zip_eq(logits.into_par_iter())
.map(|(payload, logits)| match payload {
Payload::Busy(context) => logits.map(|mut logits| {
.zip_eq(outputs.into_par_iter())
.map(|(payload, output)| match payload {
Payload::Busy(context) => match output {
ModelOutput::None => None,
ModelOutput::Last(data) => Some(data),
ModelOutput::Full(data) => Some(data.into_iter().last()?),
}
.map(|mut data| {
context
.penalties
.iter()
.filter(|(token, _)| !penalty_free_tokens.contains(token))
.for_each(|(token, penalty)| logits[*token as usize] -= penalty);
.for_each(|(token, penalty)| data[*token as usize] -= penalty);
context
.request
.logit_bias
.iter()
.for_each(|(token, bias)| logits[*token as usize] += *bias);
logits
.for_each(|(token, bias)| data[*token as usize] += *bias);
data
}),
_ => None,
})
.map(|x| match x {
Some(data) => ModelOutput::Last(data),
None => ModelOutput::None,
})
.collect();

let probs = match occupancy {
0 => vec![None; payloads.len()],
_ => self.model.softmax(logits).await?,
0 => vec![ModelOutput::None; payloads.len()],
_ => self.model.softmax(outputs).await?,
};
let output_tokens: Vec<_> = payloads
.par_iter()
.zip_eq(probs.into_par_iter())
.map(|(payload, probs)| match payload {
Payload::Busy(context) => probs.map(|probs| context.request.sampler.sample(probs)),
Payload::Busy(context) => match probs {
ModelOutput::None => None,
ModelOutput::Last(data) => Some(context.request.sampler.sample(data)),
ModelOutput::Full(_) => unreachable!(),
},
_ => None,
})
.collect();

for (payload, token, tokens) in itertools::multizip((
for (payload, token, input) in itertools::multizip((
payloads.iter_mut(),
output_tokens.into_iter(),
input_tokens.into_iter(),
inputs.into_iter(),
)) {
let Payload::Busy(context) = payload else {
continue;
Expand All @@ -556,8 +576,8 @@ where
let model_tokens = [prefix.0, suffix.0].concat();

// compute new prefix and suffix using the current remaining tokens
assert!(model_tokens.len() >= tokens.len());
let len = model_tokens.len() - tokens.len();
assert!(model_tokens.len() >= input.tokens.len());
let len = model_tokens.len() - input.tokens.len();
context.prefix = Tokens(model_tokens[..len].to_vec());
context.suffix = Tokens(model_tokens[len..].to_vec());
context
Expand Down

0 comments on commit 4e65c79

Please sign in to comment.