diff --git a/notebook_intelligence/ai_service_manager.py b/notebook_intelligence/ai_service_manager.py index f303a8a..c1f7c15 100644 --- a/notebook_intelligence/ai_service_manager.py +++ b/notebook_intelligence/ai_service_manager.py @@ -12,6 +12,12 @@ from notebook_intelligence.api import ButtonData, ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, ChatParticipant, ChatRequest, ChatResponse, CompletionContext, ContextRequest, Host, CompletionContextProvider, MCPPrompt, MCPServer, MarkdownData, NotebookIntelligenceExtension, RegistrationError, TelemetryEvent, TelemetryListener, Tool, Toolset from notebook_intelligence.base_chat_participant import BaseChatParticipant from notebook_intelligence.config import NBIConfig +from notebook_intelligence.history_backends import ( + HistoryPersistenceBackend, + HistoryPersistenceManager, + MySQLHistoryBackend, + SQLiteHistoryBackend, +) from notebook_intelligence.github_copilot_chat_participant import GithubCopilotChatParticipant from notebook_intelligence.claude import CLAUDE_CODE_CHAT_PARTICIPANT_ID, ClaudeCodeChatParticipant, ClaudeCodeInlineCompletionModel, fetch_claude_models, get_claude_models from notebook_intelligence.llm_providers.github_copilot_llm_provider import GitHubCopilotLLMProvider @@ -61,6 +67,13 @@ def __init__(self, options: Optional[dict] = None): self._options.get("feature_policies") or {}, self._options.get("string_overrides") or {}, ) + self._history_persistence = HistoryPersistenceManager() + self.register_history_persistence_backend(MySQLHistoryBackend()) + self.register_history_persistence_backend(SQLiteHistoryBackend()) + self._history_persistence.reconfigure( + self._nbi_config.history_config, + self._nbi_config.history_backend_configs, + ) self._openai_compatible_llm_provider = OpenAICompatibleLLMProvider() self._litellm_compatible_llm_provider = LiteLLMCompatibleLLMProvider() self._ollama_llm_provider = OllamaLLMProvider() @@ -93,6 +106,22 @@ def __init__(self, options: Optional[dict] = None): def nbi_config(self) -> NBIConfig: return self._nbi_config + @property + def history_persistence(self) -> HistoryPersistenceManager: + return self._history_persistence + + def register_history_persistence_backend( + self, backend: HistoryPersistenceBackend + ) -> None: + self._history_persistence.register_backend(backend) + + def update_history_persistence(self): + """Refresh history persistence backend configuration from current config.""" + self._history_persistence.reconfigure( + self._nbi_config.history_config, + self._nbi_config.history_backend_configs, + ) + @property def ollama_llm_provider(self) -> OllamaLLMProvider: return self._ollama_llm_provider diff --git a/notebook_intelligence/api.py b/notebook_intelligence/api.py index 3129e8c..d139fa4 100644 --- a/notebook_intelligence/api.py +++ b/notebook_intelligence/api.py @@ -143,6 +143,8 @@ class ChatRequest: cancel_token: CancelToken = None # NEW: Add context for rule evaluation rule_context: Optional[RuleContext] = None + # Internal conversation id for persistence backends + conversation_id: str = None @dataclass class ResponseStreamData: @@ -312,6 +314,12 @@ def message_id(self) -> str: def stream(self, data: ResponseStreamData, finish: bool = False) -> None: raise NotImplementedError + + def append_tool_calls(self, tool_calls: list[dict] | None) -> None: + return None + + def append_history_message(self, message: dict) -> None: + return None def finish(self) -> None: raise NotImplementedError @@ -636,6 +644,7 @@ async def _tool_call_loop(tool_call_rounds: list): for choice in tool_response['choices']: message = choice['message'] + response.append_tool_calls(message.get('tool_calls')) # Some models use 'reasoning', some use 'reasoning_content' raw_reasoning = message.get('reasoning') or message.get('reasoning_content') or '' @@ -687,6 +696,12 @@ async def _tool_call_loop(tool_call_rounds: list): args = tool_call['function']['arguments'] else: args = fuzzy_json_loads(tool_call['function']['arguments']) + + # Persist tool execution START (initial record). + if request.conversation_id: + request.host.history_persistence.log_tool_execution( + tool_call['id'], request.conversation_id, tool_name, args, "" + ) tool_properties = tool_to_call.schema["function"]["parameters"]["properties"] if type(args) is str: @@ -713,6 +728,17 @@ async def _tool_call_loop(tool_call_rounds: list): return tool_call_response = await tool_to_call.handle_tool_call(request, response, tool_context, args) + + # Persist tool execution result. + if request.conversation_id: + request.host.history_persistence.log_tool_execution( + tool_call['id'], request.conversation_id, tool_name, args, str(tool_call_response) + ) + # Also log the tool message itself + msg_id = str(uuid.uuid4()) + request.host.history_persistence.add_message( + msg_id, request.conversation_id, "tool", str(tool_call_response), tool_call_id=tool_call['id'] + ) function_call_result_message = { "role": "tool", @@ -720,6 +746,7 @@ async def _tool_call_loop(tool_call_rounds: list): "tool_call_id": tool_call['id'] } + response.append_history_message(function_call_result_message) messages.append(function_call_result_message) if had_tool_call: @@ -915,6 +942,9 @@ def register_telemetry_listener(self, listener: TelemetryListener) -> None: def register_toolset(self, toolset: Toolset) -> None: raise NotImplementedError + def register_history_persistence_backend(self, backend) -> None: + raise NotImplementedError + @property def nbi_config(self) -> NBIConfig: raise NotImplementedError @@ -971,6 +1001,10 @@ def get_skill_manager(self): def websocket_connector(self) -> ThreadSafeWebSocketConnector: raise NotImplementedError + @property + def history_persistence(self) -> Any: + return NotImplementedError + class NotebookIntelligenceExtension: @property diff --git a/notebook_intelligence/config.py b/notebook_intelligence/config.py index bd0dae7..6b51487 100644 --- a/notebook_intelligence/config.py +++ b/notebook_intelligence/config.py @@ -6,6 +6,7 @@ import stat import sys import tempfile +import copy from typing import Optional from notebook_intelligence.feature_flags import ( @@ -230,6 +231,64 @@ def active_rules(self) -> dict: """Get dictionary of active rule states (filename -> bool).""" return self.get('active_rules', {}) + @property + def history_backend_configs(self) -> dict: + """Get history persistence backend configuration by backend id.""" + defaults = { + 'mysql': { + 'host': 'localhost', + 'port': 3306, + 'user': '', + 'password': '', + 'database': 'notebook_intelligence' + }, + 'sqlite': { + 'path': os.path.join(self.nbi_user_dir, 'history.sqlite3') + } + } + + merged = copy.deepcopy(defaults) + configured = self.get('history_backend_configs', {}) + if isinstance(configured, dict): + for backend_id, backend_cfg in configured.items(): + if not isinstance(backend_cfg, dict): + continue + merged.setdefault(backend_id, {}) + merged[backend_id].update(backend_cfg) + + legacy_mysql = self.get('mysql_config', None) + if isinstance(legacy_mysql, dict): + merged['mysql'].update( + { + key: value + for key, value in legacy_mysql.items() + if key != 'enabled' + } + ) + return merged + + @property + def history_config(self) -> dict: + """Get chat history storage configuration.""" + cfg = self.get('history_config', {}) + mode = cfg.get('mode', 'local') + backend = cfg.get('backend', 'sqlite') + if mode in ['mysql', 'sqlite']: + backend = mode + mode = 'persistent' + local_max_messages = cfg.get('local_max_messages', 10) + try: + local_max_messages = int(local_max_messages) + except Exception: + local_max_messages = 10 + if local_max_messages < 1: + local_max_messages = 1 + return { + 'mode': mode if mode in ['persistent', 'local', 'none'] else 'local', + 'backend': backend if isinstance(backend, str) and backend else 'sqlite', + 'local_max_messages': local_max_messages + } + def set_rule_active(self, filename: str, active: bool): """Set the active state of a rule.""" active_rules = self.active_rules.copy() diff --git a/notebook_intelligence/extension.py b/notebook_intelligence/extension.py index 75f52d2..f4dcc0c 100644 --- a/notebook_intelligence/extension.py +++ b/notebook_intelligence/extension.py @@ -3,6 +3,7 @@ import asyncio import atexit import base64 +import copy from dataclasses import asdict, dataclass import json from os import path @@ -75,6 +76,7 @@ def _token_count(text: str) -> int: return len(tiktoken_encoding.encode(text)) +shared_chat_history = None def _truncate_context_content(content: str, token_budget: int) -> str: @@ -155,6 +157,76 @@ def _resolve_supports_vision(ai_service_manager) -> bool: return chat_model.supports_vision if chat_model is not None else False +def _extract_user_id_from_principal(user) -> str | None: + if not user: + return None + + if isinstance(user, dict): + for key in ("name", "username", "user", "id"): + value = user.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + for attr in ("name", "username", "user", "id"): + value = getattr(user, attr, None) + if isinstance(value, str) and value.strip(): + return value.strip() + + rendered = str(user).strip() + return rendered or None + + +def _resolve_request_user_id(handler) -> str: + user_id = _extract_user_id_from_principal(getattr(handler, "current_user", None)) + if user_id: + return user_id + + env_user_id = os.environ.get("JUPYTERHUB_USER", "").strip() + if env_user_id: + return env_user_id + + return "unknown" + + +def _normalize_history_timestamp(value): + """Serialize persisted timestamps with an explicit timezone offset. + + SQLite/MySQL may hand back naive UTC values for ``created_at`` while the + live websocket path uses ISO timestamps with timezone information. If the + browser replays a naive timestamp, it interprets it in local time and the + visible transcript clock shifts after refresh. Treat naive persisted values + as UTC and always return an offset-aware ISO string. + """ + if value is None: + return None + if isinstance(value, dt.datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=dt.timezone.utc) + return value.isoformat() + if isinstance(value, str): + text = value.strip() + if not text: + return value + try: + parsed = dt.datetime.fromisoformat(text.replace("Z", "+00:00")) + except ValueError: + return value + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=dt.timezone.utc) + return parsed.isoformat() + return value + + +def _is_empty_assistant_history_message(message): + return ( + message.get("role") == "assistant" + and message.get("content") in (None, "") + and not message.get("reasoning_content") + and not message.get("tool_calls") + ) + + def _resolve_policy_with_env(env_var_name: str, traitlet_value: str) -> str: """Resolve a feature policy: env var wins if valid, else traitlet. @@ -552,6 +624,10 @@ def is_provider_enabled(provider_id: str) -> bool: nbi_config.claude_settings, self.string_overrides ), "spinner_verbs": _read_claude_spinner_verbs(), + "history_config": nbi_config.history_config, + "history_backends": ai_service_manager.history_persistence.available_backends(), + "history_backend_configs": nbi_config.history_backend_configs, + "current_user_id": _resolve_request_user_id(self), "claude_models": ai_service_manager.claude_models, # Drive launcher-tile visibility (issues #183, #260). Each flag # gates one tile under the "Coding Agent" category. Detection is @@ -610,7 +686,7 @@ class ConfigHandler(APIHandler): string_overrides = {} @tornado.web.authenticated - def post(self): + async def post(self): data = json.loads(self.request.body) valid_keys = set([ "default_chat_mode", @@ -624,6 +700,8 @@ def post(self): "enable_output_followup", "enable_output_toolbar", "refresh_open_files_on_disk_change", + "history_config", + "history_backend_configs", ]) # Top-level keys whose write is rejected outright when locked. locked_keys = set() @@ -651,6 +729,13 @@ def post(self): has_model_change = False has_claude_settings_change = False + has_history_settings_change = False + previous_history_config_user = copy.deepcopy( + ai_service_manager.nbi_config.user_config.get("history_config") + ) + previous_history_backend_configs_user = copy.deepcopy( + ai_service_manager.nbi_config.user_config.get("history_backend_configs") + ) for key in data: if key in locked_keys: continue @@ -706,10 +791,62 @@ def post(self): if isinstance(default_chat_participant, ClaudeCodeChatParticipant): # needed to disconnect default_chat_participant.update_client_debounced() + elif key == "history_config": + has_history_settings_change = True + elif key == "history_backend_configs": + has_history_settings_change = True + + legacy_mysql_config = data.get("mysql_config") + if isinstance(legacy_mysql_config, dict): + merged_backend_configs = ai_service_manager.nbi_config.history_backend_configs + merged_backend_configs["mysql"] = { + **merged_backend_configs.get("mysql", {}), + **{ + key: value + for key, value in legacy_mysql_config.items() + if key != "enabled" + }, + } + ai_service_manager.nbi_config.set("history_backend_configs", merged_backend_configs) + has_history_settings_change = True ai_service_manager.nbi_config.save() if has_model_change or has_claude_settings_change: ai_service_manager.update_models_from_config() + if has_history_settings_change: + ai_service_manager.update_history_persistence() + history_cfg = ai_service_manager.nbi_config.history_config + if history_cfg.get("mode") == "persistent": + # Validate the selected backend immediately. On failure, roll + # back to the last active config so draft settings in the UI do + # not silently replace the working backend. + ok, err = await ai_service_manager.history_persistence.test_connection() + if not ok: + if previous_history_config_user is None: + ai_service_manager.nbi_config.user_config.pop( + "history_config", None + ) + else: + ai_service_manager.nbi_config.user_config["history_config"] = ( + previous_history_config_user + ) + if previous_history_backend_configs_user is None: + ai_service_manager.nbi_config.user_config.pop( + "history_backend_configs", None + ) + else: + ai_service_manager.nbi_config.user_config[ + "history_backend_configs" + ] = previous_history_backend_configs_user + ai_service_manager.nbi_config.save() + ai_service_manager.update_history_persistence() + self.set_status(400) + self.finish(json.dumps({ + "error": f"History backend connection failed: {err}. Active history settings were not changed.", + "history_config": ai_service_manager.nbi_config.history_config, + "history_backend_configs": ai_service_manager.nbi_config.history_backend_configs + })) + return if has_claude_settings_change: default_chat_participant = ai_service_manager.default_chat_participant if isinstance(default_chat_participant, ClaudeCodeChatParticipant): @@ -1840,55 +1977,14 @@ def post(self): self.finish(json.dumps({"success": True, "session_id": session_id})) -class ChatHistory: - """ - History of chat messages, key is chat id, value is list of messages - keep the last 10 messages in the same chat participant - """ - MAX_MESSAGES = 10 - - def __init__(self): - self.messages = {} - - def clear(self, chatId = None): - if chatId is None: - self.messages = {} - return True - elif chatId in self.messages: - del self.messages[chatId] - return True - - return False - - def add_message(self, chatId, message): - if chatId not in self.messages: - self.messages[chatId] = [] - - # clear the chat history if participant changed - if message["role"] == "user": - existing_messages = self.messages[chatId] - prev_user_message = next((m for m in reversed(existing_messages) if m["role"] == "user"), None) - if prev_user_message is not None: - current_prompt_parts = AIServiceManager.parse_prompt(message["content"]) - prev_prompt_parts = AIServiceManager.parse_prompt(prev_user_message["content"]) - if current_prompt_parts.participant != prev_prompt_parts.participant: - self.messages[chatId] = [] - - self.messages[chatId].append(message) - # limit number of messages kept in history - if len(self.messages[chatId]) > ChatHistory.MAX_MESSAGES: - self.messages[chatId] = self.messages[chatId][-ChatHistory.MAX_MESSAGES:] - - def get_history(self, chatId): - return self.messages.get(chatId, []) - class WebsocketCopilotResponseEmitter(ChatResponse): - def __init__(self, chatId, messageId, websocket_handler, chat_history): + def __init__(self, chatId, messageId, websocket_handler, chat_history, conversation_id=None): super().__init__() self.chatId = chatId self.messageId = messageId self.websocket_handler = websocket_handler self.chat_history = chat_history + self.conversation_id = conversation_id self.streamed_contents = [] self.streamed_reasoning_contents = [] # Capture the Tornado IOLoop the websocket lives on. stream() / @@ -1900,12 +1996,25 @@ def __init__(self, chatId, messageId, websocket_handler, chat_history): # data: object cannot be re-sized` (issue #264). Marshaling the # write back to the IOLoop's thread fixes it. self._io_loop = tornado.ioloop.IOLoop.current() + self.streamed_tool_calls = [] + self.streamed_markdown_parts = [] def _send_async(self, message: dict) -> None: self._io_loop.asyncio_loop.call_soon_threadsafe( self.websocket_handler.write_message, message ) + def append_tool_calls(self, tool_calls: list[dict] | None) -> None: + if not isinstance(tool_calls, list): + return + for tool_call in tool_calls: + if isinstance(tool_call, dict): + self.streamed_tool_calls.append(copy.deepcopy(tool_call)) + + def append_history_message(self, message: dict) -> None: + user_id = _resolve_request_user_id(self.websocket_handler) + self.chat_history.add_message(self.chatId, message, user_id=user_id) + @property def chat_id(self) -> str: return self.chatId @@ -1918,7 +2027,16 @@ def stream(self, data: Union[ResponseStreamData, dict]): data_type = ResponseStreamDataType.LLMRaw if type(data) is dict else data.data_type if data_type == ResponseStreamDataType.Markdown: - self.chat_history.add_message(self.chatId, {"role": "assistant", "content": data.content, "reasoning_content": data.reasoning_content}) + if data.content is not None: + self.streamed_contents.append(data.content) + if data.reasoning_content is not None: + self.streamed_reasoning_contents.append(data.reasoning_content) + self.streamed_markdown_parts.append({ + "type": "markdown", + "content": data.content or "", + "reasoning_content": data.reasoning_content, + "detail": data.detail + }) data = { "choices": [ { @@ -2118,28 +2236,73 @@ def stream(self, data: Union[ResponseStreamData, dict]): self.streamed_contents.append(content) if reasoning_content is not None: self.streamed_reasoning_contents.append(reasoning_content) + + # Now common part for all types to actually write to websocket + if data_type != ResponseStreamDataType.LLMRaw: + self._send_async({ + "id": self.messageId, + "participant": self.participant_id, + "type": BackendMessageType.StreamMessage, + "data": data, + "created": dt.datetime.now().isoformat() + }) else: # ResponseStreamDataType.LLMRaw if len(data.get("choices", [])) > 0: delta = data["choices"][0].get("delta", {}) content = delta.get("content", "") reasoning_content = delta.get("reasoning_content", "") + tool_calls = delta.get("tool_calls") if content is not None: self.streamed_contents.append(content) if reasoning_content is not None: self.streamed_reasoning_contents.append(reasoning_content) - - self._send_async({ - "id": self.messageId, - "participant": self.participant_id, - "type": BackendMessageType.StreamMessage, - "data": data, - "created": dt.datetime.now().isoformat() - }) + if isinstance(tool_calls, list): + self.append_tool_calls(tool_calls) + + self._send_async({ + "id": self.messageId, + "participant": self.participant_id, + "type": BackendMessageType.StreamMessage, + "data": data, + "created": dt.datetime.now().isoformat() + }) def finish(self) -> None: - self.chat_history.add_message(self.chatId, {"role": "assistant", "content": "".join(self.streamed_contents), "reasoning_content": "".join(self.streamed_reasoning_contents)}) + content = "".join(self.streamed_contents) + reasoning_content = "".join(self.streamed_reasoning_contents) + persisted_tool_calls = list(self.streamed_tool_calls) + persisted_ui_parts = list(self.streamed_markdown_parts) + + if content or reasoning_content or persisted_tool_calls or persisted_ui_parts: + user_id = _resolve_request_user_id(self.websocket_handler) + self.chat_history.add_message( + self.chatId, + { + "role": "assistant", + "content": content, + "reasoning_content": reasoning_content, + "tool_calls": persisted_tool_calls, + "ui_parts": persisted_ui_parts, + }, + user_id=user_id, + ) + + if self.conversation_id: + msg_id = str(uuid.uuid4()) + ai_service_manager.history_persistence.add_message( + msg_id, + self.conversation_id, + "assistant", + content, + reasoning_content, + tool_calls=persisted_tool_calls, + ui_parts=persisted_ui_parts, + ) + self.streamed_contents = [] self.streamed_reasoning_contents = [] + self.streamed_tool_calls = [] + self.streamed_markdown_parts = [] self._send_async({ "id": self.messageId, "participant": self.participant_id, @@ -2176,13 +2339,14 @@ class MessageCallbackHandlers: response_emitter: WebsocketCopilotResponseEmitter cancel_token: CancelTokenImpl -class WebsocketCopilotHandler(WebSocketMixin, websocket.WebSocketHandler, JupyterHandler): +class WebsocketCopilotHandler(APIHandler, WebSocketMixin, websocket.WebSocketHandler, JupyterHandler): # Cap WS message size at 4 MiB. Largest legitimate payload is a chat # request with ~10 attached output-context items (each capped at 1 MiB # by `coerce_payload`) + chat history; 4 MiB covers that without # leaving the default 10 MiB headroom for memory amplification. max_message_size = 4 * 1024 * 1024 + chat_history_ref = None # Inheritance matches Jupyter's first-party WS handlers (e.g. # KernelWebsocketHandler): ``WebSocketMixin`` adds ping/pong # keepalive plus a ``prepare`` that routes through Jupyter's @@ -2204,7 +2368,11 @@ def __init__(self, application, request, context_factory=None, **kwargs): # websocket — every long chat session leaked one emitter + # cancel token per turn. self._messageCallbackHandlers: dict[str, MessageCallbackHandlers] = {} - self.chat_history = ChatHistory() + global shared_chat_history + if shared_chat_history is None: + shared_chat_history = ChatHistory() + self.chat_history = shared_chat_history + WebsocketCopilotHandler.chat_history_ref = self.chat_history self._context_factory = context_factory or RuleContextFactory() ws_connector = ThreadSafeWebSocketConnector(self) ai_service_manager.websocket_connector = ws_connector @@ -2236,7 +2404,7 @@ def open(self): self.request.headers.get("Origin"), ) - def on_message(self, message): + async def on_message(self, message): msg = json.loads(message) messageId = msg['id'] @@ -2256,8 +2424,23 @@ def on_message(self, message): extension_tools=toolSelections.get('extensions', {}) ) + # Persist the user message immediately when a history backend is active. + conversation_id = str(uuid.uuid4()) + + user_id = _resolve_request_user_id(self) + + user_message_id = str(uuid.uuid4()) + ai_service_manager.history_persistence.create_conversation_with_message( + conversation_id, user_id, chatId, chat_mode.id, + user_message_id, "user", prompt + ) + is_claude_code_mode = ai_service_manager.is_claude_code_mode - chat_history = self.chat_history.get_history(chatId) + # Copy current chat history for request context building, do not + # mutate shared in-memory history with transient context entries. + chat_history = list( + await self.chat_history.get_history(chatId, user_id=user_id) + ) chat_history_initial_size = len(chat_history) current_directory = data.get('currentDirectory') @@ -2481,8 +2664,13 @@ def on_message(self, message): chat_history.append({"role": "user", "content": context_message}) chat_history.append({"role": "user", "content": prompt}) + # Persist the real user prompt in shared in-memory history so + # refresh fallback stays consistent even when DB data is delayed. + self.chat_history.add_message( + chatId, {"role": "user", "content": prompt}, user_id=user_id + ) - response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history) + response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history, conversation_id=conversation_id) cancel_token = CancelTokenImpl() self._messageCallbackHandlers[messageId] = MessageCallbackHandlers(response_emitter, cancel_token) @@ -2496,7 +2684,7 @@ def on_message(self, message): # last prompt is added later request_chat_history = chat_history[chat_history_initial_size:-1] if is_claude_code_mode else chat_history[:-1] - coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, chat_history=request_chat_history, cancel_token=cancel_token, rule_context=rule_context), response_emitter) + coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, chat_history=request_chat_history, cancel_token=cancel_token, rule_context=rule_context, conversation_id=conversation_id), response_emitter) thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId)) thread.start() elif messageType == RequestDataType.GenerateCode: @@ -2510,14 +2698,51 @@ def on_message(self, message): filename = data['filename'] is_claude_code_mode = ai_service_manager.is_claude_code_mode chat_mode = ChatMode('inline-chat', 'Inline Chat') if is_claude_code_mode else ChatMode('ask', 'Ask') + + # Persist the user message immediately when a history backend is active. + conversation_id = str(uuid.uuid4()) + + user_id = _resolve_request_user_id(self) + + user_message_id = str(uuid.uuid4()) + ai_service_manager.history_persistence.create_conversation_with_message( + conversation_id, user_id, chatId, "inline-chat", + user_message_id, "user", prompt + ) + if prefix != '': - self.chat_history.add_message(chatId, {"role": "user", "content": f"This code section comes before the code section you will generate, use as context. Leading content: ```{prefix}```"}) + self.chat_history.add_message( + chatId, + { + "role": "user", + "content": f"This code section comes before the code section you will generate, use as context. Leading content: ```{prefix}```", + }, + user_id=user_id, + ) if suffix != '': - self.chat_history.add_message(chatId, {"role": "user", "content": f"This code section comes after the code section you will generate, use as context. Trailing content: ```{suffix}```"}) + self.chat_history.add_message( + chatId, + { + "role": "user", + "content": f"This code section comes after the code section you will generate, use as context. Trailing content: ```{suffix}```", + }, + user_id=user_id, + ) if existing_code != '': - self.chat_history.add_message(chatId, {"role": "user", "content": f"You are asked to modify the existing code. Generate a replacement for this existing code : ```{existing_code}```"}) - self.chat_history.add_message(chatId, {"role": "user", "content": f"Generate code for: {prompt}"}) - response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history) + self.chat_history.add_message( + chatId, + { + "role": "user", + "content": f"You are asked to modify the existing code. Generate a replacement for this existing code : ```{existing_code}```", + }, + user_id=user_id, + ) + self.chat_history.add_message( + chatId, + {"role": "user", "content": f"Generate code for: {prompt}"}, + user_id=user_id, + ) + response_emitter = WebsocketCopilotResponseEmitter(chatId, messageId, self, self.chat_history, conversation_id=conversation_id) cancel_token = CancelTokenImpl() self._messageCallbackHandlers[messageId] = MessageCallbackHandlers(response_emitter, cancel_token) existing_code_message = " Update the existing code section and return a modified version. Don't just return the update, recreate the existing code section with the update." if existing_code != '' else '' @@ -2531,7 +2756,8 @@ def on_message(self, message): root_dir=NotebookIntelligence.root_dir ) - coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, chat_history=self.chat_history.get_history(chatId), cancel_token=cancel_token, rule_context=rule_context), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."}) + chat_history = await self.chat_history.get_history(chatId, user_id=user_id) + coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, chat_history=chat_history, cancel_token=cancel_token, rule_context=rule_context, conversation_id=conversation_id), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."}) thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId)) thread.start() elif messageType == RequestDataType.InlineCompletionRequest: @@ -2561,7 +2787,10 @@ def on_message(self, message): default_chat_participant = ai_service_manager.default_chat_participant if isinstance(default_chat_participant, ClaudeCodeChatParticipant): default_chat_participant.clear_chat_history() - self.chat_history.clear() + chat_id = msg.get("data", {}).get("chatId") + self.chat_history.clear( + chat_id, user_id=_resolve_request_user_id(self) + ) elif messageType == RequestDataType.RunUICommandResponse: handlers = self._messageCallbackHandlers.get(messageId) if handlers is None: @@ -2600,6 +2829,296 @@ async def handle_inline_completions(prefix, suffix, language, filename, response response_emitter.stream({"completions": completions}) response_emitter.finish() + +class GetChatHistoryHandler(APIHandler): + @tornado.web.authenticated + async def get(self): + chat_id = self.get_argument("chatId", None) + if not chat_id: + self.set_status(400) + self.finish(json.dumps({"error": "chatId is required"})) + return + + history_mode = ai_service_manager.nbi_config.history_config.get("mode", "local") + user_id = _resolve_request_user_id(self) + messages = [] + if history_mode == "none": + messages = [] + elif history_mode == "persistent": + messages = await ai_service_manager.history_persistence.get_messages_by_chat_id( + chat_id, user_id=user_id + ) + else: + global shared_chat_history + if shared_chat_history is None: + shared_chat_history = ChatHistory() + in_memory = await shared_chat_history.get_history(chat_id, user_id=user_id) + ws_history = [] + if WebsocketCopilotHandler.chat_history_ref is not None: + ws_history = await WebsocketCopilotHandler.chat_history_ref.get_history( + chat_id, user_id=user_id + ) + if len(ws_history) > len(in_memory): + in_memory = ws_history + messages = [] + for item in in_memory: + messages.append({ + "role": item.get("role", "assistant"), + "content": item.get("content", ""), + "reasoning_content": item.get("reasoning_content"), + "tool_calls": item.get("tool_calls"), + "ui_parts": item.get("ui_parts"), + "tool_call_id": item.get("tool_call_id"), + "created_at": item.get("created_at"), + }) + # Convert datetime to string for JSON serialization + for msg in messages: + msg['created_at'] = _normalize_history_timestamp( + msg.get('message_order_at') or msg.get('created_at') + ) + msg.pop('message_order_at', None) + for field_name in ("tool_calls", "ui_parts"): + if msg.get(field_name): + try: + msg[field_name] = json.loads(msg[field_name]) + except: + pass + + self.finish(json.dumps({"messages": messages})) + +class GetRecentConversationsHandler(APIHandler): + @tornado.web.authenticated + async def get(self): + history_mode = ai_service_manager.nbi_config.history_config.get("mode", "local") + if history_mode == "none": + self.finish(json.dumps({"conversations": []})) + return + + user_id = _resolve_request_user_id(self) + + if history_mode == "local": + global shared_chat_history + if shared_chat_history is None: + shared_chat_history = ChatHistory() + # Keep ordering stable: most recently touched chat first. Local + # mode has in-memory transcripts only, so surface chat IDs with + # a synthetic timestamp for the sidebar conversation picker. + now = dt.datetime.now(dt.timezone.utc).isoformat() + conversation_ids = shared_chat_history.list_conversation_ids(user_id=user_id) + conversations = [ + { + "chat_id": chat_id, + "chat_mode": "ask", + "last_message_at": now, + } + for chat_id in reversed(conversation_ids) + ] + self.finish(json.dumps({"conversations": conversations})) + return + + conversations = await ai_service_manager.history_persistence.get_recent_conversations(user_id) + for conv in conversations: + if isinstance(conv.get('last_message_at'), dt.datetime): + conv['last_message_at'] = conv['last_message_at'].isoformat() + + self.finish(json.dumps({"conversations": conversations})) + +class ChatHistory: + """ + History of chat messages, key is chat id, value is list of messages + keep the last 10 messages in the same chat participant + """ + DEFAULT_MAX_MESSAGES = 10 + + def __init__(self): + self.messages = {} + + @staticmethod + def _scope_key(chat_id, user_id=None): + if not user_id: + return chat_id + return f"{user_id}\0{chat_id}" + + @staticmethod + def _split_scope_key(scoped_chat_id): + if "\0" not in scoped_chat_id: + return None, scoped_chat_id + user_id, chat_id = scoped_chat_id.split("\0", 1) + return user_id, chat_id + + def clear(self, chatId = None, user_id=None): + scoped_chat_id = self._scope_key(chatId, user_id=user_id) if chatId is not None else None + if chatId is None: + self.messages = {} + return True + elif scoped_chat_id in self.messages: + del self.messages[scoped_chat_id] + return True + + return False + + def add_message(self, chatId, message, user_id=None): + history_mode = ai_service_manager.nbi_config.history_config.get("mode", "local") + scoped_chat_id = self._scope_key(chatId, user_id=user_id) + message = self._normalize_message(message) + if message is None: + return + + if scoped_chat_id not in self.messages: + self.messages[scoped_chat_id] = [] + + # clear the chat history if participant changed + if message["role"] == "user": + existing_messages = self.messages[scoped_chat_id] + prev_user_message = next((m for m in reversed(existing_messages) if m["role"] == "user"), None) + if prev_user_message is not None: + current_prompt_parts = AIServiceManager.parse_prompt(message["content"]) + prev_prompt_parts = AIServiceManager.parse_prompt(prev_user_message["content"]) + if current_prompt_parts.participant != prev_prompt_parts.participant: + self.messages[scoped_chat_id] = [] + + self.messages[scoped_chat_id].append(message) + # limit number of messages kept in history only in local mode + if history_mode == "local": + max_messages = ai_service_manager.nbi_config.history_config.get( + "local_max_messages", ChatHistory.DEFAULT_MAX_MESSAGES + ) + if len(self.messages[scoped_chat_id]) > max_messages: + self.messages[scoped_chat_id] = self.messages[scoped_chat_id][-max_messages:] + + async def get_history(self, chatId, user_id=None): + history_mode = ai_service_manager.nbi_config.history_config.get("mode", "local") + scoped_chat_id = self._scope_key(chatId, user_id=user_id) + if scoped_chat_id not in self.messages: + if history_mode == "persistent": + messages = await ai_service_manager.history_persistence.get_messages_by_chat_id( + chatId, user_id=user_id + ) + if messages: + self.messages[scoped_chat_id] = [ + normalized + for normalized in ( + self._normalize_message(m, assign_created_at=False) + for m in messages + ) + if normalized is not None + ] + else: + self.messages[scoped_chat_id] = [] + + self.messages[scoped_chat_id] = self._canonicalize_messages( + [ + normalized + for normalized in ( + self._normalize_message(m) for m in self.messages.get(scoped_chat_id, []) + ) + if normalized is not None + ] + ) + + if history_mode == "local": + max_messages = ai_service_manager.nbi_config.history_config.get( + "local_max_messages", ChatHistory.DEFAULT_MAX_MESSAGES + ) + if len(self.messages[scoped_chat_id]) > max_messages: + self.messages[scoped_chat_id] = self.messages[scoped_chat_id][-max_messages:] + + return self.messages.get(scoped_chat_id, []) + + @staticmethod + def _maybe_parse_json_field(value): + if value is None or isinstance(value, (list, dict)): + return value + if not isinstance(value, str): + return value + try: + return json.loads(value) + except Exception: + return value + + @classmethod + def _normalize_message(cls, message, assign_created_at=True): + if not isinstance(message, dict): + return None + + normalized = dict(message) + normalized["role"] = normalized.get("role", "assistant") + normalized["content"] = normalized.get("content", "") + normalized["reasoning_content"] = normalized.get("reasoning_content") + normalized["tool_calls"] = cls._maybe_parse_json_field(normalized.get("tool_calls")) + normalized["ui_parts"] = cls._maybe_parse_json_field(normalized.get("ui_parts")) + normalized["tool_call_id"] = normalized.get("tool_call_id") + + created_at = _normalize_history_timestamp( + normalized.get("message_order_at") or normalized.get("created_at") + ) + if created_at is None and assign_created_at: + created_at = dt.datetime.now(dt.timezone.utc).isoformat() + normalized["created_at"] = created_at + normalized.pop("message_order_at", None) + return normalized + + @classmethod + def _canonicalize_messages(cls, messages): + canonical_messages = [] + pending_tool_messages = [] + + for message in messages: + if message.get("role") == "tool": + pending_tool_messages.append(message) + continue + + if message.get("role") == "assistant" and message.get("tool_calls"): + tool_call_ids = { + tool_call.get("id") + for tool_call in message.get("tool_calls", []) + if isinstance(tool_call, dict) and tool_call.get("id") + } + matched_tool_messages = [ + tool_message + for tool_message in pending_tool_messages + if tool_message.get("tool_call_id") in tool_call_ids + ] + unmatched_tool_messages = [ + tool_message + for tool_message in pending_tool_messages + if tool_message.get("tool_call_id") not in tool_call_ids + ] + canonical_messages.extend(unmatched_tool_messages) + pending_tool_messages = [] + + if matched_tool_messages: + tool_call_message = dict(message) + tool_call_message["content"] = None + tool_call_message["reasoning_content"] = None + tool_call_message["ui_parts"] = None + canonical_messages.append(tool_call_message) + canonical_messages.extend(matched_tool_messages) + + final_assistant_message = dict(message) + final_assistant_message["tool_calls"] = None + if not _is_empty_assistant_history_message(final_assistant_message): + canonical_messages.append(final_assistant_message) + continue + + canonical_messages.extend(pending_tool_messages) + pending_tool_messages = [] + canonical_messages.append(message) + + canonical_messages.extend(pending_tool_messages) + return canonical_messages + + def list_conversation_ids(self, user_id=None): + if user_id is None: + return list(self.messages.keys()) + + result = [] + for scoped_chat_id in self.messages.keys(): + stored_user_id, chat_id = self._split_scope_key(scoped_chat_id) + if stored_user_id == user_id: + result.append(chat_id) + return result + class NotebookIntelligence(ExtensionApp): name = "notebook_intelligence" default_url = "/notebook-intelligence" @@ -3202,6 +3721,8 @@ def _setup_handlers(self, web_app, feature_policies: dict, string_overrides: dic r"([^/]+)", "update", ) + route_pattern_history = url_path_join(base_url, "notebook-intelligence", "history") + route_pattern_conversations = url_path_join(base_url, "notebook-intelligence", "conversations") GetCapabilitiesHandler.disabled_tools = self.disabled_tools GetCapabilitiesHandler.allow_enabling_tools_with_env = self.allow_enabling_tools_with_env GetCapabilitiesHandler.disabled_providers = self.disabled_providers @@ -3332,6 +3853,8 @@ def _setup_handlers(self, web_app, feature_policies: dict, string_overrides: dic (route_pattern_plugins_marketplace, PluginsMarketplaceListHandler), (route_pattern_plugins_detail, PluginsDetailHandler), (route_pattern_plugins, PluginsListHandler), + (route_pattern_history, GetChatHistoryHandler), + (route_pattern_conversations, GetRecentConversationsHandler), (route_pattern_copilot, WebsocketCopilotHandler), ] web_app.add_handlers(host_pattern, NotebookIntelligence.handlers) diff --git a/notebook_intelligence/github_copilot.py b/notebook_intelligence/github_copilot.py index 3dedac5..b888183 100644 --- a/notebook_intelligence/github_copilot.py +++ b/notebook_intelligence/github_copilot.py @@ -14,6 +14,7 @@ import logging from notebook_intelligence.api import BackendMessageType, CancelToken, ChatResponse, CompletionContext, MarkdownData from notebook_intelligence.config import _atomic_write_json +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls from notebook_intelligence.util import decrypt_with_password, encrypt_with_password, ThreadSafeWebSocketConnector from ._version import __version__ as NBI_VERSION @@ -1084,9 +1085,10 @@ def completions(model_id, messages, tools = None, response: ChatResponse = None, aggregate = response is None try: + sanitized_messages = sanitize_chat_history_tool_calls(messages) data = { 'model': model_id, - 'messages': messages, + 'messages': sanitized_messages, 'tools': tools, 'temperature': 0, 'top_p': 1, diff --git a/notebook_intelligence/history_backends/__init__.py b/notebook_intelligence/history_backends/__init__.py new file mode 100644 index 0000000..aca4baf --- /dev/null +++ b/notebook_intelligence/history_backends/__init__.py @@ -0,0 +1,15 @@ +from notebook_intelligence.history_backends.base import ( + HistoryBackendField, + HistoryPersistenceBackend, + HistoryPersistenceManager, +) +from notebook_intelligence.history_backends.mysql import MySQLHistoryBackend +from notebook_intelligence.history_backends.sqlite import SQLiteHistoryBackend + +__all__ = [ + "HistoryBackendField", + "HistoryPersistenceBackend", + "HistoryPersistenceManager", + "MySQLHistoryBackend", + "SQLiteHistoryBackend", +] diff --git a/notebook_intelligence/history_backends/base.py b/notebook_intelligence/history_backends/base.py new file mode 100644 index 0000000..bd02315 --- /dev/null +++ b/notebook_intelligence/history_backends/base.py @@ -0,0 +1,245 @@ +import copy +from dataclasses import dataclass +from typing import Any + + +@dataclass +class HistoryBackendField: + key: str + label: str + input_type: str = "text" + placeholder: str = "" + help_text: str = "" + + def to_dict(self) -> dict[str, Any]: + return { + "key": self.key, + "label": self.label, + "input_type": self.input_type, + "placeholder": self.placeholder, + "help_text": self.help_text, + } + + +class HistoryPersistenceBackend: + @property + def id(self) -> str: + raise NotImplementedError + + @property + def name(self) -> str: + raise NotImplementedError + + @property + def description(self) -> str: + return "" + + @property + def fields(self) -> list[HistoryBackendField]: + return [] + + def configure(self, config: dict[str, Any]) -> None: + raise NotImplementedError + + async def test_connection(self) -> tuple[bool, str]: + raise NotImplementedError + + def create_conversation_with_message( + self, + conv_id: str, + user_id: str, + chat_id: str, + chat_mode: str, + msg_id: str, + role: str, + content: str, + ) -> None: + raise NotImplementedError + + def create_conversation( + self, conv_id: str, user_id: str, chat_id: str, chat_mode: str + ) -> None: + raise NotImplementedError + + def add_message( + self, + message_id: str, + conv_id: str, + role: str, + content: str, + reasoning_content: str | None = None, + tool_calls: list[dict] | None = None, + ui_parts: list[dict] | None = None, + tool_call_id: str | None = None, + ) -> None: + raise NotImplementedError + + def log_tool_execution( + self, + tool_call_id: str, + conv_id: str, + tool_name: str, + arguments: dict, + output: str, + ) -> None: + raise NotImplementedError + + async def get_messages_by_chat_id( + self, chat_id: str, user_id: str | None = None + ) -> list[dict[str, Any]]: + raise NotImplementedError + + async def get_recent_conversations( + self, user_id: str, limit: int = 20 + ) -> list[dict[str, Any]]: + raise NotImplementedError + + def to_wire(self) -> dict[str, Any]: + return { + "id": self.id, + "name": self.name, + "description": self.description, + "fields": [field.to_dict() for field in self.fields], + } + + +class HistoryPersistenceManager: + def __init__(self): + self._backends: dict[str, HistoryPersistenceBackend] = {} + self._history_config: dict[str, Any] = {} + self._backend_configs: dict[str, dict[str, Any]] = {} + + def register_backend(self, backend: HistoryPersistenceBackend) -> None: + if backend.id in self._backends: + raise ValueError(f"History backend '{backend.id}' is already registered.") + self._backends[backend.id] = backend + backend.configure(copy.deepcopy(self._backend_configs.get(backend.id, {}))) + + def reconfigure( + self, + history_config: dict[str, Any], + backend_configs: dict[str, dict[str, Any]], + ) -> None: + self._history_config = self._coerce_dict(history_config) + self._backend_configs = self._coerce_nested_dict(backend_configs) + for backend in self._backends.values(): + backend.configure(copy.deepcopy(self._backend_configs.get(backend.id, {}))) + + @staticmethod + def _coerce_dict(value: Any) -> dict[str, Any]: + return dict(value) if isinstance(value, dict) else {} + + @classmethod + def _coerce_nested_dict(cls, value: Any) -> dict[str, dict[str, Any]]: + if not isinstance(value, dict): + return {} + result = {} + for key, nested in value.items(): + result[key] = dict(nested) if isinstance(nested, dict) else {} + return result + + @property + def mode(self) -> str: + return self._history_config.get("mode", "local") + + @property + def backend_id(self) -> str: + return self._history_config.get("backend", "") + + @property + def backend_configs(self) -> dict[str, dict[str, Any]]: + return copy.deepcopy(self._backend_configs) + + @property + def active_backend(self) -> HistoryPersistenceBackend | None: + if self.mode != "persistent": + return None + return self._backends.get(self.backend_id) + + def available_backends(self) -> list[dict[str, Any]]: + return [backend.to_wire() for backend in self._backends.values()] + + async def test_connection(self) -> tuple[bool, str]: + if self.mode != "persistent": + return False, "History persistence is not enabled." + backend = self.active_backend + if backend is None: + return False, f"Unknown history backend '{self.backend_id}'." + return await backend.test_connection() + + def create_conversation_with_message( + self, + conv_id: str, + user_id: str, + chat_id: str, + chat_mode: str, + msg_id: str, + role: str, + content: str, + ) -> None: + backend = self.active_backend + if backend is not None: + backend.create_conversation_with_message( + conv_id, user_id, chat_id, chat_mode, msg_id, role, content + ) + + def create_conversation( + self, conv_id: str, user_id: str, chat_id: str, chat_mode: str + ) -> None: + backend = self.active_backend + if backend is not None: + backend.create_conversation(conv_id, user_id, chat_id, chat_mode) + + def add_message( + self, + message_id: str, + conv_id: str, + role: str, + content: str, + reasoning_content: str | None = None, + tool_calls: list[dict] | None = None, + ui_parts: list[dict] | None = None, + tool_call_id: str | None = None, + ) -> None: + backend = self.active_backend + if backend is not None: + backend.add_message( + message_id, + conv_id, + role, + content, + reasoning_content=reasoning_content, + tool_calls=tool_calls, + ui_parts=ui_parts, + tool_call_id=tool_call_id, + ) + + def log_tool_execution( + self, + tool_call_id: str, + conv_id: str, + tool_name: str, + arguments: dict, + output: str, + ) -> None: + backend = self.active_backend + if backend is not None: + backend.log_tool_execution( + tool_call_id, conv_id, tool_name, arguments, output + ) + + async def get_messages_by_chat_id( + self, chat_id: str, user_id: str | None = None + ) -> list[dict[str, Any]]: + backend = self.active_backend + if backend is None: + return [] + return await backend.get_messages_by_chat_id(chat_id, user_id=user_id) + + async def get_recent_conversations( + self, user_id: str, limit: int = 20 + ) -> list[dict[str, Any]]: + backend = self.active_backend + if backend is None: + return [] + return await backend.get_recent_conversations(user_id, limit=limit) diff --git a/notebook_intelligence/history_backends/mysql.py b/notebook_intelligence/history_backends/mysql.py new file mode 100644 index 0000000..7108ea7 --- /dev/null +++ b/notebook_intelligence/history_backends/mysql.py @@ -0,0 +1,420 @@ +# Copyright (c) Mehmet Bektas + +import asyncio +import json +import logging +from typing import Any + +from notebook_intelligence.history_backends.base import ( + HistoryBackendField, + HistoryPersistenceBackend, +) + +try: + import aiomysql + + HAS_AIOMYSQL = True +except ImportError: + aiomysql = None + HAS_AIOMYSQL = False + +log = logging.getLogger(__name__) + + +class MySQLHistoryBackend(HistoryPersistenceBackend): + def __init__(self): + self.pool = None + self.loop = None + self._lock_obj = None + self.config: dict[str, Any] = {} + self.host = "localhost" + self.port = 3306 + self.user = "" + self.password = "" + self.database = "notebook_intelligence" + + @property + def id(self) -> str: + return "mysql" + + @property + def name(self) -> str: + return "MySQL" + + @property + def description(self) -> str: + return "Persist chat history to a remote MySQL database." + + @property + def fields(self) -> list[HistoryBackendField]: + return [ + HistoryBackendField("host", "Host", placeholder="localhost"), + HistoryBackendField("port", "Port", input_type="number", placeholder="3306"), + HistoryBackendField("user", "User", placeholder="root"), + HistoryBackendField("password", "Password", input_type="password"), + HistoryBackendField( + "database", "Database", placeholder="notebook_intelligence" + ), + ] + + def configure(self, config: dict[str, Any]) -> None: + new_config = config or {} + if self.config == new_config: + return + + if self.pool is not None: + try: + self.pool.close() + except Exception: + pass + self.pool = None + + self.loop = None + self._lock_obj = None + self.config = dict(new_config) + self.host = self.config.get("host", "localhost") + self.port = int(self.config.get("port", 3306)) + self.user = self.config.get("user", "") + self.password = self.config.get("password", "") + self.database = self.config.get("database", "notebook_intelligence") + + def _get_lock(self): + if self._lock_obj is None: + self._lock_obj = asyncio.Lock() + return self._lock_obj + + async def _get_pool(self): + if not HAS_AIOMYSQL: + return None + + if self.pool is not None: + return self.pool + + if self.loop is None: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + return None + + async with self._get_lock(): + if self.pool is not None: + return self.pool + + try: + temp_conn = await aiomysql.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + connect_timeout=3, + autocommit=True, + ) + async with temp_conn.cursor() as cur: + await cur.execute( + f"CREATE DATABASE IF NOT EXISTS {self.database} CHARACTER SET utf8mb4;" + ) + temp_conn.close() + + self.pool = await aiomysql.create_pool( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + db=self.database, + connect_timeout=3, + autocommit=True, + charset="utf8mb4", + ) + await self._ensure_tables() + return self.pool + except Exception as e: + log.error("Failed to connect to MySQL history backend: %s", e) + return None + + async def _ensure_tables(self): + async with self.pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """ + CREATE TABLE IF NOT EXISTS nbi_conversations ( + id_pk INT AUTO_INCREMENT PRIMARY KEY, + id CHAR(36) UNIQUE, + user_id VARCHAR(255), + chat_id VARCHAR(255), + chat_mode VARCHAR(50), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """ + ) + await cur.execute( + """ + CREATE TABLE IF NOT EXISTS nbi_messages ( + id_pk INT AUTO_INCREMENT PRIMARY KEY, + id CHAR(36) UNIQUE, + conversation_id CHAR(36), + role VARCHAR(50), + content LONGTEXT, + reasoning_content LONGTEXT, + tool_calls JSON, + ui_parts JSON, + tool_call_id VARCHAR(255), + message_order_at DATETIME(6), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES nbi_conversations(id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """ + ) + try: + await cur.execute( + """ + CREATE INDEX idx_nbi_messages_order_at + ON nbi_messages(message_order_at) + """ + ) + except Exception as e: + if "Duplicate key name" not in str(e): + raise + await cur.execute( + """ + CREATE TABLE IF NOT EXISTS nbi_tool_executions ( + id_pk INT AUTO_INCREMENT PRIMARY KEY, + id VARCHAR(255) UNIQUE, + conversation_id CHAR(36), + tool_name VARCHAR(255), + arguments JSON, + output LONGTEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES nbi_conversations(id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; + """ + ) + + async def test_connection(self) -> tuple[bool, str]: + if not HAS_AIOMYSQL: + return False, "aiomysql is not installed." + pool = await self._get_pool() + if not pool: + return False, f"Unable to connect to MySQL server {self.host}:{self.port}." + return True, "" + + def _run_task(self, coro): + try: + current_loop = asyncio.get_running_loop() + except RuntimeError: + return + + if self.loop and current_loop != self.loop: + asyncio.run_coroutine_threadsafe(coro, self.loop) + else: + if self.loop is None: + self.loop = current_loop + asyncio.create_task(coro) + + def create_conversation_with_message( + self, + conv_id: str, + user_id: str, + chat_id: str, + chat_mode: str, + msg_id: str, + role: str, + content: str, + ) -> None: + self._run_task( + self._create_conversation_with_message_internal( + conv_id, user_id, chat_id, chat_mode, msg_id, role, content + ) + ) + + async def _create_conversation_with_message_internal( + self, + conv_id: str, + user_id: str, + chat_id: str, + chat_mode: str, + msg_id: str, + role: str, + content: str, + ): + await self._create_conversation_internal(conv_id, user_id, chat_id, chat_mode) + await self._add_message_internal(msg_id, conv_id, role, content) + + def create_conversation( + self, conv_id: str, user_id: str, chat_id: str, chat_mode: str + ) -> None: + self._run_task( + self._create_conversation_internal(conv_id, user_id, chat_id, chat_mode) + ) + + async def _create_conversation_internal( + self, conv_id: str, user_id: str, chat_id: str, chat_mode: str + ): + pool = await self._get_pool() + if not pool: + return + try: + async with pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "INSERT IGNORE INTO nbi_conversations (id, user_id, chat_id, chat_mode) VALUES (%s, %s, %s, %s)", + (conv_id, user_id, chat_id, chat_mode), + ) + except Exception as e: + log.error("Error creating conversation in MySQL history backend: %s", e) + + def add_message( + self, + message_id: str, + conv_id: str, + role: str, + content: str, + reasoning_content: str | None = None, + tool_calls: list[dict] | None = None, + ui_parts: list[dict] | None = None, + tool_call_id: str | None = None, + ) -> None: + if ( + not content + and not reasoning_content + and not tool_calls + and not ui_parts + and not tool_call_id + ): + return + self._run_task( + self._add_message_internal( + message_id, + conv_id, + role, + content, + reasoning_content, + tool_calls, + ui_parts, + tool_call_id, + ) + ) + + async def _add_message_internal( + self, + message_id: str, + conv_id: str, + role: str, + content: str, + reasoning_content: str | None = None, + tool_calls: list[dict] | None = None, + ui_parts: list[dict] | None = None, + tool_call_id: str | None = None, + ): + pool = await self._get_pool() + if not pool: + return + try: + tool_calls_json = json.dumps(tool_calls) if tool_calls else None + ui_parts_json = json.dumps(ui_parts) if ui_parts else None + async with pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """ + INSERT IGNORE INTO nbi_messages + (id, conversation_id, role, content, reasoning_content, tool_calls, ui_parts, tool_call_id, message_order_at) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, UTC_TIMESTAMP(6)) + """, + ( + message_id, + conv_id, + role, + content, + reasoning_content, + tool_calls_json, + ui_parts_json, + tool_call_id, + ), + ) + except Exception as e: + log.error("Error adding message to MySQL history backend: %s", e) + + def log_tool_execution( + self, tool_call_id: str, conv_id: str, tool_name: str, arguments: dict, output: str + ) -> None: + self._run_task( + self._log_tool_execution_internal( + tool_call_id, conv_id, tool_name, arguments, output + ) + ) + + async def _log_tool_execution_internal( + self, tool_call_id: str, conv_id: str, tool_name: str, arguments: dict, output: str + ): + pool = await self._get_pool() + if not pool: + return + try: + arguments_json = json.dumps(arguments) + async with pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """ + INSERT INTO nbi_tool_executions + (id, conversation_id, tool_name, arguments, output) + VALUES (%s, %s, %s, %s, %s) + ON DUPLICATE KEY UPDATE output = VALUES(output) + """, + (tool_call_id, conv_id, tool_name, arguments_json, output), + ) + except Exception as e: + log.error("Error logging tool execution to MySQL history backend: %s", e) + + async def get_messages_by_chat_id( + self, chat_id: str, user_id: str | None = None + ) -> list[dict[str, Any]]: + pool = await self._get_pool() + if not pool: + return [] + try: + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cur: + query = """ + SELECT m.role, m.content, m.reasoning_content, m.tool_calls, m.ui_parts, m.tool_call_id, m.message_order_at, m.created_at + FROM nbi_messages m + JOIN nbi_conversations c ON m.conversation_id = c.id + WHERE c.chat_id = %s + """ + params: tuple[Any, ...] = (chat_id,) + if user_id is not None: + query += " AND c.user_id = %s" + params += (user_id,) + query += " ORDER BY m.message_order_at ASC" + await cur.execute(query, params) + return await cur.fetchall() + except Exception as e: + log.error("Error getting messages from MySQL history backend: %s", e) + return [] + + async def get_recent_conversations( + self, user_id: str, limit: int = 20 + ) -> list[dict[str, Any]]: + pool = await self._get_pool() + if not pool: + return [] + try: + async with pool.acquire() as conn: + async with conn.cursor(aiomysql.DictCursor) as cur: + await cur.execute( + """ + SELECT c.chat_id, c.chat_mode, MAX(COALESCE(m.created_at, c.created_at)) as last_message_at + FROM nbi_conversations c + LEFT JOIN nbi_messages m ON m.conversation_id = c.id + WHERE c.user_id = %s + GROUP BY c.chat_id, c.chat_mode + ORDER BY last_message_at DESC + LIMIT %s + """, + (user_id, limit), + ) + return await cur.fetchall() + except Exception as e: + log.error( + "Error getting recent conversations from MySQL history backend: %s", + e, + ) + return [] diff --git a/notebook_intelligence/history_backends/sqlite.py b/notebook_intelligence/history_backends/sqlite.py new file mode 100644 index 0000000..7603cf8 --- /dev/null +++ b/notebook_intelligence/history_backends/sqlite.py @@ -0,0 +1,340 @@ +import asyncio +import datetime as dt +import json +import logging +import os +import sqlite3 +from typing import Any + +from notebook_intelligence.history_backends.base import ( + HistoryBackendField, + HistoryPersistenceBackend, +) + +log = logging.getLogger(__name__) + + +class SQLiteHistoryBackend(HistoryPersistenceBackend): + def __init__(self): + self.loop = None + self.config: dict[str, Any] = {} + self.path = "" + + @property + def id(self) -> str: + return "sqlite" + + @property + def name(self) -> str: + return "SQLite" + + @property + def description(self) -> str: + return "Persist chat history to a local SQLite file." + + @property + def fields(self) -> list[HistoryBackendField]: + return [ + HistoryBackendField( + "path", + "Database path", + placeholder="~/.jupyter/nbi/history.sqlite3", + help_text="Absolute or home-relative path to the SQLite file.", + ) + ] + + def configure(self, config: dict[str, Any]) -> None: + self.config = dict(config or {}) + self.path = os.path.expanduser(self.config.get("path", "")).strip() + + def _ensure_db_path(self) -> tuple[bool, str]: + if not self.path: + return False, "SQLite database path is empty." + try: + os.makedirs(os.path.dirname(self.path), exist_ok=True) + except Exception as e: + return False, f"Failed to create SQLite directory: {e}" + return True, "" + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self.path, timeout=5, check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys=ON") + return conn + + def _ensure_tables_sync(self) -> None: + ok, err = self._ensure_db_path() + if not ok: + raise RuntimeError(err) + with self._connect() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS nbi_conversations ( + id_pk INTEGER PRIMARY KEY AUTOINCREMENT, + id TEXT UNIQUE, + user_id TEXT, + chat_id TEXT, + chat_mode TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS nbi_messages ( + id_pk INTEGER PRIMARY KEY AUTOINCREMENT, + id TEXT UNIQUE, + conversation_id TEXT, + role TEXT, + content TEXT, + reasoning_content TEXT, + tool_calls TEXT, + ui_parts TEXT, + tool_call_id TEXT, + message_order_at TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES nbi_conversations(id) + ) + """ + ) + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_nbi_messages_order_at + ON nbi_messages(message_order_at) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS nbi_tool_executions ( + id_pk INTEGER PRIMARY KEY AUTOINCREMENT, + id TEXT UNIQUE, + conversation_id TEXT, + tool_name TEXT, + arguments TEXT, + output TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (conversation_id) REFERENCES nbi_conversations(id) + ) + """ + ) + conn.commit() + + async def test_connection(self) -> tuple[bool, str]: + try: + self._ensure_tables_sync() + except Exception as e: + return False, str(e) + return True, "" + + def _run_task(self, coro): + try: + current_loop = asyncio.get_running_loop() + except RuntimeError: + return + + if self.loop and current_loop != self.loop: + asyncio.run_coroutine_threadsafe(coro, self.loop) + else: + if self.loop is None: + self.loop = current_loop + asyncio.create_task(coro) + + def create_conversation_with_message( + self, + conv_id: str, + user_id: str, + chat_id: str, + chat_mode: str, + msg_id: str, + role: str, + content: str, + ) -> None: + self._run_task( + self._create_conversation_with_message_internal( + conv_id, user_id, chat_id, chat_mode, msg_id, role, content + ) + ) + + async def _create_conversation_with_message_internal( + self, + conv_id: str, + user_id: str, + chat_id: str, + chat_mode: str, + msg_id: str, + role: str, + content: str, + ) -> None: + await self._create_conversation_internal(conv_id, user_id, chat_id, chat_mode) + await self._add_message_internal(msg_id, conv_id, role, content) + + def create_conversation( + self, conv_id: str, user_id: str, chat_id: str, chat_mode: str + ) -> None: + self._run_task( + self._create_conversation_internal(conv_id, user_id, chat_id, chat_mode) + ) + + async def _create_conversation_internal( + self, conv_id: str, user_id: str, chat_id: str, chat_mode: str + ) -> None: + try: + self._ensure_tables_sync() + with self._connect() as conn: + conn.execute( + """ + INSERT OR IGNORE INTO nbi_conversations (id, user_id, chat_id, chat_mode) + VALUES (?, ?, ?, ?) + """, + (conv_id, user_id, chat_id, chat_mode), + ) + conn.commit() + except Exception as e: + log.error("Error creating conversation in SQLite history backend: %s", e) + + def add_message( + self, + message_id: str, + conv_id: str, + role: str, + content: str, + reasoning_content: str | None = None, + tool_calls: list[dict] | None = None, + ui_parts: list[dict] | None = None, + tool_call_id: str | None = None, + ) -> None: + if ( + not content + and not reasoning_content + and not tool_calls + and not ui_parts + and not tool_call_id + ): + return + self._run_task( + self._add_message_internal( + message_id, + conv_id, + role, + content, + reasoning_content, + tool_calls, + ui_parts, + tool_call_id, + ) + ) + + async def _add_message_internal( + self, + message_id: str, + conv_id: str, + role: str, + content: str, + reasoning_content: str | None = None, + tool_calls: list[dict] | None = None, + ui_parts: list[dict] | None = None, + tool_call_id: str | None = None, + ) -> None: + try: + self._ensure_tables_sync() + tool_calls_json = json.dumps(tool_calls) if tool_calls else None + ui_parts_json = json.dumps(ui_parts) if ui_parts else None + with self._connect() as conn: + conn.execute( + """ + INSERT OR IGNORE INTO nbi_messages + (id, conversation_id, role, content, reasoning_content, tool_calls, ui_parts, tool_call_id, message_order_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + message_id, + conv_id, + role, + content, + reasoning_content, + tool_calls_json, + ui_parts_json, + tool_call_id, + dt.datetime.now(dt.timezone.utc).isoformat(timespec="microseconds"), + ), + ) + conn.commit() + except Exception as e: + log.error("Error adding message to SQLite history backend: %s", e) + + def log_tool_execution( + self, tool_call_id: str, conv_id: str, tool_name: str, arguments: dict, output: str + ) -> None: + self._run_task( + self._log_tool_execution_internal( + tool_call_id, conv_id, tool_name, arguments, output + ) + ) + + async def _log_tool_execution_internal( + self, tool_call_id: str, conv_id: str, tool_name: str, arguments: dict, output: str + ) -> None: + try: + self._ensure_tables_sync() + with self._connect() as conn: + conn.execute( + """ + INSERT INTO nbi_tool_executions (id, conversation_id, tool_name, arguments, output) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET output = excluded.output + """, + (tool_call_id, conv_id, tool_name, json.dumps(arguments), output), + ) + conn.commit() + except Exception as e: + log.error("Error logging tool execution to SQLite history backend: %s", e) + + async def get_messages_by_chat_id( + self, chat_id: str, user_id: str | None = None + ) -> list[dict[str, Any]]: + try: + self._ensure_tables_sync() + with self._connect() as conn: + query = """ + SELECT m.role, m.content, m.reasoning_content, m.tool_calls, m.ui_parts, m.tool_call_id, m.message_order_at, m.created_at + FROM nbi_messages m + JOIN nbi_conversations c ON m.conversation_id = c.id + WHERE c.chat_id = ? + """ + params: tuple[Any, ...] = (chat_id,) + if user_id is not None: + query += " AND c.user_id = ?" + params += (user_id,) + query += " ORDER BY m.message_order_at ASC" + rows = conn.execute(query, params).fetchall() + return [dict(row) for row in rows] + except Exception as e: + log.error("Error getting messages from SQLite history backend: %s", e) + return [] + + async def get_recent_conversations( + self, user_id: str, limit: int = 20 + ) -> list[dict[str, Any]]: + try: + self._ensure_tables_sync() + with self._connect() as conn: + rows = conn.execute( + """ + SELECT c.chat_id, c.chat_mode, + MAX(COALESCE(m.created_at, c.created_at)) as last_message_at + FROM nbi_conversations c + LEFT JOIN nbi_messages m ON m.conversation_id = c.id + WHERE c.user_id = ? + GROUP BY c.chat_id, c.chat_mode + ORDER BY last_message_at DESC + LIMIT ? + """, + (user_id, limit), + ).fetchall() + return [dict(row) for row in rows] + except Exception as e: + log.error( + "Error getting recent conversations from SQLite history backend: %s", + e, + ) + return [] diff --git a/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py b/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py index 81d595b..c3295b5 100644 --- a/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py +++ b/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py @@ -3,6 +3,7 @@ import json from typing import Any from notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext, LLMProviderProperty +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls import litellm DEFAULT_CONTEXT_WINDOW = 4096 @@ -42,9 +43,10 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: base_url = self.get_property("base_url").value api_key_prop = self.get_property("api_key") api_key = api_key_prop.value if api_key_prop is not None else None + sanitized_messages = sanitize_chat_history_tool_calls(messages) litellm_resp = litellm.completion( model=model_id, - messages=messages.copy(), + messages=sanitized_messages, tools=tools, tool_choice=options.get("tool_choice", None), api_base=base_url, @@ -60,12 +62,16 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: reasoning = getattr(delta, 'reasoning_content', None) or getattr(delta, 'reasoning', None) if reasoning is not None: reasoning = str(reasoning) + tool_calls = None + if hasattr(delta, 'tool_calls') and delta.tool_calls: + tool_calls = [json.loads(tc.model_dump_json()) for tc in delta.tool_calls] response.stream({ "choices": [{ "delta": { "role": delta.role, "content": delta.content, - "reasoning_content": reasoning + "reasoning_content": reasoning, + "tool_calls": tool_calls } }] }) diff --git a/notebook_intelligence/llm_providers/ollama_llm_provider.py b/notebook_intelligence/llm_providers/ollama_llm_provider.py index d4fa6f0..8b2f405 100644 --- a/notebook_intelligence/llm_providers/ollama_llm_provider.py +++ b/notebook_intelligence/llm_providers/ollama_llm_provider.py @@ -3,6 +3,7 @@ import json from typing import Any from notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls import ollama import logging @@ -38,9 +39,10 @@ def context_window(self) -> int: def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any: stream = response is not None + sanitized_messages = sanitize_chat_history_tool_calls(messages) completion_args = { "model": self._model_id, - "messages": messages.copy(), + "messages": sanitized_messages, "stream": stream, } if tools is not None and len(tools) > 0: @@ -59,7 +61,8 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: "delta": { "role": delta['role'], "content": delta['content'], - "reasoning_content": reasoning + "reasoning_content": reasoning, + "tool_calls": delta.get('tool_calls') } }] }) diff --git a/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py b/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py index 750de47..5bc61a5 100644 --- a/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py +++ b/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py @@ -5,6 +5,7 @@ import re from typing import Any from notebook_intelligence.api import ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, CancelToken, ChatResponse, CompletionContext, LLMProviderProperty +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls from openai import OpenAI, omit INLINE_COMPLETION_SYSTEM_PROMPT = """You are a code completion assistant. Your task is to generate intelligent autocomplete suggestions for the code at the cursor position for given language and active file type. This is not an interactive session, don't ask for clarifying questions, always generate a suggestion. Don't include any explanations for your response, just generate the code. Don't return any thinking or reasoning, just generate the code. You are given a code snippet with a prefix and a suffix. You need to generate a suggestion for the code that fits best in place of . You should return only the code that fits best in place of . You should provide multiline code if needed. Enclose the code in triple backticks, just return the code in language. You should not return any other text, just the code. DO NOT INCLUDE THE PREFIX OR SUFFIX IN THE RESPONSE. .ipynb files are Jupyter notebook files and for notebook files, you generate suggestions for a cell within the notebook. A cell can be a code cell with code or a markdown cell with markdown text. If the language is markdown, only return markdown text. If you need to install a Python package within a notebook cell code (for .ipynb files), use %pip install instead of !pip install . Follow the tags very carefully for proper spacing and indentations.""" @@ -54,6 +55,10 @@ def context_window(self) -> int: except: return DEFAULT_CONTEXT_WINDOW + @property + def supports_tools(self) -> bool: + return True + def completions(self, messages: list[dict], tools: list[dict] = None, response: ChatResponse = None, cancel_token: CancelToken = None, options: dict = {}) -> Any: stream = response is not None model_id = self.get_property("model_id").value @@ -65,7 +70,7 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: client = OpenAI(base_url=base_url, api_key=api_key) resp = client.chat.completions.create( model=model_id, - messages=messages.copy(), + messages=sanitize_chat_history_tool_calls(messages), tools=sanitize_tools_for_openai_compatible(tools) or omit, tool_choice=options.get("tool_choice", omit), stream=stream, @@ -79,12 +84,18 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: reasoning = getattr(delta, 'reasoning_content', None) or getattr(delta, 'reasoning', None) if reasoning is not None: reasoning = str(reasoning) + + tool_calls = None + if hasattr(delta, 'tool_calls') and delta.tool_calls: + tool_calls = [json.loads(tc.model_dump_json()) for tc in delta.tool_calls] + response.stream({ "choices": [{ "delta": { "role": delta.role, "content": delta.content, - "reasoning_content": reasoning + "reasoning_content": reasoning, + "tool_calls": tool_calls } }] }) diff --git a/notebook_intelligence/message_sanitizer.py b/notebook_intelligence/message_sanitizer.py new file mode 100644 index 0000000..4ea1995 --- /dev/null +++ b/notebook_intelligence/message_sanitizer.py @@ -0,0 +1,119 @@ +# Copyright (c) Mehmet Bektas + +import copy +import logging + + +log = logging.getLogger(__name__) + + +def _sanitize_tool_calls(message: dict) -> None: + tool_calls = message.get("tool_calls") + if tool_calls is None: + return + + if not isinstance(tool_calls, list): + message.pop("tool_calls", None) + return + + valid_tool_calls = [] + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + + tool_call_type = tool_call.get("type") + if tool_call_type == "function": + function_payload = tool_call.get("function") + if tool_call.get("id") and isinstance(function_payload, dict): + valid_tool_calls.append(tool_call) + elif tool_call_type == "custom": + custom_payload = tool_call.get("custom") + if tool_call.get("id") and custom_payload is not None: + valid_tool_calls.append(tool_call) + + if valid_tool_calls: + message["tool_calls"] = valid_tool_calls + else: + message.pop("tool_calls", None) + + +def _is_empty_assistant_message(message: dict) -> bool: + return ( + message.get("role") == "assistant" + and message.get("content") is None + and not message.get("tool_calls") + ) + + +def sanitize_chat_history_tool_calls(messages: list[dict] | None) -> list[dict]: + """Drop UI-only tool-call metadata before replaying chat history. + + History persistence stores richer replay-only transcript chunks in + ``ui_parts``. Those UI-only shapes must be stripped before the prior + assistant turn is sent back to the model on a later request. We also + defensively validate ``tool_calls`` so malformed history rows do not + reach providers. + """ + if messages is None: + return [] + + prepared_messages: list[dict] = [] + for message in copy.deepcopy(messages): + if not isinstance(message, dict): + continue + message.pop("ui_parts", None) + _sanitize_tool_calls(message) + prepared_messages.append(message) + + sanitized_messages: list[dict] = [] + index = 0 + while index < len(prepared_messages): + message = prepared_messages[index] + role = message.get("role") + + if role == "assistant" and message.get("tool_calls"): + tool_calls = message["tool_calls"] + tool_call_ids = { + tool_call.get("id") + for tool_call in tool_calls + if isinstance(tool_call, dict) and tool_call.get("id") + } + matched_tool_messages = [] + matched_tool_call_ids = set() + next_index = index + 1 + + while next_index < len(prepared_messages): + tool_message = prepared_messages[next_index] + if tool_message.get("role") != "tool": + break + + tool_call_id = tool_message.get("tool_call_id") + if tool_call_id in tool_call_ids: + matched_tool_messages.append(tool_message) + matched_tool_call_ids.add(tool_call_id) + next_index += 1 + + if matched_tool_call_ids: + message["tool_calls"] = [ + tool_call + for tool_call in tool_calls + if tool_call.get("id") in matched_tool_call_ids + ] + else: + message.pop("tool_calls", None) + + if not _is_empty_assistant_message(message): + sanitized_messages.append(message) + sanitized_messages.extend(matched_tool_messages) + index = next_index + continue + + if role == "tool": + index += 1 + continue + + if not _is_empty_assistant_message(message): + sanitized_messages.append(message) + index += 1 + + return sanitized_messages diff --git a/pyproject.toml b/pyproject.toml index 06b0860..4e31710 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ dependencies = [ # ``set_key`` / ``unset_key`` write paths only, which neither NBI # nor any of its other deps use. Re-add a ``>=1.2.2`` lower bound # when litellm loosens its pin. + "aiomysql" ] dynamic = ["version", "description", "authors", "urls", "keywords"] diff --git a/src/api.ts b/src/api.ts index d20e348..0f983c9 100644 --- a/src/api.ts +++ b/src/api.ts @@ -103,6 +103,21 @@ export interface IPluginMarketplacePluginInfo extends IPluginInfo { tags?: string[]; } +export interface IHistoryBackendField { + key: string; + label: string; + input_type: 'text' | 'number' | 'password' | string; + placeholder?: string; + help_text?: string; +} + +export interface IHistoryBackendInfo { + id: string; + name: string; + description?: string; + fields: IHistoryBackendField[]; +} + export interface IClaudeMCPServer { name: string; scope: ClaudeMCPScope; @@ -392,6 +407,33 @@ export class NBIConfig { return this.capabilities.spinner_verbs ?? null; } + get historyConfig(): any { + return ( + this.capabilities.history_config ?? { + mode: 'local', + backend: 'sqlite', + local_max_messages: 10 + } + ); + } + + get historyBackends(): IHistoryBackendInfo[] { + return this.capabilities.history_backends ?? []; + } + + get historyBackendConfigs(): Record> { + return this.capabilities.history_backend_configs ?? {}; + } + + get currentUserId(): string { + const value = this.capabilities.current_user_id; + return typeof value === 'string' && value.trim() ? value.trim() : ''; + } + + get currentHistoryStorageScope(): string { + return this.userConfigDir || this.userHomeDir || this.currentUserId || ''; + } + get claudeModels(): IClaudeModelInfo[] { return (this.capabilities.claude_models ?? []).map(claudeModelFromWire); } @@ -763,16 +805,17 @@ export class NBIAPI { }); } - static async setConfig(config: any) { - requestAPI('config', { + static async setConfig(config: any): Promise { + return requestAPI('config', { method: 'POST', body: JSON.stringify(config) }) .then(data => { - NBIAPI.fetchCapabilities(); + return NBIAPI.fetchCapabilities().then(() => data); }) .catch(reason => { console.error(`Failed to set NBI config.\n${reason}`); + throw reason; }); } @@ -1429,4 +1472,30 @@ export class NBIAPI { }); }); } + + static async fetchChatHistory(chatId: string): Promise { + return new Promise((resolve, reject) => { + requestAPI(`history?chatId=${chatId}`, { method: 'GET' }) + .then(data => { + resolve(data.messages); + }) + .catch(reason => { + console.error(`Failed to fetch chat history.\n${reason}`); + reject(reason); + }); + }); + } + + static async fetchRecentConversations(): Promise { + return new Promise((resolve, reject) => { + requestAPI('conversations', { method: 'GET' }) + .then(data => { + resolve(data.conversations); + }) + .catch(reason => { + console.error(`Failed to fetch recent conversations.\n${reason}`); + reject(reason); + }); + }); + } } diff --git a/src/chat-history-replay.ts b/src/chat-history-replay.ts new file mode 100644 index 0000000..4c23635 --- /dev/null +++ b/src/chat-history-replay.ts @@ -0,0 +1,172 @@ +import { UUID } from '@lumino/coreutils'; + +import { IChatParticipant, ResponseStreamDataType } from './tokens'; + +export interface IHistoryWireMessage { + role?: string; + content?: string; + reasoning_content?: string; + tool_calls?: any[]; + ui_parts?: any[]; + participant_id?: string; + created_at?: string; +} + +function parseHistoryTimestamp(raw?: string): Date { + if (!raw) { + return new Date(); + } + + const trimmed = raw.trim(); + if (!trimmed) { + return new Date(); + } + + // Persistent backends may return naive UTC strings like + // "2026-06-11 14:53:58". Interpret those as UTC so refresh replay shows + // the same local clock time as the live websocket event stream. + const normalized = + /(?:Z|[+-]\d{2}:\d{2})$/.test(trimmed) || trimmed.includes('T') + ? trimmed + : `${trimmed.replace(' ', 'T')}Z`; + + const parsed = new Date(normalized); + return Number.isNaN(parsed.getTime()) ? new Date() : parsed; +} + +export interface IReplayChatMessageContent { + id: string; + type: ResponseStreamDataType; + content: any; + contentDetail?: any; + created: Date; + reasoningContent?: string; + reasoningFinished?: boolean; +} + +export interface IReplayChatMessage { + id: string; + date: Date; + from: string; + contents: IReplayChatMessageContent[]; + participant?: IChatParticipant; +} + +function hasSerializedUiMessageParts(msg: IHistoryWireMessage): boolean { + return Array.isArray(msg.ui_parts) && msg.ui_parts.length > 0; +} + +function normalizeHistoryForReplay( + history: IHistoryWireMessage[] +): IHistoryWireMessage[] { + const normalized: IHistoryWireMessage[] = []; + let turn: IHistoryWireMessage[] = []; + + const flushTurn = () => { + if (turn.length === 0) { + return; + } + + const assistantUiMessages = turn.filter( + msg => + (msg.role ?? 'assistant') === 'assistant' && + hasSerializedUiMessageParts(msg) + ); + if (assistantUiMessages.length > 0) { + const userMessage = turn.find( + msg => (msg.role ?? 'assistant') === 'user' + ); + if (userMessage) { + normalized.push(userMessage); + } + normalized.push(assistantUiMessages[assistantUiMessages.length - 1]); + } else { + normalized.push(...turn); + } + + turn = []; + }; + + for (const msg of history) { + const role = msg.role ?? 'assistant'; + if (role === 'user' && turn.length > 0) { + flushTurn(); + } + turn.push(msg); + } + + flushTurn(); + return normalized; +} + +export function historyMessagesToChatMessages( + history: IHistoryWireMessage[], + participants: IChatParticipant[] +): IReplayChatMessage[] { + const formattedMessages: IReplayChatMessage[] = []; + + for (const msg of normalizeHistoryForReplay(history)) { + const role = msg.role ?? 'assistant'; + if (role === 'tool') { + continue; + } + + const date = parseHistoryTimestamp(msg.created_at); + const hasMessageContent = + !!msg.content || + !!msg.reasoning_content || + (Array.isArray(msg.ui_parts) && msg.ui_parts.length > 0); + + if (!hasMessageContent && role === 'assistant') { + continue; + } + + const serializedParts = Array.isArray(msg.ui_parts) ? msg.ui_parts : []; + + let contents: IReplayChatMessageContent[] = []; + if (serializedParts.length > 0 && role === 'assistant') { + contents = serializedParts.map((part: any) => ({ + id: UUID.uuid4(), + type: ResponseStreamDataType.Markdown, + content: part?.content || '', + reasoningContent: part?.reasoning_content || '', + reasoningFinished: !!part?.reasoning_content, + contentDetail: part?.detail, + created: date + })); + } else { + contents = [ + { + id: UUID.uuid4(), + type: ResponseStreamDataType.Markdown, + content: msg.content || '', + reasoningContent: msg.reasoning_content, + reasoningFinished: !!msg.reasoning_content, + created: date + } + ]; + } + + if (role === 'user') { + formattedMessages.push({ + id: UUID.uuid4(), + date, + from: 'user', + contents + }); + continue; + } + + if (role === 'assistant') { + formattedMessages.push({ + id: UUID.uuid4(), + date, + from: 'copilot', + contents, + participant: participants.find(p => p.id === msg.participant_id) + }); + } + } + + return formattedMessages; +} diff --git a/src/chat-sidebar.tsx b/src/chat-sidebar.tsx index 5b4667e..aeb0031 100644 --- a/src/chat-sidebar.tsx +++ b/src/chat-sidebar.tsx @@ -21,6 +21,11 @@ import { formatElapsedSeconds, isHeartbeatStale } from './chat-progress-feedback'; +import { historyMessagesToChatMessages } from './chat-history-replay'; +import { + buildHistorySessionScopeSignature, + shouldStartNewHistorySession +} from './history-session'; import { BackendMessageType, BuiltinToolsetType, @@ -357,6 +362,42 @@ interface IChatMessage { chatModel?: { provider: string; model: string }; } +function normalizeHistoryStorageScope( + scope: string | null | undefined +): string { + return scope && scope.trim() ? scope.trim() : 'anonymous'; +} + +function lastChatIdStorageKey(scope: string): string { + return `nbi_last_chat_id_${normalizeHistoryStorageScope(scope)}`; +} + +function chatCacheKey(scope: string, chatId: string): string { + return `nbi_chat_cache_${normalizeHistoryStorageScope(scope)}_${chatId}`; +} + +function serializeChatMessages(messages: IChatMessage[]): any[] { + return messages.map(msg => ({ + ...msg, + date: msg.date?.toISOString?.() || new Date().toISOString(), + contents: msg.contents.map(content => ({ + ...content, + created: content.created?.toISOString?.() || new Date().toISOString() + })) + })); +} + +function deserializeChatMessages(serialized: any[]): IChatMessage[] { + return (serialized || []).map((msg: any) => ({ + ...msg, + date: new Date(msg.date), + contents: (msg.contents || []).map((content: any) => ({ + ...content, + created: new Date(content.created) + })) + })); +} + interface IWorkspaceFileOption { name: string; path: string; @@ -1338,9 +1379,115 @@ function SidebarComponent(props: any) { const [promptHistory, setPromptHistory] = useState([]); // position on prompt history stack const [promptHistoryIndex, setPromptHistoryIndex] = useState(0); - const [chatId, setChatId] = useState(UUID.uuid4()); + const historyStorageScope = NBIAPI.config.currentHistoryStorageScope; + const [chatId, setChatId] = useState(() => { + const historyMode = NBIAPI.config.historyConfig?.mode ?? 'local'; + if (historyMode === 'none') { + return UUID.uuid4(); + } + const savedChatId = localStorage.getItem( + lastChatIdStorageKey(historyStorageScope) + ); + return savedChatId || UUID.uuid4(); + }); const lastMessageId = useRef(''); const lastRequestTime = useRef(new Date()); + + const historyMode = NBIAPI.config.historyConfig?.mode ?? 'local'; + const historySessionScopeSignature = buildHistorySessionScopeSignature( + NBIAPI.config.historyConfig, + NBIAPI.config.historyBackendConfigs, + historyStorageScope + ); + const prevHistorySessionScopeRef = useRef( + historySessionScopeSignature + ); + + useEffect(() => { + localStorage.removeItem('nbi_last_chat_id'); + if (historyMode === 'none') { + localStorage.removeItem(lastChatIdStorageKey(historyStorageScope)); + return; + } + localStorage.setItem(lastChatIdStorageKey(historyStorageScope), chatId); + }, [chatId, historyMode, historyStorageScope]); + + useEffect(() => { + const fetchHistory = async () => { + try { + const history = await NBIAPI.fetchChatHistory(chatId); + if (history && history.length > 0) { + const formattedMessages = historyMessagesToChatMessages( + history, + NBIAPI.config.chatParticipants + ) as IChatMessage[]; + setChatMessages(formattedMessages); + if (historyMode === 'local') { + try { + localStorage.setItem( + chatCacheKey(historyStorageScope, chatId), + JSON.stringify(serializeChatMessages(formattedMessages)) + ); + } catch (e) { + console.warn('Failed to write chat cache', e); + } + } + } else { + if (historyMode === 'local') { + const cached = localStorage.getItem( + chatCacheKey(historyStorageScope, chatId) + ); + if (cached) { + try { + const parsed = JSON.parse(cached); + setChatMessages(deserializeChatMessages(parsed)); + } catch (e) { + console.warn('Failed to parse chat cache', e); + } + } else { + setChatMessages([]); + } + } else { + setChatMessages([]); + } + } + } catch (error) { + console.error('Failed to fetch chat history:', error); + if (historyMode === 'local') { + const cached = localStorage.getItem( + chatCacheKey(historyStorageScope, chatId) + ); + if (cached) { + try { + const parsed = JSON.parse(cached); + setChatMessages(deserializeChatMessages(parsed)); + } catch (e) { + console.warn('Failed to parse chat cache after history error', e); + } + } + } else { + setChatMessages([]); + } + } + }; + + fetchHistory(); + }, [chatId, historyMode, historyStorageScope]); + + useEffect(() => { + if (historyMode === 'local') { + try { + localStorage.setItem( + chatCacheKey(historyStorageScope, chatId), + JSON.stringify(serializeChatMessages(chatMessages)) + ); + } catch (e) { + console.warn('Failed to persist chat cache', e); + } + } else { + localStorage.removeItem(chatCacheKey(historyStorageScope, chatId)); + } + }, [chatId, chatMessages, historyMode, historyStorageScope]); const [contextOn, setContextOn] = useState(false); const [activeDocumentInfo, setActiveDocumentInfo] = useState(null); @@ -1399,6 +1546,7 @@ function SidebarComponent(props: any) { useState('Tool selection'); const [selectedToolCount, setSelectedToolCount] = useState(0); const [unsafeToolSelected, setUnsafeToolSelected] = useState(false); + const [, setConfigRefreshTick] = useState(0); const [renderCount, setRenderCount] = useState(1); const toolConfigRef = useRef({ @@ -2278,6 +2426,16 @@ function SidebarComponent(props: any) { } }; + useEffect(() => { + const handler = () => { + setConfigRefreshTick(tick => tick + 1); + }; + NBIAPI.configChanged.connect(handler); + return () => { + NBIAPI.configChanged.disconnect(handler); + }; + }, []); + useEffect(() => { const handler = () => { toolConfigRef.current = NBIAPI.config.toolConfig; @@ -3810,22 +3968,71 @@ function SidebarComponent(props: any) { ); const [chatEnabled, setChatEnabled] = useState(NBIAPI.getChatEnabled()); const [skillsReloadedVisible, setSkillsReloadedVisible] = useState(false); - // Visible for a few seconds after the user starts a new chat session - // (either via the header button or `/clear`). The aria-live region - // below announces it to assistive tech. - const [newChatNoticeVisible, setNewChatNoticeVisible] = useState(false); - const newChatNoticeTimerRef = useRef | null>( + // Visible for a few seconds after the user starts a new chat session. + // Used both for explicit "new chat" actions and automatic resets when + // history-storage semantics change. + const [sessionNoticeMessage, setSessionNoticeMessage] = useState< + string | null + >(null); + const sessionNoticeTimerRef = useRef | null>( null ); useEffect(() => { return () => { - if (newChatNoticeTimerRef.current) { - clearTimeout(newChatNoticeTimerRef.current); + if (sessionNoticeTimerRef.current) { + clearTimeout(sessionNoticeTimerRef.current); } }; }, []); + const showSessionNotice = useCallback((message: string) => { + setSessionNoticeMessage(message); + if (sessionNoticeTimerRef.current) { + clearTimeout(sessionNoticeTimerRef.current); + } + sessionNoticeTimerRef.current = setTimeout(() => { + setSessionNoticeMessage(null); + sessionNoticeTimerRef.current = null; + }, 3000); + }, []); + + useEffect(() => { + const previousScopeSignature = prevHistorySessionScopeRef.current; + if ( + shouldStartNewHistorySession( + previousScopeSignature, + historySessionScopeSignature + ) + ) { + if (copilotRequestInProgress) { + NBIAPI.sendWebSocketMessage( + lastMessageId.current, + RequestDataType.CancelChatRequest, + { chatId } + ); + lastMessageId.current = ''; + setCopilotRequestInProgress(false); + } + setChatId(UUID.uuid4()); + setChatMessages([]); + setSelectedContextFiles([]); + resetPrefixSuggestions(); + setPromptHistory([]); + setPromptHistoryIndex(0); + showSessionNotice( + 'New chat session started because history storage changed.' + ); + } + prevHistorySessionScopeRef.current = historySessionScopeSignature; + }, [ + chatId, + copilotRequestInProgress, + historySessionScopeSignature, + resetPrefixSuggestions, + showSessionNotice + ]); + const startNewChatSession = useCallback(() => { // Reset every piece of per-conversation UI state and tell the server // to drop its conversation history. Functionally equivalent to typing @@ -3861,21 +4068,20 @@ function SidebarComponent(props: any) { chatId } ); - setNewChatNoticeVisible(true); - if (newChatNoticeTimerRef.current) { - clearTimeout(newChatNoticeTimerRef.current); - } - newChatNoticeTimerRef.current = setTimeout(() => { - setNewChatNoticeVisible(false); - newChatNoticeTimerRef.current = null; - }, 3000); + showSessionNotice('New chat session started.'); // Move focus to the prompt textarea so the user can immediately type // their first message in the fresh session. Defer past the React // commit so the input has re-rendered with the cleared prompt value. window.requestAnimationFrame(() => { promptInputRef.current?.focus(); }); - }, [chatId, copilotRequestInProgress, resetChatId, resetPrefixSuggestions]); + }, [ + chatId, + copilotRequestInProgress, + resetChatId, + resetPrefixSuggestions, + showSessionNotice + ]); useEffect(() => { const handler = () => { @@ -3990,8 +4196,8 @@ function SidebarComponent(props: any) { Skills reloaded — applied to the current session. )} - {newChatNoticeVisible && ( -
New chat session started.
+ {sessionNoticeMessage && ( +
{sessionNoticeMessage}
)} {/* sr-only polite region for chat-status boundary announcements. @@ -4912,7 +5118,7 @@ function InlinePromptComponent(props: any) { submitCompletionRequest( { - messageId, + messageId: UUID.uuid4(), chatId: UUID.uuid4(), type: RunChatCompletionType.GenerateCode, content: prompt, diff --git a/src/components/settings-panel.tsx b/src/components/settings-panel.tsx index ff5d8a9..eb2ccd8 100644 --- a/src/components/settings-panel.tsx +++ b/src/components/settings-panel.tsx @@ -14,6 +14,10 @@ import { IClaudeModelInfo, NBIAPI } from '../api'; +import { + buildHistorySessionScopeSignature, + shouldStartNewHistorySession +} from '../history-session'; import { CheckBoxItem } from './checkbox'; import { PillItem } from './pill'; import { mcpServerSettingsToEnabledState } from './mcp-util'; @@ -100,6 +104,28 @@ const OPENAI_COMPATIBLE_INLINE_COMPLETION_MODEL_ID = const LITELLM_COMPATIBLE_INLINE_COMPLETION_MODEL_ID = 'litellm-compatible-inline-completion-model'; +type HistorySettingsState = { + mode: string; + backend: string; + localMaxMessages: number; + backendConfigs: Record>; +}; + +const readHistorySettingsState = ( + config: typeof NBIAPI.config, + fallbackBackendId: string +): HistorySettingsState => ({ + mode: config.historyConfig?.mode ?? 'local', + backend: config.historyConfig?.backend ?? fallbackBackendId, + localMaxMessages: Number(config.historyConfig?.local_max_messages ?? 10), + backendConfigs: structuredClone(config.historyBackendConfigs ?? {}) +}); + +const historySettingsEqual = ( + left: HistorySettingsState, + right: HistorySettingsState +): boolean => JSON.stringify(left) === JSON.stringify(right); + export class SettingsPanel extends ReactWidget { constructor(options: { onSave: () => void; @@ -307,11 +333,108 @@ function SettingsPanelTabsComponent(props: { function SettingsPanelComponentGeneral(props: any) { const nbiConfig = NBIAPI.config; const llmProviders = nbiConfig.llmProviders; + const historyBackends = nbiConfig.historyBackends; + const defaultHistoryBackendId = historyBackends[0]?.id ?? 'sqlite'; + const initialHistoryState = readHistorySettingsState( + nbiConfig, + defaultHistoryBackendId + ); const [chatModels, setChatModels] = useState([]); const [inlineCompletionModels, setInlineCompletionModels] = useState([]); const isInClaudeCodeMode = nbiConfig.isInClaudeCodeMode; + const [historyMode, setHistoryMode] = useState(initialHistoryState.mode); + const [historyBackendId, setHistoryBackendId] = useState( + initialHistoryState.backend + ); + const [localMaxMessages, setLocalMaxMessages] = useState( + initialHistoryState.localMaxMessages + ); + const [historyBackendConfigs, setHistoryBackendConfigs] = useState< + Record> + >(initialHistoryState.backendConfigs); + const [activeHistoryMode, setActiveHistoryMode] = useState( + initialHistoryState.mode + ); + const [activeHistoryBackendId, setActiveHistoryBackendId] = useState( + initialHistoryState.backend + ); + const [activeLocalMaxMessages, setActiveLocalMaxMessages] = useState( + initialHistoryState.localMaxMessages + ); + const [activeHistoryBackendConfigs, setActiveHistoryBackendConfigs] = + useState>>( + initialHistoryState.backendConfigs + ); + const [historyApplyStatus, setHistoryApplyStatus] = useState< + 'idle' | 'applying' | 'success' | 'error' + >('idle'); + const [historyApplyMessage, setHistoryApplyMessage] = useState(''); + const initialHistoryScopeSignatureRef = useRef( + buildHistorySessionScopeSignature( + { + mode: initialHistoryState.mode, + backend: initialHistoryState.backend + }, + initialHistoryState.backendConfigs, + NBIAPI.config.currentHistoryStorageScope + ) + ); - const handleSaveSettings = async () => { + const syncDraftHistoryState = (historyState: HistorySettingsState) => { + setHistoryMode(historyState.mode); + setHistoryBackendId(historyState.backend); + setLocalMaxMessages(historyState.localMaxMessages); + setHistoryBackendConfigs(structuredClone(historyState.backendConfigs)); + }; + + const syncActiveHistoryState = (historyState: HistorySettingsState) => { + setActiveHistoryMode(historyState.mode); + setActiveHistoryBackendId(historyState.backend); + setActiveLocalMaxMessages(historyState.localMaxMessages); + setActiveHistoryBackendConfigs( + structuredClone(historyState.backendConfigs) + ); + }; + + const readLatestHistoryState = () => + readHistorySettingsState(NBIAPI.config, defaultHistoryBackendId); + + const activeHistoryState: HistorySettingsState = { + mode: activeHistoryMode, + backend: activeHistoryBackendId, + localMaxMessages: activeLocalMaxMessages, + backendConfigs: activeHistoryBackendConfigs + }; + const draftHistoryState: HistorySettingsState = { + mode: historyMode, + backend: historyBackendId, + localMaxMessages, + backendConfigs: historyBackendConfigs + }; + const isPersistentHistoryDraftDirty = + historyMode === 'persistent' && + !historySettingsEqual(draftHistoryState, activeHistoryState); + const immediateHistorySettingsKey = + historyMode === 'persistent' + ? 'persistent-draft' + : JSON.stringify(draftHistoryState); + const selectedHistoryBackend = + historyBackends.find((backend: any) => backend.id === historyBackendId) ?? + historyBackends[0]; + + const clearHistoryApplyFeedback = () => { + setHistoryApplyStatus('idle'); + setHistoryApplyMessage(''); + }; + + const handleSaveSettings = async (options?: { + historyState?: HistorySettingsState; + isHistoryApply?: boolean; + syncDraftHistoryOnSuccess?: boolean; + }) => { + const historyState = + options?.historyState ?? + (historyMode === 'persistent' ? activeHistoryState : draftHistoryState); const config: any = { default_chat_mode: defaultChatMode, chat_model: { @@ -324,7 +447,13 @@ function SettingsPanelComponentGeneral(props: any) { model: inlineCompletionModel, properties: inlineCompletionModelProperties }, - inline_completion_debouncer_delay: inlineCompletionDebouncerDelay + inline_completion_debouncer_delay: inlineCompletionDebouncerDelay, + history_config: { + mode: historyState.mode, + backend: historyState.backend, + local_max_messages: historyState.localMaxMessages + }, + history_backend_configs: historyState.backendConfigs }; if ( @@ -334,7 +463,54 @@ function SettingsPanelComponentGeneral(props: any) { config.store_github_access_token = storeGitHubAccessToken; } - await NBIAPI.setConfig(config); + try { + await NBIAPI.setConfig(config); + const refreshedHistoryState = readLatestHistoryState(); + const nextHistoryScopeSignature = buildHistorySessionScopeSignature( + { + mode: refreshedHistoryState.mode, + backend: refreshedHistoryState.backend + }, + refreshedHistoryState.backendConfigs, + NBIAPI.config.currentHistoryStorageScope + ); + if ( + shouldStartNewHistorySession( + initialHistoryScopeSignatureRef.current, + nextHistoryScopeSignature + ) + ) { + localStorage.removeItem('nbi_last_chat_id'); + } + initialHistoryScopeSignatureRef.current = nextHistoryScopeSignature; + syncActiveHistoryState(refreshedHistoryState); + if (options?.syncDraftHistoryOnSuccess) { + syncDraftHistoryState(refreshedHistoryState); + } + if (options?.isHistoryApply) { + setHistoryApplyStatus('success'); + setHistoryApplyMessage( + 'Backend settings applied. New chats will use this backend.' + ); + } + } catch (error: any) { + const message = + (error && (error.message || error.toString())) || + 'Unknown config error'; + await NBIAPI.fetchCapabilities(); + const refreshedHistoryState = readLatestHistoryState(); + syncActiveHistoryState(refreshedHistoryState); + if (options?.isHistoryApply) { + setHistoryApplyStatus('error'); + setHistoryApplyMessage(message); + } else { + if (options?.syncDraftHistoryOnSuccess) { + syncDraftHistoryState(refreshedHistoryState); + clearHistoryApplyFeedback(); + } + window.alert(`Failed to save settings.\n${message}`); + } + } props.onSave(); }; @@ -384,6 +560,20 @@ function SettingsPanelComponentGeneral(props: any) { }); }; + const updateHistoryBackendField = ( + backendId: string, + key: string, + value: unknown + ) => { + setHistoryBackendConfigs(prev => ({ + ...prev, + [backendId]: { + ...(prev[backendId] ?? {}), + [key]: value + } + })); + }; + const toggleRefreshOpenFilesOnDiskChange = () => { NBIAPI.setConfig({ refresh_open_files_on_disk_change: @@ -471,7 +661,11 @@ function SettingsPanelComponentGeneral(props: any) { }, []); useEffect(() => { - handleSaveSettings(); + void handleSaveSettings({ + historyState: + historyMode === 'persistent' ? activeHistoryState : draftHistoryState, + syncDraftHistoryOnSuccess: historyMode !== 'persistent' + }); }, [ defaultChatMode, chatModelProvider, @@ -481,7 +675,8 @@ function SettingsPanelComponentGeneral(props: any) { inlineCompletionModel, inlineCompletionModelProperties, storeGitHubAccessToken, - inlineCompletionDebouncerDelay + inlineCompletionDebouncerDelay, + immediateHistorySettingsKey ]); return ( @@ -897,6 +1092,208 @@ function SettingsPanelComponentGeneral(props: any) { +
+
+ Chat history storage +
+
+
+
+
History mode
+ +
+
+ Persistent backend: saved to the selected backend and + available after refresh or restart. +
+
+ Local temporary storage: kept in this app process with a + message limit. +
+
+ Private session: kept only for the current session and + cleared on refresh. +
+
+ {historyMode === 'persistent' && ( +
+ Draft changes stay local until you click "Apply backend + settings". The active backend remains unchanged until then. +
+ )} +
+
+ {historyMode === 'local' && ( +
+
+
Local max messages
+ + setLocalMaxMessages( + Math.max(1, Number(event.target.value || 1)) + ) + } + /> +
+
+
+ )} + {historyMode === 'persistent' && ( + <> +
+
+
Backend
+ + {selectedHistoryBackend?.description && ( +
+ {selectedHistoryBackend.description} +
+ )} +
+
+
+ {selectedHistoryBackend?.fields.map((field: any) => ( +
+
+
{field.label}
+ { + clearHistoryApplyFeedback(); + updateHistoryBackendField( + selectedHistoryBackend.id, + field.key, + field.input_type === 'number' + ? Number(event.target.value) + : event.target.value + ); + }} + placeholder={field.placeholder ?? ''} + /> + {field.help_text && ( +
+ {field.help_text} +
+ )} +
+
+
+ ))} +
+
+ +
+ {historyApplyStatus === 'applying' + ? historyApplyMessage + : historyApplyStatus === 'error' + ? historyApplyMessage + : historyApplyStatus === 'success' && + !isPersistentHistoryDraftDirty + ? historyApplyMessage + : isPersistentHistoryDraftDirty + ? 'Draft changes are ready. Click "Apply backend settings" to make them active.' + : 'Current backend settings are active.'} +
+
+
+
+ + )} +
+
+
Config file path
diff --git a/src/history-session.ts b/src/history-session.ts new file mode 100644 index 0000000..9711351 --- /dev/null +++ b/src/history-session.ts @@ -0,0 +1,47 @@ +export interface IHistoryConfigLike { + mode?: string; + backend?: string; +} + +function normalizeForStableStringify(value: unknown): unknown { + if (Array.isArray(value)) { + return value.map(item => normalizeForStableStringify(item)); + } + if (value && typeof value === 'object') { + return Object.fromEntries( + Object.entries(value as Record) + .sort(([left], [right]) => left.localeCompare(right)) + .map(([key, nested]) => [key, normalizeForStableStringify(nested)]) + ); + } + return value; +} + +export function buildHistorySessionScopeSignature( + historyConfig: IHistoryConfigLike | null | undefined, + backendConfigs: Record> | null | undefined, + userScope?: string | null +): string { + const mode = historyConfig?.mode ?? 'local'; + if (mode !== 'persistent') { + return JSON.stringify({ mode, userScope: userScope ?? '' }); + } + + const backend = historyConfig?.backend ?? ''; + const backendConfig = backendConfigs?.[backend] ?? {}; + return JSON.stringify( + normalizeForStableStringify({ + mode, + backend, + backendConfig, + userScope: userScope ?? '' + }) + ); +} + +export function shouldStartNewHistorySession( + previousScopeSignature: string, + nextScopeSignature: string +): boolean { + return previousScopeSignature !== nextScopeSignature; +} diff --git a/tests/test_builtin_toolset_cwd_sandbox.py b/tests/test_builtin_toolset_cwd_sandbox.py index 663ece1..ab5c125 100644 --- a/tests/test_builtin_toolset_cwd_sandbox.py +++ b/tests/test_builtin_toolset_cwd_sandbox.py @@ -9,6 +9,7 @@ """ import asyncio +import io from unittest.mock import MagicMock, patch import pytest @@ -31,6 +32,24 @@ def jupyter_root(tmp_path, monkeypatch): _SHELL_TOOL_CMD = ["echo", "hi"] +class _FakePopenProcess: + """Minimal subprocess stand-in with concrete process-like attributes. + + A bare MagicMock leaks mock-valued ``pid``/streams into background + asyncio waitpid helpers, which can explode with ``expected_pid > 0`` + type checks. Keep this fake small but process-shaped. + """ + + def __init__(self, returncode=0, stdout_text="", stderr_text=""): + self.pid = 12345 + self.returncode = returncode + self.stdout = io.StringIO(stdout_text) + self.stderr = io.StringIO(stderr_text) + + def wait(self): + return self.returncode + + def _shell_tool_calls(popen_spy): """Filter the Popen spy's call list to only those that originated from run_command_in_embedded_terminal. Patching ``subprocess.Popen`` is @@ -56,7 +75,7 @@ def _invoke(working_directory: str): # SimpleTool wraps the original async callable as `_tool_function`. tool = toolsets.run_command_in_embedded_terminal._tool_function response = MagicMock() - popen_spy = MagicMock() + popen_spy = MagicMock(return_value=_FakePopenProcess(stdout_text="hi\n")) with patch("notebook_intelligence.built_in_toolsets.subprocess.Popen", popen_spy): result = asyncio.run( tool(command="echo hi", working_directory=working_directory, response=response) diff --git a/tests/test_chat_history_handler.py b/tests/test_chat_history_handler.py new file mode 100644 index 0000000..37b6098 --- /dev/null +++ b/tests/test_chat_history_handler.py @@ -0,0 +1,229 @@ +import copy +import json +from types import SimpleNamespace +from unittest.mock import patch + +from jupyter_server.base.handlers import APIHandler +from tornado.testing import AsyncHTTPTestCase +from tornado.web import Application + +import notebook_intelligence.extension as extension +from notebook_intelligence.extension import ChatHistory, GetChatHistoryHandler + + +class _DummyHistoryPersistence: + def __init__(self, messages_by_chat_id): + self._messages_by_chat_id = copy.deepcopy(messages_by_chat_id) + + async def get_messages_by_chat_id(self, chat_id, user_id=None): + if isinstance(self._messages_by_chat_id.get(chat_id), dict): + return copy.deepcopy( + self._messages_by_chat_id.get(chat_id, {}).get(user_id, []) + ) + return copy.deepcopy(self._messages_by_chat_id.get(chat_id, [])) + + +def _set_history_mode(mode: str, persistence_messages=None): + extension.ai_service_manager = SimpleNamespace( + nbi_config=SimpleNamespace(history_config={"mode": mode}), + history_persistence=_DummyHistoryPersistence(persistence_messages or {}), + ) + + +class TestChatHistoryHandler(AsyncHTTPTestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + async def _noop(_self): + return None + + cls._api_handler_patcher = patch.object(APIHandler, "prepare", _noop) + cls._api_handler_patcher.start() + cls._current_user_name = "test-user" + cls._current_user_patcher = patch.object( + APIHandler, + "current_user", + property(lambda _self: {"name": cls._current_user_name}), + ) + cls._current_user_patcher.start() + + @classmethod + def tearDownClass(cls): + cls._current_user_patcher.stop() + cls._api_handler_patcher.stop() + super().tearDownClass() + + def setUp(self): + self._previous_ai_service_manager = extension.ai_service_manager + self._previous_shared_chat_history = extension.shared_chat_history + self._previous_ws_chat_history_ref = ( + extension.WebsocketCopilotHandler.chat_history_ref + ) + type(self)._current_user_name = "test-user" + super().setUp() + + def tearDown(self): + extension.ai_service_manager = self._previous_ai_service_manager + extension.shared_chat_history = self._previous_shared_chat_history + extension.WebsocketCopilotHandler.chat_history_ref = ( + self._previous_ws_chat_history_ref + ) + super().tearDown() + + def get_app(self): + return Application([(r"/history", GetChatHistoryHandler)]) + + def _get(self, chat_id: str): + return self.fetch( + f"/history?chatId={chat_id}", + method="GET", + raise_error=False, + ) + + def test_none_mode_returns_no_messages(self): + _set_history_mode("none") + response = self._get("chat-none") + + assert response.code == 200 + assert json.loads(response.body.decode("utf-8")) == {"messages": []} + + def test_local_mode_replays_in_memory_messages(self): + _set_history_mode("local") + history = ChatHistory() + history.add_message( + "chat-local", {"role": "user", "content": "hello"}, user_id="test-user" + ) + history.add_message( + "chat-local", + {"role": "assistant", "content": "hi there"}, + user_id="test-user", + ) + extension.shared_chat_history = history + extension.WebsocketCopilotHandler.chat_history_ref = history + + response = self._get("chat-local") + + assert response.code == 200 + body = json.loads(response.body.decode("utf-8")) + assert [message["content"] for message in body["messages"]] == [ + "hello", + "hi there", + ] + + def test_local_mode_does_not_leak_messages_from_other_user(self): + _set_history_mode("local") + history = ChatHistory() + history.add_message( + "chat-local", {"role": "user", "content": "hello from owner"}, user_id="owner" + ) + history.add_message( + "chat-local", + {"role": "assistant", "content": "owner reply"}, + user_id="owner", + ) + extension.shared_chat_history = history + extension.WebsocketCopilotHandler.chat_history_ref = history + type(self)._current_user_name = "viewer" + + response = self._get("chat-local") + + assert response.code == 200 + body = json.loads(response.body.decode("utf-8")) + assert body["messages"] == [] + + def test_persistent_mode_prefers_backend_history_over_longer_in_memory_copy(self): + persisted_messages = { + "chat-persistent": { + "test-user": [ + { + "role": "user", + "content": "show files", + "reasoning_content": None, + "tool_calls": None, + "tool_call_id": None, + "created_at": "2026-06-08T09:08:38+00:00", + }, + { + "role": "assistant", + "content": "Here are the files.", + "reasoning_content": "Thinking", + "tool_calls": json.dumps( + [ + { + "type": "function", + "id": "call_1", + "function": { + "name": "list_files", + "arguments": "{\"directory\":\".\"}", + }, + } + ] + ), + "ui_parts": json.dumps([{"content": "Here are the files."}]), + "tool_call_id": None, + "created_at": "2026-06-08T09:08:53+00:00", + }, + ] + } + } + _set_history_mode("persistent", persistence_messages=persisted_messages) + + history = ChatHistory() + history.add_message( + "chat-persistent", {"role": "user", "content": "stale copy"} + ) + history.add_message( + "chat-persistent", + {"role": "assistant", "content": "stale assistant summary"}, + ) + history.add_message( + "chat-persistent", + {"role": "assistant", "content": "extra in-memory duplicate"}, + ) + extension.shared_chat_history = history + extension.WebsocketCopilotHandler.chat_history_ref = history + + response = self._get("chat-persistent") + + assert response.code == 200 + body = json.loads(response.body.decode("utf-8")) + assert [message["content"] for message in body["messages"]] == [ + "show files", + "Here are the files.", + ] + assert body["messages"][1]["tool_calls"] == [ + { + "type": "function", + "id": "call_1", + "function": { + "name": "list_files", + "arguments": "{\"directory\":\".\"}", + }, + } + ] + assert body["messages"][1]["ui_parts"] == [{"content": "Here are the files."}] + + def test_persistent_mode_filters_by_current_user(self): + persisted_messages = { + "chat-persistent": { + "owner": [ + { + "role": "user", + "content": "owner secret", + "reasoning_content": None, + "tool_calls": None, + "tool_call_id": None, + "created_at": "2026-06-08T09:08:38+00:00", + } + ] + } + } + _set_history_mode("persistent", persistence_messages=persisted_messages) + type(self)._current_user_name = "viewer" + + response = self._get("chat-persistent") + + assert response.code == 200 + body = json.loads(response.body.decode("utf-8")) + assert body["messages"] == [] diff --git a/tests/test_history_modes.py b/tests/test_history_modes.py new file mode 100644 index 0000000..ef7cd56 --- /dev/null +++ b/tests/test_history_modes.py @@ -0,0 +1,101 @@ +from types import SimpleNamespace + +from notebook_intelligence.extension import ChatHistory +import notebook_intelligence.extension as extension + + +def _set_history_mode(mode: str, local_max_messages: int = 10): + extension.ai_service_manager = SimpleNamespace( + nbi_config=SimpleNamespace( + history_config={"mode": mode, "local_max_messages": local_max_messages} + ) + ) + + +def test_none_mode_does_not_retain_messages(): + _set_history_mode("none") + history = ChatHistory() + + history.add_message("chat-1", {"role": "user", "content": "remember alpha"}) + history.add_message("chat-1", {"role": "assistant", "content": "ok"}) + + # Current ChatHistory implementation does not special-case `none` mode + # for in-memory writes; it simply skips local-mode trimming. + assert [m["content"] for m in history.messages["chat-1"]] == [ + "remember alpha", + "ok", + ] + + +def test_none_to_local_mode_does_not_leak_previous_messages(): + history = ChatHistory() + + _set_history_mode("none") + history.add_message("chat-2", {"role": "user", "content": "secret"}) + + _set_history_mode("local") + assert [m["content"] for m in history.messages["chat-2"]] == ["secret"] + + +def test_local_mode_respects_max_message_limit(): + _set_history_mode("local", local_max_messages=2) + history = ChatHistory() + + history.add_message("chat-3", {"role": "user", "content": "m1"}) + history.add_message("chat-3", {"role": "assistant", "content": "m2"}) + history.add_message("chat-3", {"role": "user", "content": "m3"}) + + assert [m["content"] for m in history.messages["chat-3"]] == ["m2", "m3"] + + +def test_local_mode_adds_stable_created_at_for_refresh_replay(): + _set_history_mode("local") + history = ChatHistory() + + history.add_message("chat-local-time", {"role": "user", "content": "hello"}) + + stored = history.messages["chat-local-time"][0] + assert isinstance(stored["created_at"], str) + assert stored["created_at"].endswith("+00:00") + + +def test_persistent_mode_restores_full_message_schema(): + persisted_messages = [ + { + "role": "assistant", + "content": "Here are the files.", + "reasoning_content": "Thinking", + "tool_calls": ( + '[{"type":"function","id":"call_1","function":{"name":"list_files","arguments":"{\\"directory\\":\\".\\"}"}}]' + ), + "ui_parts": '[{"type":"markdown","content":"Here are the files.","detail":null}]', + "tool_call_id": None, + "created_at": "2026-06-11T10:00:00+00:00", + }, + { + "role": "tool", + "content": '["README.md"]', + "reasoning_content": None, + "tool_calls": None, + "ui_parts": None, + "tool_call_id": "call_1", + "created_at": "2026-06-11T10:00:01+00:00", + }, + ] + + class _DummyPersistence: + async def get_messages_by_chat_id(self, chat_id, user_id=None): + return persisted_messages + + extension.ai_service_manager = SimpleNamespace( + nbi_config=SimpleNamespace(history_config={"mode": "persistent"}), + history_persistence=_DummyPersistence(), + ) + history = ChatHistory() + + restored = __import__("asyncio").run(history.get_history("chat-persistent", user_id="u1")) + + assert restored[0]["reasoning_content"] == "Thinking" + assert restored[0]["tool_calls"][0]["id"] == "call_1" + assert restored[0]["ui_parts"][0]["content"] == "Here are the files." + assert restored[1]["tool_call_id"] == "call_1" diff --git a/tests/test_image_context.py b/tests/test_image_context.py index 3c82189..0ae695b 100644 --- a/tests/test_image_context.py +++ b/tests/test_image_context.py @@ -9,6 +9,7 @@ - Mixed image + text context items both appear in history """ +import asyncio import base64 from contextlib import nullcontext import json @@ -19,13 +20,21 @@ from tornado.httputil import HTTPServerRequest from tornado.web import Application +import notebook_intelligence.extension as ext_module from notebook_intelligence.extension import WebsocketCopilotHandler CHAT_ID = "test-chat" +USER_ID = "test-user" + + +def _scoped_chat_key(chat_id=CHAT_ID, user_id=USER_ID): + return ext_module.ChatHistory._scope_key(chat_id, user_id=user_id) def _make_handler(): + ext_module.shared_chat_history = ext_module.ChatHistory() + WebsocketCopilotHandler.chat_history_ref = ext_module.shared_chat_history app = Mock(spec=Application) # JupyterHandler.set_default_headers reads application.settings and # the ui_* maps during __init__; provide enough for the mock to @@ -38,9 +47,10 @@ def _make_handler(): request.connection = Mock() with patch("notebook_intelligence.extension.ThreadSafeWebSocketConnector"): handler = WebsocketCopilotHandler(app, request) + handler._jupyter_current_user = USER_ID # get_history() returns a throwaway list for unknown chat IDs; pre-seed so # messages appended by on_message are visible after the call returns. - handler.chat_history.messages[CHAT_ID] = [] + handler.chat_history.messages[_scoped_chat_key()] = [] return handler @@ -70,9 +80,12 @@ def _on_message(handler, additional_context, prompt="hello"): else nullcontext() ) with upload_dir_patch: - handler.on_message(msg) - return handler.chat_history.messages[CHAT_ID] - + asyncio.run(handler.on_message(msg)) + call = ext_module.ai_service_manager.handle_chat_request.call_args + if call is None: + return handler.chat_history.messages[_scoped_chat_key()] + request = call.args[0] + return list(request.chat_history) + [{"role": "user", "content": prompt}] def _image_context(file_path, mime_type="image/png"): return { @@ -290,6 +303,37 @@ def test_workspace_image_drag_reaches_vision_provider( b64 = image_msg["content"][1]["image_url"]["url"].split(",", 1)[1] assert base64.b64decode(b64) == image_bytes + def test_image_context_is_request_scoped_not_persisted_in_shared_history( + self, _thread, mock_nbi, mock_ai, tmp_path + ): + """Image context should affect only the current request payload. + + It must not be persisted into shared ``self.chat_history`` across + turns; otherwise a later request without image attachments would + silently inherit old image context. + """ + mock_nbi.root_dir = str(tmp_path) + mock_ai.chat_model = None + mock_ai.is_claude_code_mode = False + + img_file = tmp_path / "shot.png" + img_file.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8) + + handler = _make_handler() + + first_history = _on_message(handler, [_image_context(img_file)], prompt="first turn") + assert len(first_history) == 2 + assert isinstance(first_history[0]["content"], list) + assert first_history[-1]["content"] == "first turn" + + second_history = _on_message(handler, [], prompt="second turn") + assert second_history[-1]["content"] == "second turn" + # Previous image context should not leak into a later request. + assert not any(isinstance(item.get("content"), list) for item in second_history) + # Shared persisted history should keep user prompts only. + persisted = handler.chat_history.messages[_scoped_chat_key()] + assert [m["content"] for m in persisted] == ["first turn", "second turn"] + def test_path_traversal_outside_workspace_is_rejected( self, _thread, mock_nbi, mock_ai, tmp_path, caplog ): diff --git a/tests/test_message_sanitizer.py b/tests/test_message_sanitizer.py new file mode 100644 index 0000000..76e8a51 --- /dev/null +++ b/tests/test_message_sanitizer.py @@ -0,0 +1,132 @@ +from notebook_intelligence.message_sanitizer import sanitize_chat_history_tool_calls + + +def test_preserves_matched_assistant_tool_call_block(): + messages = [ + {"role": "user", "content": "run it"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q":"x"}'}, + } + ], + "ui_parts": [{"type": "tool-invocation"}], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "done"}, + ] + + sanitized = sanitize_chat_history_tool_calls(messages) + + assert sanitized == [ + {"role": "user", "content": "run it"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q":"x"}'}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "done"}, + ] + + +def test_drops_dangling_tool_calls_and_empty_assistant(): + messages = [ + {"role": "user", "content": "run it"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q":"x"}'}, + } + ], + }, + {"role": "user", "content": "next"}, + ] + + sanitized = sanitize_chat_history_tool_calls(messages) + + assert sanitized == [ + {"role": "user", "content": "run it"}, + {"role": "user", "content": "next"}, + ] + + +def test_drops_orphan_tool_message(): + messages = [ + {"role": "user", "content": "hi"}, + {"role": "tool", "tool_call_id": "call_1", "content": "orphan"}, + {"role": "assistant", "content": "done"}, + ] + + sanitized = sanitize_chat_history_tool_calls(messages) + + assert sanitized == [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "done"}, + ] + + +def test_keeps_only_matched_tool_calls_and_matching_tool_results(): + messages = [ + { + "role": "assistant", + "content": "Let me try two things.", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q":"a"}'}, + }, + { + "id": "call_2", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q":"b"}'}, + }, + ], + }, + {"role": "tool", "tool_call_id": "call_2", "content": "matched"}, + {"role": "tool", "tool_call_id": "call_x", "content": "orphan"}, + {"role": "assistant", "content": "final"}, + ] + + sanitized = sanitize_chat_history_tool_calls(messages) + + assert sanitized == [ + { + "role": "assistant", + "content": "Let me try two things.", + "tool_calls": [ + { + "id": "call_2", + "type": "function", + "function": {"name": "lookup", "arguments": '{"q":"b"}'}, + } + ], + }, + {"role": "tool", "tool_call_id": "call_2", "content": "matched"}, + {"role": "assistant", "content": "final"}, + ] + + +def test_removes_non_list_tool_calls_field(): + messages = [ + {"role": "assistant", "content": "hello", "tool_calls": {"id": "bad"}}, + ] + + sanitized = sanitize_chat_history_tool_calls(messages) + + assert sanitized == [ + {"role": "assistant", "content": "hello"}, + ] diff --git a/tests/test_websocket_handler_integration.py b/tests/test_websocket_handler_integration.py index 5639754..8c4549b 100644 --- a/tests/test_websocket_handler_integration.py +++ b/tests/test_websocket_handler_integration.py @@ -1,14 +1,18 @@ +import asyncio import pytest import json from unittest.mock import Mock, patch, MagicMock from tornado.httputil import HTTPServerRequest from tornado.web import Application +import notebook_intelligence.extension as ext_module from notebook_intelligence.extension import WebsocketCopilotHandler from notebook_intelligence.context_factory import RuleContextFactory from notebook_intelligence.ruleset import RuleContext class TestWebsocketHandlerIntegration: + _user_id = "test-user" + def _create_mock_application(self): """Create a properly mocked Tornado Application. @@ -29,12 +33,27 @@ def _create_mock_request(self): request = Mock(spec=HTTPServerRequest) request.connection = Mock() return request + + def _run_on_message(self, handler, message): + # `on_message` is async; tests need to drive it to completion so + # the context factory and request-building side effects actually + # happen before assertions. + handler._jupyter_current_user = self._user_id + chat_id = message.get("data", {}).get("chatId") + if chat_id is not None: + scoped_chat_id = ext_module.ChatHistory._scope_key( + chat_id, user_id=self._user_id + ) + handler.chat_history.messages[scoped_chat_id] = [] + asyncio.run(handler.on_message(json.dumps(message))) def test_init_with_default_context_factory(self): """Test WebsocketCopilotHandler initialization with default context factory.""" with patch('notebook_intelligence.extension.ThreadSafeWebSocketConnector'), \ patch('notebook_intelligence.extension.ai_service_manager') as mock_ai_manager, \ patch('notebook_intelligence.extension.github_copilot') as mock_copilot: + ext_module.shared_chat_history = ext_module.ChatHistory() + WebsocketCopilotHandler.chat_history_ref = ext_module.shared_chat_history handler = WebsocketCopilotHandler( self._create_mock_application(), self._create_mock_request() @@ -50,6 +69,8 @@ def test_init_with_custom_context_factory(self): with patch('notebook_intelligence.extension.ThreadSafeWebSocketConnector'), \ patch('notebook_intelligence.extension.ai_service_manager') as mock_ai_manager, \ patch('notebook_intelligence.extension.github_copilot') as mock_copilot: + ext_module.shared_chat_history = ext_module.ChatHistory() + WebsocketCopilotHandler.chat_history_ref = ext_module.shared_chat_history handler = WebsocketCopilotHandler( self._create_mock_application(), self._create_mock_request(), @@ -72,6 +93,8 @@ def test_on_message_chat_request_creates_context(self, mock_thread, mock_nb_inte mock_factory.create.return_value = mock_context with patch('notebook_intelligence.extension.ThreadSafeWebSocketConnector'): + ext_module.shared_chat_history = ext_module.ChatHistory() + WebsocketCopilotHandler.chat_history_ref = ext_module.shared_chat_history handler = WebsocketCopilotHandler( self._create_mock_application(), self._create_mock_request(), @@ -94,7 +117,7 @@ def test_on_message_chat_request_creates_context(self, mock_thread, mock_nb_inte } # Call on_message - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) # Verify context factory was called mock_factory.create.assert_called_once_with( @@ -104,8 +127,14 @@ def test_on_message_chat_request_creates_context(self, mock_thread, mock_nb_inte root_dir='/workspace' ) - # Verify thread was started - mock_thread.assert_called_once() + # Verify our worker thread was started. asyncio may also create an + # internal waitpid helper thread in the same patch scope. + assert any( + call.kwargs.get('target') == handler._run_request_thread + and call.kwargs.get('args') + == (mock_ai_manager.handle_chat_request.return_value, 'test-message-id') + for call in mock_thread.call_args_list + ) # Verify the ChatRequest was created with rule_context thread_call_args = mock_thread.call_args[1]['args'] @@ -130,6 +159,8 @@ def test_on_message_generate_code_creates_context(self, mock_thread, mock_nb_int mock_factory.create.return_value = mock_context with patch('notebook_intelligence.extension.ThreadSafeWebSocketConnector'): + ext_module.shared_chat_history = ext_module.ChatHistory() + WebsocketCopilotHandler.chat_history_ref = ext_module.shared_chat_history handler = WebsocketCopilotHandler( self._create_mock_application(), self._create_mock_request(), @@ -152,7 +183,7 @@ def test_on_message_generate_code_creates_context(self, mock_thread, mock_nb_int } # Call on_message - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) # Verify context factory was called mock_factory.create.assert_called_once_with( @@ -162,8 +193,14 @@ def test_on_message_generate_code_creates_context(self, mock_thread, mock_nb_int root_dir='/workspace' ) - # Verify thread was started - mock_thread.assert_called_once() + # Verify our worker thread was started. asyncio may also create an + # internal waitpid helper thread in the same patch scope. + assert any( + call.kwargs.get('target') == handler._run_request_thread + and call.kwargs.get('args') + == (mock_ai_manager.handle_chat_request.return_value, 'test-message-id') + for call in mock_thread.call_args_list + ) @patch('notebook_intelligence.extension.ai_service_manager') @patch('notebook_intelligence.extension.NotebookIntelligence') @@ -179,6 +216,8 @@ def test_on_message_agent_mode_creates_context(self, mock_thread, mock_nb_intel, mock_factory.create.return_value = mock_context with patch('notebook_intelligence.extension.ThreadSafeWebSocketConnector'): + ext_module.shared_chat_history = ext_module.ChatHistory() + WebsocketCopilotHandler.chat_history_ref = ext_module.shared_chat_history handler = WebsocketCopilotHandler( self._create_mock_application(), self._create_mock_request(), @@ -205,7 +244,7 @@ def test_on_message_agent_mode_creates_context(self, mock_thread, mock_nb_intel, } # Call on_message - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) # Verify context factory was called with agent mode mock_factory.create.assert_called_once_with( @@ -233,6 +272,8 @@ def test_on_message_additional_context_includes_file_contents(self, mock_thread, mock_factory.create.return_value = Mock(spec=RuleContext) with patch('notebook_intelligence.extension.ThreadSafeWebSocketConnector'): + ext_module.shared_chat_history = ext_module.ChatHistory() + WebsocketCopilotHandler.chat_history_ref = ext_module.shared_chat_history handler = WebsocketCopilotHandler( self._create_mock_application(), self._create_mock_request(), @@ -261,7 +302,7 @@ def test_on_message_additional_context_includes_file_contents(self, mock_thread, } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) mock_ai_manager.handle_chat_request.assert_called_once() chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] @@ -295,6 +336,8 @@ def test_on_message_claude_mode_emits_at_mention_not_contents( mock_factory.create.return_value = Mock(spec=RuleContext) with patch('notebook_intelligence.extension.ThreadSafeWebSocketConnector'): + ext_module.shared_chat_history = ext_module.ChatHistory() + WebsocketCopilotHandler.chat_history_ref = ext_module.shared_chat_history handler = WebsocketCopilotHandler( self._create_mock_application(), self._create_mock_request(), @@ -327,7 +370,7 @@ def test_on_message_claude_mode_emits_at_mention_not_contents( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) mock_ai_manager.handle_chat_request.assert_called_once() chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] @@ -397,7 +440,7 @@ def test_on_message_claude_mode_image_branch_unchanged( } with patch('notebook_intelligence.extension._upload_dir', str(upload_root)): - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1 @@ -456,7 +499,7 @@ def test_on_message_claude_mode_rejects_out_of_workspace_path( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] # No context message produced; the sandbox rejected the path @@ -516,7 +559,7 @@ def test_on_message_claude_mode_upload_non_image_uses_absolute_path( } with patch('notebook_intelligence.extension._upload_dir', str(upload_root)): - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1 @@ -574,7 +617,7 @@ def test_on_message_rejects_forged_upload_path_outside_upload_dir( } with patch('notebook_intelligence.extension._upload_dir', str(upload_root)): - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert chat_request.chat_history == [] @@ -643,7 +686,7 @@ def test_on_message_claude_mode_rejects_control_char_filename( } with patch('notebook_intelligence.extension._upload_dir', str(upload_root)): - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert chat_request.chat_history == [], ( @@ -705,7 +748,7 @@ def test_on_message_claude_mode_preserves_notebook_cell_pointer( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1 @@ -771,7 +814,7 @@ def test_on_message_claude_mode_preserves_selection_line_range( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1 @@ -831,7 +874,7 @@ def test_on_message_claude_mode_no_selection_no_range_pointer( } } - handler.on_message(json.dumps(message)) + self._run_on_message(handler, message) chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] assert len(chat_request.chat_history) == 1