diff --git a/api/routes_chat.py b/api/routes_chat.py index 977964a..89f1a85 100644 --- a/api/routes_chat.py +++ b/api/routes_chat.py @@ -51,7 +51,26 @@ def chat_message(req: ChatRequest, user=Depends(get_current_user)): model_type=model_type, model_name=model_name, ) - reply = generator.generate_response(req.message, history) + + # Determine which sources contributed to the response + db_text = generator._lookup_db(req.message) + rag_text = generator._search_rag(req.message) + if db_text and rag_text: + source_info = "DB + RAG" + elif db_text: + source_info = "DB" + elif rag_text: + source_info = "RAG" + else: + source_info = "None" + + context = generator._merge_sources(db_text, rag_text) + if not context: + reply = "No information" + else: + prompt = generator._build_prompt(req.message, history, context) + reply = generator.model.generate(prompt) + if model_type == "ollama" and not reply.startswith("[Ollama"): reply = f"[Ollama] {reply}" @@ -65,6 +84,7 @@ def chat_message(req: ChatRequest, user=Depends(get_current_user)): "history": conversation_repository.get_history( req.session_id, user["username"] ), + "source": source_info, } diff --git a/frontend/src/pages/Chat.jsx b/frontend/src/pages/Chat.jsx index 29f6543..96daad7 100644 --- a/frontend/src/pages/Chat.jsx +++ b/frontend/src/pages/Chat.jsx @@ -87,7 +87,15 @@ export default function Chat() { const sess = await api.get('/chat/sessions'); setSessions(sess.data.sessions || []); - setHistory(res.data.history || []); + + const updatedHistory = res.data.history || []; + if (updatedHistory.length > 0) { + const last = updatedHistory[updatedHistory.length - 1]; + if (last.sender === 'assistant') { + last.source = res.data.source; + } + } + setHistory(updatedHistory); } catch (err) { const detail = err.response?.data?.detail || 'Unknown error'; console.error('API Error:', detail); @@ -148,6 +156,9 @@ export default function Chat() {

{msg.message}

+

+ מקור התשובה: {msg.source || 'None'} +

diff --git a/tests/test_api.py b/tests/test_api.py index 63e0aa4..d1f54f1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -83,6 +83,7 @@ def test_admin_chat_and_history(tmp_path, monkeypatch): data = resp.json() assert data["reply"].startswith("[Ollama") assert len(data["history"]) == 2 + assert data["source"] in {"None", "DB", "RAG", "DB + RAG"} resp = client.get("/chat/history", params={"session_id": "s1"}, headers=headers) assert resp.status_code == 200