diff --git a/Cargo.lock b/Cargo.lock index 010a3a3..7ee346d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -308,6 +308,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "async-openai" +version = "0.32.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4d91c2005832450ad9ac92411b9d4cec39e30be21adc76553efc47f812d59c7" +dependencies = [ + "getrandom 0.3.3", + "reqwest", + "serde", + "serde_json", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -2085,6 +2097,7 @@ dependencies = [ "hyper", "hyper-util", "rustls", + "rustls-native-certs", "rustls-pki-types", "tokio", "tokio-rustls", @@ -4159,6 +4172,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls", + "rustls-native-certs", "rustls-pki-types", "serde", "serde_json", @@ -4902,6 +4916,7 @@ version = "0.1.0" dependencies = [ "anyhow", "approx 0.5.1", + "async-openai", "async-trait", "chrono", "dotenvy", diff --git a/Cargo.toml b/Cargo.toml index bc21d69..3f3c88a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ dotenvy = "0.15.7" tokio-util = "0.7.16" toml = "0.9.5" thiserror = "2.0.17" +async-openai = "0.32.4" [profile.release] lto = true diff --git a/src/memory.rs b/src/memory.rs index 54aa198..7a8ff63 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -2,4 +2,5 @@ pub mod embedding; pub mod memory_cluster; pub mod memory_links; pub mod memory_note; +pub mod working_memory; pub mod record; diff --git a/src/memory/working_memory.rs b/src/memory/working_memory.rs new file mode 100644 index 0000000..8472d2f --- /dev/null +++ b/src/memory/working_memory.rs @@ -0,0 +1 @@ +pub mod sliding_window; diff --git a/src/memory/working_memory/sliding_window.rs b/src/memory/working_memory/sliding_window.rs new file mode 100644 index 0000000..35e966b --- /dev/null +++ b/src/memory/working_memory/sliding_window.rs @@ -0,0 +1,226 @@ +use std::collections::VecDeque; +use tokio::sync::mpsc; +use std::sync::{Arc, RwLock}; +use std::error::Error; +use tokio::time::{sleep, Duration}; +use tokio::runtime::Runtime; + +// use async_openai::{ +// types::{CreateChatCompletionRequest, ChatCompletionRequestMessage}, +// Client, +// }; + + +//滑动窗口(容器、容量、标记计数、摘要用临时储存) +pub struct SlidingWindow { + window: VecDeque, + capacity: usize, + tag_count: usize, + summary: Arc>, +} + +impl SlidingWindow { + //新建 + pub fn new(capacity: usize) -> Self { + Self { + window: VecDeque::with_capacity(capacity+1), + capacity, + tag_count: capacity, + summary: Arc::new(RwLock::new(String::new())), + } + } + //信息滑入 + pub async fn push(&mut self, mut value: Information) -> Result<(), Box> { + value = self.auto_tag(value); + self.window.push_back(value); + if self.window.len() == (self.capacity+1) { + self.pop().await?; + } + Ok(()) + } + //信息滑出,若信息被标记则进行摘要 + pub async fn pop(&mut self) -> Result<(), Box> { + let target = self.window.pop_front(); + if let Some(value) = target { + if value.is_tagged() { + self.summarize().await?; + } + } + Ok(()) + } + //获取窗口大小 + pub fn len(&self) -> usize { + self.window.len() + } + //获取窗口容量 + pub fn get_capacity(&self) -> usize { + self.capacity + } + //获取窗口容量(可变) + pub fn get_mut_capacity(&mut self) -> &mut usize { + &mut self.capacity + } + //获取窗口中指定索引的信息 + pub fn get(&self, index: usize) -> Option<&Information> { + self.window.get(index) + } + + //判断窗口是否为空 + pub fn is_empty(&self) -> bool { + self.window.is_empty() + } + //清空窗口内容 + pub fn clear(&mut self) { + self.window.clear(); + self.tag_count = 0; + } + //标记用 + pub fn tag_information(&mut self, index: usize) { + if index < self.capacity { + self.window[index].tag_information(); + } + } + //取消标记用 + pub fn untag_information(&mut self, index: usize) { + if index < self.capacity { + self.window[index].untag_information(); + } + } + //每滑入capacity次信息时进行一次标记 + fn auto_tag(&mut self, mut value: Information) -> Information { + self.tag_count += 1; + if self.tag_count >= self.capacity { + value.tag_information(); + self.tag_count = 0; + } + value + } + + //将摘要记忆和当前滑动窗口信息合并提供LLM + async fn summarize(&self) -> Result>, Box> { + let mut summary_text = match self.summary.write() { + Ok(mut value) => value.clone(), + Err(e) => { + eprintln!("Summary Error: {}", e); + String::new() + } + }; + + for (index, i) in self.window.iter().enumerate() { + summary_text.push_str(&index.to_string()); + summary_text.push_str(&i.to_string()); + } + + let summary_arc = self.summary.clone(); + + match call_llm(&summary_text).await { + Ok(response) => { + match summary_arc.write() { + Ok(mut value) => { + println!("Summary updated in background."); + *value = response + }, + Err(e) => eprintln!("LLM Error: {}", e), + } + } + Err(e) => eprintln!("LLM Error: {}", e), + } + + Ok(summary_arc) + } + +} + +pub struct Information { + pub text: String, + pub tag: bool, +} + +impl Information { + pub fn new(text: String) -> Self { + Self { text, tag: false } + } + pub fn tag_information(&mut self) { + self.tag = true + } + pub fn untag_information(&mut self) { + self.tag = false + } + pub fn is_tagged(&self) -> bool { + self.tag + } + pub fn to_string(&self) -> String { + self.text.clone() + } +} + +fn test_summary(summary: String) -> String { + println!("{}", summary.clone()); + summary +} + + +async fn call_llm(summary: &String) -> Result> { + // let client = Client::new(); + + // let request = CreateChatCompletionRequest { + // model: "unknown".to_string(), + // messages: vec![ChatCompletionRequestMessage { + // role: "user".to_string(), + // content: summary, + // ..Default::default() + // }], + // ..Default::default() + // }; + // let response = client.chat().create(request).await?; + // let output = response + // .choices + // .first() + // .and_then(|c| c.message.content.clone()) + // .unwrap_or_default(); + sleep(Duration::from_millis(500)).await; + let output = summary.clone(); + Ok(output) +} + + +#[cfg(test)] +mod slidingwindow_test{ + use super::*; + + #[tokio::test] + async fn sliding_window_test_push(){ + let mut window = SlidingWindow::new(10); + let info = Information::new("test1".to_string()); + window.push(info).await; + let info2 = Information::new("test2".to_string()); + window.push(info2).await; + assert_eq!(window.get(0).expect("not found this information").text, "test1"); + assert_eq!(window.get(1).expect("not found this information").text, "test2"); + } + #[tokio::test] + async fn sliding_window_test_pop(){ + let mut window = SlidingWindow::new(10); + let info = Information::new("test1".to_string()); + window.push(info).await; + let info2 = Information::new("test2".to_string()); + window.push(info2).await; + window.pop().await; + assert_eq!(window.get(0).expect("not found this information").text, "test2"); + } + #[tokio::test] + async fn sliding_window_test_summary_and_tag(){ + let mut window = SlidingWindow::new(2); + let info = Information::new("test1".to_string()); + window.push(info).await; + let info2 = Information::new("test2".to_string()); + window.push(info2).await; + let info3 = Information::new("test3".to_string()); + window.push(info3).await; + assert_eq!(window.summary.read().unwrap().as_str(), "0test21test3"); + let test = window.get(1); + if let Some(value) = test { + assert!(value.is_tagged()); + } + } +}