Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 71 additions & 26 deletions backend/services/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from services.knowledgebase.knowledgebase_service import KBase
from fastapi import HTTPException
import datetime
import logging
from typing import List
from core.utils.utils import GetDeafultLLM_Chat,GetDefaultEmbedding
class Chat:
Expand All @@ -19,6 +20,9 @@ def __init__(self, conversation_id,user_id,rag:RAG_Pipeline):
self.mysql_session = MysqlClient().SessionLocal()
self.rag:RAG_Pipeline = rag
self.default_llm_config = GetDeafultLLM_Chat()
if self.default_llm_config is None:
logging.error("GetDeafultLLM_Chat() returned None. No default LLM configuration found.")
raise HTTPException(status_code=500, detail="No default LLM configuration found. Please configure a default LLM.")
self.default_embedding_config = GetDefaultEmbedding()
def __del__(self):
self.mysql_session.close()
Expand All @@ -32,7 +36,9 @@ def create_conversation(self,knowledgeBaseId:str,username)->Conversation:
# print("Conversation created successfully")
return new_conversation
except Exception as e:
print(e)
logging.error(f"Error creating conversation: {e}")
self.mysql_session.rollback()
raise


#匹配对话
Expand All @@ -56,9 +62,14 @@ def match_knowledgebase(self, conversation_id,username)->KnowledgeBase:
return knowledgebase

except Exception as e:
print(f"Erreor:{e}")
logging.error(f"Error matching knowledgebase: {e}")
raise
# 生成对话标题
def generate_conversation_title(self, conversation_id,username:str)->str:
if self.default_llm_config is None:
logging.error("default_llm_config is None in generate_conversation_title")
raise HTTPException(status_code=500, detail="No default LLM configuration available")

llm = LLM_Manager().creatLLM(mode_provider=self.default_llm_config.vendor_type,model=self.default_llm_config.model)

conversation_messages = self.load_conversation(conversation_id,username=username)
Expand All @@ -76,8 +87,10 @@ def generate_conversation_title(self, conversation_id,username:str)->str:
#########################
"""
prompt = f"""
请根据以下对话内容生成一个对话标题,对话内容如下:\n
{messageLogs_txt}\n
请根据以下对话内容生成一个对话标题,对话内容如下:

{messageLogs_txt}

