Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dotenvy = "0.15.7"
tokio-util = "0.7.16"
toml = "0.9.5"
thiserror = "2.0.17"
async-openai = "0.32.4"

[profile.release]
lto = true
Expand Down
1 change: 1 addition & 0 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod embedding;
pub mod memory_cluster;
pub mod memory_links;
pub mod memory_note;
pub mod working_memory;
pub mod record;
1 change: 1 addition & 0 deletions src/memory/working_memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod sliding_window;
226 changes: 226 additions & 0 deletions src/memory/working_memory/sliding_window.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
use std::collections::VecDeque;
use tokio::sync::mpsc;
use std::sync::{Arc, RwLock};
use std::error::Error;
use tokio::time::{sleep, Duration};
use tokio::runtime::Runtime;

// use async_openai::{
// types::{CreateChatCompletionRequest, ChatCompletionRequestMessage},
// Client,
// };


//滑动窗口(容器、容量、标记计数、摘要用临时储存)
pub struct SlidingWindow {
window: VecDeque<Information>,
capacity: usize,
tag_count: usize,
summary: Arc<RwLock<String>>,
}

impl SlidingWindow {
//新建
pub fn new(capacity: usize) -> Self {
Self {
window: VecDeque::with_capacity(capacity+1),
capacity,
tag_count: capacity,
summary: Arc::new(RwLock::new(String::new())),
}
}
//信息滑入
pub async fn push(&mut self, mut value: Information) -> Result<(), Box<dyn Error>> {
value = self.auto_tag(value);
self.window.push_back(value);
if self.window.len() == (self.capacity+1) {
self.pop().await?;
}
Ok(())
}
//信息滑出,若信息被标记则进行摘要
pub async fn pop(&mut self) -> Result<(), Box<dyn Error>> {
let target = self.window.pop_front();
if let Some(value) = target {
if value.is_tagged() {
self.summarize().await?;
}
}
Ok(())
}
//获取窗口大小
pub fn len(&self) -> usize {
self.window.len()
}
//获取窗口容量
pub fn get_capacity(&self) -> usize {
self.capacity
}
//获取窗口容量(可变)
pub fn get_mut_capacity(&mut self) -> &mut usize {
&mut self.capacity
}
//获取窗口中指定索引的信息
pub fn get(&self, index: usize) -> Option<&Information> {
self.window.get(index)
}

//判断窗口是否为空
pub fn is_empty(&self) -> bool {
self.window.is_empty()
}
//清空窗口内容
pub fn clear(&mut self) {
self.window.clear();
self.tag_count = 0;
}
//标记用
pub fn tag_information(&mut self, index: usize) {
if index < self.capacity {
self.window[index].tag_information();
}
}
//取消标记用
pub fn untag_information(&mut self, index: usize) {
if index < self.capacity {
self.window[index].untag_information();
}
}
//每滑入capacity次信息时进行一次标记
fn auto_tag(&mut self, mut value: Information) -> Information {
self.tag_count += 1;
if self.tag_count >= self.capacity {
value.tag_information();
self.tag_count = 0;
}
value
}

//将摘要记忆和当前滑动窗口信息合并提供LLM
async fn summarize(&self) -> Result<Arc<RwLock<String>>, Box<dyn Error>> {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回类型Arc不必要,同时此处的返回类型将写权限暴露给外部

let mut summary_text = match self.summary.write() {
Ok(mut value) => value.clone(),
Err(e) => {
eprintln!("Summary Error: {}", e);
String::new()
}
};
Comment on lines +101 to +107
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处错误为锁中毒,应该提前返回


for (index, i) in self.window.iter().enumerate() {
summary_text.push_str(&index.to_string());
summary_text.push_str(&i.to_string());
}

let summary_arc = self.summary.clone();

match call_llm(&summary_text).await {
Ok(response) => {
match summary_arc.write() {
Ok(mut value) => {
println!("Summary updated in background.");
*value = response
},
Err(e) => eprintln!("LLM Error: {}", e),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处错误为锁中毒

}
}
Err(e) => eprintln!("LLM Error: {}", e),
}

Ok(summary_arc)
}

}

pub struct Information {
pub text: String,
pub tag: bool,
}

impl Information {
pub fn new(text: String) -> Self {
Self { text, tag: false }
}
pub fn tag_information(&mut self) {
self.tag = true
}
pub fn untag_information(&mut self) {
self.tag = false
}
pub fn is_tagged(&self) -> bool {
self.tag
}
pub fn to_string(&self) -> String {
self.text.clone()
}
}

fn test_summary(summary: String) -> String {
println!("{}", summary.clone());
summary
}


async fn call_llm(summary: &String) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
// let client = Client::new();

// let request = CreateChatCompletionRequest {
// model: "unknown".to_string(),
// messages: vec![ChatCompletionRequestMessage {
// role: "user".to_string(),
// content: summary,
// ..Default::default()
// }],
// ..Default::default()
// };
// let response = client.chat().create(request).await?;
// let output = response
// .choices
// .first()
// .and_then(|c| c.message.content.clone())
// .unwrap_or_default();
sleep(Duration::from_millis(500)).await;
let output = summary.clone();
Ok(output)
}


#[cfg(test)]
mod slidingwindow_test{
use super::*;

#[tokio::test]
async fn sliding_window_test_push(){
let mut window = SlidingWindow::new(10);
let info = Information::new("test1".to_string());
window.push(info).await;
let info2 = Information::new("test2".to_string());
window.push(info2).await;
assert_eq!(window.get(0).expect("not found this information").text, "test1");
assert_eq!(window.get(1).expect("not found this information").text, "test2");
}
#[tokio::test]
async fn sliding_window_test_pop(){
let mut window = SlidingWindow::new(10);
let info = Information::new("test1".to_string());
window.push(info).await;
let info2 = Information::new("test2".to_string());
window.push(info2).await;
window.pop().await;
assert_eq!(window.get(0).expect("not found this information").text, "test2");
}
#[tokio::test]
async fn sliding_window_test_summary_and_tag(){
let mut window = SlidingWindow::new(2);
let info = Information::new("test1".to_string());
window.push(info).await;
let info2 = Information::new("test2".to_string());
window.push(info2).await;
let info3 = Information::new("test3".to_string());
window.push(info3).await;
assert_eq!(window.summary.read().unwrap().as_str(), "0test21test3");
let test = window.get(1);
if let Some(value) = test {
assert!(value.is_tagged());
}
}
}