diff --git a/backend/services/chat/chat.py b/backend/services/chat/chat.py index 75222ba..d28222e 100644 --- a/backend/services/chat/chat.py +++ b/backend/services/chat/chat.py @@ -10,6 +10,7 @@ from services.knowledgebase.knowledgebase_service import KBase from fastapi import HTTPException import datetime +import logging from typing import List from core.utils.utils import GetDeafultLLM_Chat,GetDefaultEmbedding class Chat: @@ -19,6 +20,9 @@ def __init__(self, conversation_id,user_id,rag:RAG_Pipeline): self.mysql_session = MysqlClient().SessionLocal() self.rag:RAG_Pipeline = rag self.default_llm_config = GetDeafultLLM_Chat() + if self.default_llm_config is None: + logging.error("GetDeafultLLM_Chat() returned None. No default LLM configuration found.") + raise HTTPException(status_code=500, detail="No default LLM configuration found. Please configure a default LLM.") self.default_embedding_config = GetDefaultEmbedding() def __del__(self): self.mysql_session.close() @@ -32,7 +36,9 @@ def create_conversation(self,knowledgeBaseId:str,username)->Conversation: # print("Conversation created successfully") return new_conversation except Exception as e: - print(e) + logging.error(f"Error creating conversation: {e}") + self.mysql_session.rollback() + raise #匹配对话 @@ -56,9 +62,14 @@ def match_knowledgebase(self, conversation_id,username)->KnowledgeBase: return knowledgebase except Exception as e: - print(f"Erreor:{e}") + logging.error(f"Error matching knowledgebase: {e}") + raise # 生成对话标题 def generate_conversation_title(self, conversation_id,username:str)->str: + if self.default_llm_config is None: + logging.error("default_llm_config is None in generate_conversation_title") + raise HTTPException(status_code=500, detail="No default LLM configuration available") + llm = LLM_Manager().creatLLM(mode_provider=self.default_llm_config.vendor_type,model=self.default_llm_config.model) conversation_messages = self.load_conversation(conversation_id,username=username) @@ -76,8 +87,10 @@ def generate_conversation_title(self, conversation_id,username:str)->str: ######################### """ prompt = f""" -请根据以下对话内容生成一个对话标题,对话内容如下:\n -{messageLogs_txt}\n +请根据以下对话内容生成一个对话标题,对话内容如下: + +{messageLogs_txt} + 请生成一个简洁明了的对话标题,不超过10个字,且仅需要输出标题,如果没有内容则输出:新对话 """ @@ -100,7 +113,8 @@ def get_knowledgebases(self,username:str): all_kbse = self.mysql_session.query(KnowledgeBase).filter((KnowledgeBase.created_by == username) | (KnowledgeBase.is_public == True),KnowledgeBase.delete_sign == False) return all_kbse except Exception as e: - print(f"获取知识库失败:{e}") + logging.error(f"Error getting knowledgebases: {e}") + raise # 更改知识库 def change_knowledgebase(self, conversation_id,knowledgeBaseId,username): try: @@ -110,7 +124,9 @@ def change_knowledgebase(self, conversation_id,knowledgeBaseId,username): self.mysql_session.refresh(conversation) return conversation except Exception as e: - print(e) + logging.error(f"Error changing knowledgebase: {e}") + self.mysql_session.rollback() + raise # 删除对话 def delete_conversation(self, conversation_id,username:str)->Conversation: try: @@ -124,7 +140,9 @@ def delete_conversation(self, conversation_id,username:str)->Conversation: return conversation except Exception as e: - print(e) + logging.error(f"Error deleting conversation: {e}") + self.mysql_session.rollback() + raise # 重命名对话 def rename_conversation(self, conversation_id,username, new_name): try: @@ -139,7 +157,9 @@ def rename_conversation(self, conversation_id,username, new_name): return conversation except Exception as e: - print(e) + logging.error(f"Error renaming conversation: {e}") + self.mysql_session.rollback() + raise # 加载对话 def load_conversation(self, conversation_id,username:str)->List[ChatMessageHistory]: # 检查对话是否属于用户 @@ -174,7 +194,8 @@ def load_conversation(self, conversation_id,username:str)->List[ChatMessageHisto return messages_history except Exception as e: - print(e) + logging.error(f"Error loading conversation: {e}") + raise # 格式化对话记录 def format_conversation_Log(self, messages:List[Chat_Messages]): @@ -199,12 +220,18 @@ def save_conversation(self, message:Chat_Messages): self.mysql_session.commit() self.mysql_session.refresh(message) except Exception as e: - print(e) + logging.error(f"Error saving conversation: {e}") + self.mysql_session.rollback() + raise # 回答问题 def answer_question(self, resultByDoc: ResultByDoc,history_message=[],streaming=False): # rAG_Pipeline = RAG_Pipeline() + if self.default_llm_config is None: + logging.error("default_llm_config is None in answer_question") + raise HTTPException(status_code=500, detail="No default LLM configuration available") + answer = self.rag.generate_answer_by_knowledgebase(resultByDoc=resultByDoc,history_messages=history_message,streaming=streaming,LLM_Provider=self.default_llm_config.vendor_type,LLM_Model=self.default_llm_config.model) if isinstance(answer, str): print(answer) @@ -226,7 +253,8 @@ def get_retrieve_documents(self, question,knowledgebase_id,rag_model:int,is_rera resultByDoc:ResultByDoc = self.rag.retrieve_documents(question=question,knowledge_base_id=knowledgebase_id,rag_model=rag_model,is_rerank=is_rerank) return resultByDoc except Exception as e: - print(e) + logging.error(f"Error retrieving documents: {e}") + raise # 处理用户输入 def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False): # 获取用户输入 @@ -262,7 +290,8 @@ def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False): knowledgeBaseId=knowledgebase.knowledgeBaseId, timeStamp=datetime.datetime.now() ) - self.save_conversation(new_message) + self.mysql_session.add(new_message) + self.mysql_session.flush() # 保存检索文档 for doc in resultByDoc.source: @@ -273,23 +302,29 @@ def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False): messageId=new_message.id ) self.mysql_session.add(retriever_doc) - self.mysql_session.commit() answer = "" # 生成回答 if streaming: answer_generator = self.answer_question(resultByDoc=resultByDoc, history_message=messageLog, streaming=True) if isinstance(answer_generator, str): - pass - else: - - for item in answer_generator : - yield item - answer+=item - # 流式输出完成后更新答案到数据库 + answer = answer_generator new_message.answer = answer - self.mysql_session.commit() # 更新数据库 + self.mysql_session.commit() self.mysql_session.refresh(new_message) + else: + try: + for item in answer_generator : + yield item + answer+=item + # 流式输出完成后更新答案到数据库 + new_message.answer = answer + self.mysql_session.commit() # 更新数据库 + self.mysql_session.refresh(new_message) + except Exception as e: + logging.error(f"Error during streaming chat: {e}") + self.mysql_session.rollback() + raise else: answer = self.answer_question(resultByDoc=resultByDoc, history_message=messageLog, streaming=False) if isinstance(answer, str): @@ -300,17 +335,27 @@ def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False): return answer else: - for item in answer: - yield item + try: + for item in answer: + yield item + answer+=item + new_message.answer = answer + self.mysql_session.commit() + self.mysql_session.refresh(new_message) + except Exception as e: + logging.error(f"Error during non-streaming chat: {e}") + self.mysql_session.rollback() + raise except Exception as e: - print(f"Error Chat Run: {e}") + logging.error(f"Error Chat Run: {e}") + self.mysql_session.rollback() + raise - """ 临时梳理逻辑 @@ -326,4 +371,4 @@ def run(self, chatMessageRequest: ChatMessageRequest,username, streaming=False): 所以需要有一个创建对话的接口,不能根据判断对话id是否为空来创建对话 -""" \ No newline at end of file +"""