请生成一个简洁明了的对话标题,不超过10个字,且仅需要输出标题,如果没有内容则输出:新对话
"""

Expand All @@ -100,7 +113,8 @@ def get_knowledgebases(self,username:str):
all_kbse = self.mysql_session.query(KnowledgeBase).filter((KnowledgeBase.created_by == username) | (KnowledgeBase.is_public == True),KnowledgeBase.delete_sign == False)
return all_kbse
except Exception as e:
print(f"获取知识库失败:{e}")
logging.error(f"Error getting knowledgebases: {e}")
raise
# 更改知识库
def change_knowledgebase(self, conversation_id,knowledgeBaseId,username):
try:
Expand All @@ -110,7 +124,9 @@ def change_knowledgebase(self, conversation_id,knowledgeBaseId,username):
self.mysql_session.refresh(conversation)
return conversation
except Exception as e:
print(e)
logging.error(f"Error changing knowledgebase: {e}")
self.mysql_session.rollback()
raise
# 删除对话
def delete_conversation(self, conversation_id,username:str)->Conversation:
try:
Expand All @@ -124,7 +140,9 @@ def delete_conversation(self, conversation_id,username:str)->Conversation:
return conversation

except Exception as e:
print(e)
logging.error(f"Error deleting conversation: {e}")
self.mysql_session.rollback()
raise
# 重命名对话
def rename_conversation(self, conversation_id,username, new_name):
try:
Expand All @@ -139,7 +157,9 @@ def rename_conversation(self, conversation_id,username, new_name):

return conversation
except Exception as e:
print(e)
logging.error(f"Error renaming conversation: {e}")
self.mysql_session.rollback()
raise
# 加载对话
def load_conversation(self, conversation_id,username:str)->List[ChatMessageHistory]:
# 检查对话是否属于用户
Expand Down Expand Up @@ -174,7 +194,8 @@ def load_conversation(self, conversation_id,username:str)->List[ChatMessageHisto

return messages_history
except Exception as e:
print(e)
logging.error(f"Error loading conversation: {e}")
raise

# 格式化对话记录
def format_conversation_Log(self, messages:List[Chat_Messages]):
Expand All @@ -199,12 +220,18 @@ def save_conversation(self, message:Chat_Messages):
self.mysql_session.commit()
self.mysql_session.refresh(message)
except Exception as e:
print(e)
logging.error(f"Error saving conversation: {e}")
self.mysql_session.rollback()
raise

# 回答问题
def answer_question(self, resultByDoc: ResultByDoc,history_message=[],streaming=False):
# rAG_Pipeline = RAG_Pipeline()

if self.default_llm_config is None:
logging.error("default_llm_config is None in answer_question")
raise HTTPException(status_code=500, detail="No default LLM configuration available")

answer = self.rag.generate_answer_by_knowledgebase(resultByDoc=resultByDoc,history_messages=history_message,streaming=streaming,LLM_Provider=self.default_llm_config.vendor_type,LLM_Model=self.default_llm_config.model)
if isinstance(answer, str):
print(answer)
Expand All @@ -226,7 +253,8 @@ def get_retrieve_documents(self, question,knowledgebase_id,rag_model:int,is_rera
resultByDoc:ResultByDoc = self.rag.retrieve_documents(question=question,knowledge_base_id=knowledgebase_id,rag_model=rag_model,is_rerank=is_rerank)
return resultByDoc
except Exception as e:
print(e)
logging.error(f"Error retrieving documents: {e}")
raise
# 处理用户输入
def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False):
# 获取用户输入
Expand Down Expand Up @@ -262,7 +290,8 @@ def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False):
knowledgeBaseId=knowledgebase.knowledgeBaseId,
timeStamp=datetime.datetime.now()
)
self.save_conversation(new_message)
self.mysql_session.add(new_message)
self.mysql_session.flush()

# 保存检索文档
for doc in resultByDoc.source:
Expand All @@ -273,23 +302,29 @@ def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False):
messageId=new_message.id
)
self.mysql_session.add(retriever_doc)
self.mysql_session.commit()

answer = ""
# 生成回答
if streaming:
answer_generator = self.answer_question(resultByDoc=resultByDoc, history_message=messageLog, streaming=True)
if isinstance(answer_generator, str):
pass
else:

for item in answer_generator :
yield item
answer+=item
# 流式输出完成后更新答案到数据库
answer = answer_generator
new_message.answer = answer
self.mysql_session.commit() # 更新数据库
self.mysql_session.commit()
self.mysql_session.refresh(new_message)
else:
try:
for item in answer_generator :
yield item
answer+=item
# 流式输出完成后更新答案到数据库
new_message.answer = answer
self.mysql_session.commit() # 更新数据库
self.mysql_session.refresh(new_message)
except Exception as e:
logging.error(f"Error during streaming chat: {e}")
self.mysql_session.rollback()
raise
else:
answer = self.answer_question(resultByDoc=resultByDoc, history_message=messageLog, streaming=False)
if isinstance(answer, str):
Expand All @@ -300,17 +335,27 @@ def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False):

return answer
else:
for item in answer:
yield item
try:
for item in answer:
yield item
answer+=item
new_message.answer = answer
self.mysql_session.commit()
self.mysql_session.refresh(new_message)
except Exception as e:
logging.error(f"Error during non-streaming chat: {e}")
self.mysql_session.rollback()
raise

except Exception as e:
print(f"Error Chat Run: {e}")
logging.error(f"Error Chat Run: {e}")
self.mysql_session.rollback()
raise






"""
临时梳理逻辑

Expand All @@ -326,4 +371,4 @@ def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False):

所以需要有一个创建对话的接口,不能根据判断对话id是否为空来创建对话

"""
"""