From 78ae4d89f512061481840d6c4bd4cc0211515f5b Mon Sep 17 00:00:00 2001 From: xxczxp Date: Tue, 25 Jun 2024 00:45:39 +0800 Subject: [PATCH] =?UTF-8?q?fix(service):=20=E8=BE=93=E5=85=A5=E8=B6=85?= =?UTF-8?q?=E9=95=BF=E6=88=AA=E5=8F=96=E6=97=B6=E4=BC=9A=E4=BF=9D=E7=95=99?= =?UTF-8?q?=E4=B8=80=E6=AE=B5=E8=B5=B7=E5=A7=8B=E7=9A=84=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- service/Cargo.toml | 1 + service/src/session/cache.rs | 262 +++++++++++++++++++++++++++----- service/src/session/dispatch.rs | 7 +- service/src/session/mod.rs | 7 +- service/src/session/task.rs | 4 +- 5 files changed, 232 insertions(+), 49 deletions(-) diff --git a/service/Cargo.toml b/service/Cargo.toml index 523d3d61..becac627 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -14,6 +14,7 @@ causal-lm = { path = "../causal-lm" } log.workspace = true tokio.workspace = true lru = "0.12" +rangemap = { version = "1" } [dev-dependencies] colored = "2.1" diff --git a/service/src/session/cache.rs b/service/src/session/cache.rs index b60c1dea..1ea4a7b1 100644 --- a/service/src/session/cache.rs +++ b/service/src/session/cache.rs @@ -1,6 +1,8 @@ use causal_lm::{CausalLM, QueryContext}; use common::{upos, utok}; -use std::ops::Range; +use log::{debug, info}; +use rangemap::{range_set, RangeSet}; +use std::{cmp::min, ops::Range}; use tensor::Tensor; pub(super) struct Cache { @@ -8,78 +10,169 @@ pub(super) struct Cache { tokens: Vec, /// token 序列在整个对话中的位置。 pos: usize, - /// 缓存在 token 序列中的范围。 - cached: Range, + /// 需要缓存的 token 在 token 序列中的范围。 + cached: RangeSet, + /// 已缓存的 token 在 cached_range 中的范围 + to_be_cached: RangeSet, /// 计算缓存。 cache: Tensor, } +pub struct CacheQuery<'a> { + tokens: &'a [utok], + to_be_cached: &'a RangeSet, +} + +impl<'a> CacheQuery<'a> { + fn new(tokens: &'a [utok], to_be_cached: &'a RangeSet) -> Self { + Self { + tokens, + to_be_cached, + } + } + + pub fn len(&self) -> usize { + self.to_be_cached.iter().map(|range| range.len()).sum() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl<'a> IntoIterator for CacheQuery<'a> { + type Item = &'a utok; + + type IntoIter = CacheQueryIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + CacheQueryIter::new(self.tokens, self.to_be_cached) + } +} + +pub struct CacheQueryIter<'a> { + tokens: &'a [utok], + ranges: rangemap::set::Iter<'a, usize>, + current_iter: Option>, +} + +impl<'a> CacheQueryIter<'a> { + fn new(tokens: &'a [utok], to_be_cached: &'a RangeSet) -> Self { + let mut ranges_iter = to_be_cached.iter(); + + let current_iter = ranges_iter.next().cloned(); + + Self { + tokens, + ranges: ranges_iter, + current_iter, + } + } +} + +impl<'a> Iterator for CacheQueryIter<'a> { + type Item = &'a utok; + + fn next(&mut self) -> Option { + if let Some(range) = &mut self.current_iter { + if let Some(i) = range.next() { + Some(&self.tokens[i]) + } else { + self.current_iter = self.ranges.next().cloned(); + self.current_iter.as_mut()?.next().map(|i| &self.tokens[i]) + } + } else { + None + } + } +} + impl Cache { /// 生成一个空白的缓存结构,准备填充 `tokens`。 #[inline] pub fn new(t: &impl CausalLM, tokens: Vec) -> Self { + let tokens_len = tokens.len(); Self { tokens, pos: 0, - cached: 0..0, + cached: RangeSet::new(), + to_be_cached: range_set![0..tokens_len], cache: t.new_cache(), } } + /// 复制缓存结构。 #[inline] pub fn duplicate(&self, t: &impl CausalLM) -> Self { - assert_eq!(self.cached.start, 0); + debug!("call duplicate"); Self { tokens: self.tokens.clone(), pos: self.pos, cached: self.cached.clone(), - cache: t.duplicate_cache(&self.cache, self.cached.end as _), + to_be_cached: self.to_be_cached.clone(), + cache: t.duplicate_cache(&self.cache, self.cached_len() as _), } } /// 回滚缓存到 `pos`,并返回剩余的有效缓存长度。 - pub fn revert(&mut self, pos: usize) -> usize { - // 只能在闲时回滚,因此 cache 和 tokens 起始位置对齐 - assert_eq!(self.cached.start, 0); + pub fn revert(&mut self, pos: usize) -> Option { + debug!("call revert"); // 回滚之后,tokens.len()、cached.end、pos 不能大于新的 pos - let len = pos.saturating_sub(self.pos); - // 1. tokens.len() 不大于 pos; - self.tokens.truncate(len); + // 1. pos 不大于 pos; + let len = pos.checked_sub(self.pos)?; // 2. cached.end 不大于 pos; - self.cached.end = self.cached.end.min(len); - // 3. pos 不大于 pos; - self.pos = self.pos.min(pos); + if len != 0 && self.cached.contains(&(len - 1)) { + self.to_be_cached.clear(); + self.cached.remove(len..self.cached.last().unwrap().end); + } else { + return None; + } + // 3. tokens.len() 不大于 pos; + self.tokens.truncate(len); // 返回当前的缓存长度 - self.cached.len() + Some(self.cached_len()) } /// 扩展待填充 token。 #[inline] pub fn extend(&mut self, tokens: &[utok]) { + debug!("call extend"); + let before_len = self.tokens.len(); self.tokens.extend_from_slice(tokens); + self.to_be_cached.insert(before_len..self.tokens.len()); } /// 所有 token 中还没有加入缓存的部分就是这次的查询。 #[inline] - pub fn query(&self) -> &[utok] { - &self.tokens[self.cached.end..] + pub fn query(&self) -> CacheQuery { + CacheQuery::new(&self.tokens, &self.to_be_cached) } /// 生成对应的查询上下文。 #[inline] pub fn as_ctx(&mut self) -> QueryContext { - let Cache { - pos: _pos, - cache, - tokens, - cached, - } = self; + debug!("call as_ctx"); + debug!( + "cache reset\ncached is {:?}\nto_be_cached is {:?}", + self.cached, self.to_be_cached + ); QueryContext { - cache: Some(cache), - range: cached.len() as upos..(tokens.len() - cached.start) as upos, + range: self.cached_len() as upos..(self.cached_len() + self.to_be_cached_len()) as upos, + cache: Some(&mut (self.cache)), } } - /// 将新采样的值加入缓存。 + /// 将新采样的值加入缓存。默认to_be_cached不为空 #[inline] pub fn push(&mut self, token: utok) { - self.cached.end = self.tokens.len(); + debug!("call push"); + assert!(self.is_continue()); + + //to_be_cached 全部变为cached + self.to_be_cached + .iter() + .for_each(|range| self.cached.insert(range.clone())); + //清空to_be_cached 并插入新的需要缓存的token + self.to_be_cached.clear(); + self.to_be_cached + .insert(self.tokens.len()..self.tokens.len() + 1); + //插入token self.tokens.push(token); } /// 已采样的最后一个词在对话中的位置。 @@ -87,35 +180,122 @@ impl Cache { pub fn end(&self) -> usize { self.pos + self.tokens.len() } - /// 提取尾部词序列。 + /// 提取尾部词序列,默认尾部序列在cached和to_be_cached中是连续的 #[inline] pub fn slice_tail(&self, pos: usize) -> &[utok] { let known = pos.checked_sub(self.pos).unwrap(); &self.tokens[known..] } - /// 重置缓存窗口。 - pub fn reset_within(&mut self, min: usize, max: usize) { - if self.tokens.len() - self.cached.start >= max { - self.cached.start = self.tokens.len() - min; - self.cached.end = self.cached.start; + /// 重置缓存窗口,并将起始点设置为尾部一部分之前 + #[allow(unused)] + pub fn reset_within_one_range(&mut self, min: usize, max: usize) { + if self.cached_len() + self.to_be_cached_len() >= max { + self.cached.clear(); + self.to_be_cached = range_set![(self.tokens.len() - min..self.tokens.len())]; + } + } + /// 重置缓存窗口,保留起始的一部分,并将起始点设置为尾部一部分之前 + pub fn reset_within_start_and_end_range( + &mut self, + start_size: usize, + end_size: usize, + max: usize, + ) { + assert!(start_size + end_size <= max); + if self.cached_len() + self.to_be_cached_len() >= max { + let mut uncached_start: usize = 0; + + self.cached.clear(); + if let Some(mut first_range) = self.cached.first().cloned() { + first_range.end = min(first_range.end, start_size); + if first_range.start == 0 && !first_range.is_empty() { + uncached_start = first_range.end; + self.cached = range_set![first_range]; + } + } + // 为to_be_cached 赋值 + if uncached_start != start_size { + self.to_be_cached = range_set![ + uncached_start..start_size, + self.tokens.len() - end_size..self.tokens.len() + ]; + } else { + self.to_be_cached = range_set![self.tokens.len() - end_size..self.tokens.len()]; + } + + info!( + "cache reset\ncached is {:?}\nto_be_cached is {:?}", + self.cached, self.to_be_cached + ); } } - /// 重置缓存窗口。 + /// 重置并清空缓存窗口。 pub fn reset_with(&mut self, tokens: Vec, pos: usize) { self.tokens = tokens; self.pos = pos; - self.cached = 0..0; + self.cached.clear(); + self.to_be_cached = range_set![0..self.tokens.len()]; } - /// 清理缓存中已脱离缓存窗口的部分。 - pub fn cleanup(&mut self) { - let to_remove = self.cached.start; + /// 清理缓存中在缓存窗口之前的部分。 + pub fn cleanup_before_start(&mut self) { + let to_remove = self.cached.first().unwrap().start; if to_remove > 0 { self.tokens.copy_within(to_remove.., 0); self.pos += to_remove; self.tokens.truncate(self.tokens.len() - to_remove); - self.cached.start = 0; - self.cached.end -= to_remove; + + // 整体减小cached和to_be_cached + self.cached + .iter() + .fold(RangeSet::new(), |mut set, range| { + set.insert(range.start - to_remove..range.end - to_remove); + set + }) + .clone_into(&mut self.cached); + self.to_be_cached + .iter() + .fold(RangeSet::new(), |mut set, range| { + set.insert(range.start - to_remove..range.end - to_remove); + set + }) + .clone_into(&mut self.to_be_cached); } } + /// 获取cached中最后一个区间的长度,如果cached为空则会panic + pub fn get_last_cached_range_len(&self) -> usize { + self.cached.last().unwrap().len() + } + + /// 判定需要缓存的部分包含tokens的结尾 + fn is_continue(&self) -> bool { + if self.to_be_cached.is_empty() { + self.cached.last().unwrap().end == self.tokens.len() + } else { + self.to_be_cached.last().unwrap().end == self.tokens.len() + } + } + + /// 获取cached 总长度 + #[inline] + fn cached_len(&self) -> usize { + self.cached.iter().map(|range| range.len()).sum() + } + + /// 获取 to_be_cached 总长度 + #[inline] + fn to_be_cached_len(&self) -> usize { + self.to_be_cached.iter().map(|range| range.len()).sum() + } +} + +#[test] +fn test_cache_query() { + let v: Vec = (0..100).into_iter().collect(); + CacheQuery::new( + &v, + &range_set![10 as usize..20 as usize, 40 as usize..50 as usize], + ) + .into_iter() + .for_each(|a| println!("{:?}", a)); } diff --git a/service/src/session/dispatch.rs b/service/src/session/dispatch.rs index 04d3426e..2d9d2ef1 100644 --- a/service/src/session/dispatch.rs +++ b/service/src/session/dispatch.rs @@ -29,7 +29,7 @@ impl TaskHandle { impl ServiceComponent { pub(super) fn infer(&self, sample: SampleArgs, mut cache: Cache) -> TaskHandle { let max = self.handle.model.max_seq_len() as usize; - cache.reset_within(max / 4, max / 4 * 3); + cache.reset_within_start_and_end_range(max / 4, max / 4, max / 4 * 3); // 生成推理任务与会话的交互管道 let cache = Arc::new(Mutex::new(Some(cache))); let (sender, receiver) = unbounded_channel(); @@ -137,14 +137,15 @@ where tokio::task::spawn_blocking(move || { let eos = self_.model.eos_token(); let max = self_.model.max_seq_len() as usize; - let min = max / 4; + let end_size = max / 4; + let start_size = max / 4; zip(tasks, num_decode) .filter(|(_, n)| *n > 0) .map(|(t, _)| t) .zip(tokens) .filter(|(_, token)| *token != eos) .for_each(|(mut task, token)| { - if task.push(token, min, max) { + if task.push(token, start_size, end_size, max) { self_.batcher.enq(task); } }); diff --git a/service/src/session/mod.rs b/service/src/session/mod.rs index 2873dbed..0fc8ed7f 100644 --- a/service/src/session/mod.rs +++ b/service/src/session/mod.rs @@ -81,9 +81,10 @@ impl Session { let cache = self.cache.as_mut().unwrap(); self.dialog.revert(dialog_pos); - let cached = cache.revert(self.dialog.num_tokens()); let last_prompt = self.dialog.last_prompt().map_or(0, |p| p.len()); - if cached < last_prompt { + if cache.revert(self.dialog.num_tokens()).is_none() + || cache.get_last_cached_range_len() < last_prompt + { let len = self.component.handle.model.max_seq_len() as usize; let (tokens, pos) = self.dialog.window(len); cache.reset_with(tokens, pos); @@ -141,7 +142,7 @@ impl Session { // 只要忙会话收集到任何 token,就生成一个新的句子 self.dialog.push(cache.slice_tail(end).to_vec()); } - cache.cleanup(); + cache.cleanup_before_start(); info!("Cache restored at {} tokens", cache.end()); self.cache = Some(cache); } diff --git a/service/src/session/task.rs b/service/src/session/task.rs index cae8c08e..9a61f78b 100644 --- a/service/src/session/task.rs +++ b/service/src/session/task.rs @@ -39,11 +39,11 @@ impl Task { } #[inline] - pub fn push(&mut self, token: utok, min: usize, max: usize) -> bool { + pub fn push(&mut self, token: utok, start_size: usize, end_size: usize, max: usize) -> bool { if self.sender.send(token).is_ok() { if let Some(cache) = self.cache.lock().unwrap().as_mut() { cache.push(token); - cache.reset_within(min, max); + cache.reset_within_start_and_end_range(start_size, end_size, max); return true; } }