Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions notebook_intelligence/ai_service_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from notebook_intelligence.api import ButtonData, ChatModel, EmbeddingModel, InlineCompletionModel, LLMProvider, ChatParticipant, ChatRequest, ChatResponse, CompletionContext, ContextRequest, Host, CompletionContextProvider, MCPPrompt, MCPServer, MarkdownData, NotebookIntelligenceExtension, RegistrationError, TelemetryEvent, TelemetryListener, Tool, Toolset
from notebook_intelligence.base_chat_participant import BaseChatParticipant
from notebook_intelligence.config import NBIConfig
from notebook_intelligence.history_backends import (
HistoryPersistenceBackend,
HistoryPersistenceManager,
MySQLHistoryBackend,
SQLiteHistoryBackend,
)
from notebook_intelligence.github_copilot_chat_participant import GithubCopilotChatParticipant
from notebook_intelligence.claude import CLAUDE_CODE_CHAT_PARTICIPANT_ID, ClaudeCodeChatParticipant, ClaudeCodeInlineCompletionModel, fetch_claude_models, get_claude_models
from notebook_intelligence.llm_providers.github_copilot_llm_provider import GitHubCopilotLLMProvider
Expand Down Expand Up @@ -61,6 +67,13 @@ def __init__(self, options: Optional[dict] = None):
self._options.get("feature_policies") or {},
self._options.get("string_overrides") or {},
)
self._history_persistence = HistoryPersistenceManager()
self.register_history_persistence_backend(MySQLHistoryBackend())
self.register_history_persistence_backend(SQLiteHistoryBackend())
self._history_persistence.reconfigure(
self._nbi_config.history_config,
self._nbi_config.history_backend_configs,
)
self._openai_compatible_llm_provider = OpenAICompatibleLLMProvider()
self._litellm_compatible_llm_provider = LiteLLMCompatibleLLMProvider()
self._ollama_llm_provider = OllamaLLMProvider()
Expand Down Expand Up @@ -93,6 +106,22 @@ def __init__(self, options: Optional[dict] = None):
def nbi_config(self) -> NBIConfig:
return self._nbi_config

@property
def history_persistence(self) -> HistoryPersistenceManager:
return self._history_persistence

def register_history_persistence_backend(
self, backend: HistoryPersistenceBackend
) -> None:
self._history_persistence.register_backend(backend)

def update_history_persistence(self):
"""Refresh history persistence backend configuration from current config."""
self._history_persistence.reconfigure(
self._nbi_config.history_config,
self._nbi_config.history_backend_configs,
)

@property
def ollama_llm_provider(self) -> OllamaLLMProvider:
return self._ollama_llm_provider
Expand Down
34 changes: 34 additions & 0 deletions notebook_intelligence/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class ChatRequest:
cancel_token: CancelToken = None
# NEW: Add context for rule evaluation
rule_context: Optional[RuleContext] = None
# Internal conversation id for persistence backends
conversation_id: str = None

@dataclass
class ResponseStreamData:
Expand Down Expand Up @@ -312,6 +314,12 @@ def message_id(self) -> str:

def stream(self, data: ResponseStreamData, finish: bool = False) -> None:
raise NotImplementedError

def append_tool_calls(self, tool_calls: list[dict] | None) -> None:
return None

def append_history_message(self, message: dict) -> None:
return None

def finish(self) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -636,6 +644,7 @@ async def _tool_call_loop(tool_call_rounds: list):

for choice in tool_response['choices']:
message = choice['message']
response.append_tool_calls(message.get('tool_calls'))
# Some models use 'reasoning', some use 'reasoning_content'
raw_reasoning = message.get('reasoning') or message.get('reasoning_content') or ''

Expand Down Expand Up @@ -687,6 +696,12 @@ async def _tool_call_loop(tool_call_rounds: list):
args = tool_call['function']['arguments']
else:
args = fuzzy_json_loads(tool_call['function']['arguments'])

# Persist tool execution START (initial record).
if request.conversation_id:
request.host.history_persistence.log_tool_execution(
tool_call['id'], request.conversation_id, tool_name, args, ""
)

tool_properties = tool_to_call.schema["function"]["parameters"]["properties"]
if type(args) is str:
Expand All @@ -713,13 +728,25 @@ async def _tool_call_loop(tool_call_rounds: list):
return

tool_call_response = await tool_to_call.handle_tool_call(request, response, tool_context, args)

# Persist tool execution result.
if request.conversation_id:
request.host.history_persistence.log_tool_execution(
tool_call['id'], request.conversation_id, tool_name, args, str(tool_call_response)
)
# Also log the tool message itself
msg_id = str(uuid.uuid4())
request.host.history_persistence.add_message(
msg_id, request.conversation_id, "tool", str(tool_call_response), tool_call_id=tool_call['id']
)

function_call_result_message = {
"role": "tool",
"content": str(tool_call_response),
"tool_call_id": tool_call['id']
}

response.append_history_message(function_call_result_message)
messages.append(function_call_result_message)

if had_tool_call:
Expand Down Expand Up @@ -915,6 +942,9 @@ def register_telemetry_listener(self, listener: TelemetryListener) -> None:
def register_toolset(self, toolset: Toolset) -> None:
raise NotImplementedError

def register_history_persistence_backend(self, backend) -> None:
raise NotImplementedError

@property
def nbi_config(self) -> NBIConfig:
raise NotImplementedError
Expand Down Expand Up @@ -971,6 +1001,10 @@ def get_skill_manager(self):
def websocket_connector(self) -> ThreadSafeWebSocketConnector:
raise NotImplementedError

@property
def history_persistence(self) -> Any:
return NotImplementedError


class NotebookIntelligenceExtension:
@property
Expand Down
59 changes: 59 additions & 0 deletions notebook_intelligence/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import stat
import sys
import tempfile
import copy
from typing import Optional

from notebook_intelligence.feature_flags import (
Expand Down Expand Up @@ -230,6 +231,64 @@ def active_rules(self) -> dict:
"""Get dictionary of active rule states (filename -> bool)."""
return self.get('active_rules', {})

@property
def history_backend_configs(self) -> dict:
"""Get history persistence backend configuration by backend id."""
defaults = {
'mysql': {
'host': 'localhost',
'port': 3306,
'user': '',
'password': '',
'database': 'notebook_intelligence'
},
'sqlite': {
'path': os.path.join(self.nbi_user_dir, 'history.sqlite3')
}
}

merged = copy.deepcopy(defaults)
configured = self.get('history_backend_configs', {})
if isinstance(configured, dict):
for backend_id, backend_cfg in configured.items():
if not isinstance(backend_cfg, dict):
continue
merged.setdefault(backend_id, {})
merged[backend_id].update(backend_cfg)

legacy_mysql = self.get('mysql_config', None)
if isinstance(legacy_mysql, dict):
merged['mysql'].update(
{
key: value
for key, value in legacy_mysql.items()
if key != 'enabled'
}
)
return merged

@property
def history_config(self) -> dict:
"""Get chat history storage configuration."""
cfg = self.get('history_config', {})
mode = cfg.get('mode', 'local')
backend = cfg.get('backend', 'sqlite')
if mode in ['mysql', 'sqlite']:
backend = mode
mode = 'persistent'
local_max_messages = cfg.get('local_max_messages', 10)
try:
local_max_messages = int(local_max_messages)
except Exception:
local_max_messages = 10
if local_max_messages < 1:
local_max_messages = 1
return {
'mode': mode if mode in ['persistent', 'local', 'none'] else 'local',
'backend': backend if isinstance(backend, str) and backend else 'sqlite',
'local_max_messages': local_max_messages
}

def set_rule_active(self, filename: str, active: bool):
"""Set the active state of a rule."""
active_rules = self.active_rules.copy()
Expand Down
Loading
Loading