diff --git a/notebook_intelligence/api.py b/notebook_intelligence/api.py index e2fda1eb..76c45fcc 100644 --- a/notebook_intelligence/api.py +++ b/notebook_intelligence/api.py @@ -140,6 +140,8 @@ class ChatRequest: tool_selection: RequestToolSelection = None command: str = '' prompt: str = '' + language: str = '' + kernel_name: str = '' chat_history: list[dict] = None cancel_token: CancelToken = None # NEW: Add context for rule evaluation diff --git a/notebook_intelligence/base_chat_participant.py b/notebook_intelligence/base_chat_participant.py index 2398d3e9..913658b9 100644 --- a/notebook_intelligence/base_chat_participant.py +++ b/notebook_intelligence/base_chat_participant.py @@ -102,6 +102,14 @@ def schema(self) -> dict: } } } + }, + "language": { + "type": "string", + "description": "Programming language for the notebook kernel, e.g. python or r" + }, + "kernel_name": { + "type": "string", + "description": "Jupyter kernel name to use when creating the notebook" } }, "required": [], @@ -120,8 +128,13 @@ def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvo async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: cell_sources = tool_args.get('cell_sources', []) + language = tool_args.get('language') or request.language or 'python' + kernel_name = tool_args.get('kernel_name') or request.kernel_name or '' - ui_cmd_response = await response.run_ui_command('notebook-intelligence:create-new-notebook-from-py', {'code': ''}) + ui_cmd_response = await response.run_ui_command( + 'notebook-intelligence:create-new-notebook', + {'code': '', 'language': language, 'kernelName': kernel_name} + ) file_path = ui_cmd_response['path'] for cell_source in cell_sources: @@ -133,7 +146,55 @@ async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, t source = cell_source.get('source', '') ui_cmd_response = await response.run_ui_command('notebook-intelligence:add-code-cell-to-notebook', {'code': source, 'path': file_path}) - return "Notebook created successfully at {file_path}" + return f"Notebook created successfully at {file_path}" + + +class ListAvailableNotebookKernelsTool(Tool): + @property + def name(self) -> str: + return "list_available_notebook_kernels" + + @property + def title(self) -> str: + return "List available notebook kernels" + + @property + def tags(self) -> list[str]: + return ["default-participant-tool"] + + @property + def description(self) -> str: + return "Lists Jupyter kernels available in the current frontend environment" + + @property + def schema(self) -> dict: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "strict": True, + "parameters": { + "type": "object", + "properties": {}, + "required": [], + "additionalProperties": False, + }, + }, + } + + def pre_invoke(self, request: ChatRequest, tool_args: dict) -> Union[ToolPreInvokeResponse, None]: + return ToolPreInvokeResponse( + f"Calling tool '{self.name}'", + detail={"title": "Parameters", "content": json.dumps(tool_args)}, + ) + + async def handle_tool_call(self, request: ChatRequest, response: ChatResponse, tool_context: dict, tool_args: dict) -> str: + ui_cmd_response = await response.run_ui_command( + 'notebook-intelligence:list-available-notebook-kernels', + {} + ) + return json.dumps(ui_cmd_response) class AddMarkdownCellToNotebookTool(Tool): def __init__(self, auto_approve: bool = False): @@ -379,7 +440,8 @@ async def generate_code_cell(self, request: ChatRequest) -> str: chat_model = request.host.chat_model messages = request.chat_history.copy() messages.pop() - messages.insert(0, {"role": "system", "content": f"You are an assistant that creates Python code which will be used in a Jupyter notebook. Generate only Python code and some comments for the code. You should return the code directly, without wrapping it inside ```."}) + language = request.language or 'python' + messages.insert(0, {"role": "system", "content": f"You are an assistant that creates {language} code which will be used in a Jupyter notebook. Generate only {language} code and some comments for the code. You should return the code directly, without wrapping it inside ```."}) messages.append({"role": "user", "content": f"Generate code for: {request.prompt}"}) generated = chat_model.completions(messages) code = generated['choices'][0]['message']['content'] @@ -440,8 +502,17 @@ async def handle_chat_request(self, request: ChatRequest, response: ChatResponse async def handle_ask_mode_chat_request(self, request: ChatRequest, response: ChatResponse, options: dict = {}) -> None: chat_model = request.host.chat_model if request.command == 'newNotebook': + language = request.language or 'python' + kernel_name = request.kernel_name or '' # create a new notebook - ui_cmd_response = await response.run_ui_command('notebook-intelligence:create-new-notebook-from-py', {'code': ''}) + ui_cmd_response = await response.run_ui_command( + 'notebook-intelligence:create-new-notebook', + { + 'code': '', + 'language': language, + 'kernelName': kernel_name, + } + ) file_path = ui_cmd_response['path'] code = await self.generate_code_cell(request) @@ -494,6 +565,8 @@ async def handle_ask_mode_chat_request(self, request: ChatRequest, response: Cha def get_tool_by_name(name: str) -> Tool: if name == "create_new_notebook": return CreateNewNotebookTool() + elif name == "list_available_notebook_kernels": + return ListAvailableNotebookKernelsTool() elif name == "add_markdown_cell_to_notebook": return AddMarkdownCellToNotebookTool() elif name == "add_code_cell_to_notebook": diff --git a/notebook_intelligence/built_in_toolsets.py b/notebook_intelligence/built_in_toolsets.py index af1a9f34..b5174d55 100644 --- a/notebook_intelligence/built_in_toolsets.py +++ b/notebook_intelligence/built_in_toolsets.py @@ -1,6 +1,7 @@ # Copyright (c) Mehmet Bektas from time import time +import json from notebook_intelligence.api import ChatResponse, MarkdownPartData, Toolset import logging import notebook_intelligence.api as nbapi @@ -96,11 +97,49 @@ def _truncate_read_file_output( @nbapi.auto_approve @nbapi.tool -async def create_new_notebook(**args) -> str: +async def list_available_notebook_kernels(**args) -> str: + """Lists Jupyter kernels available in the current frontend environment. + + Use this before creating a notebook when you need a kernel that may differ + from the current notebook context. + """ + response = args["response"] + ui_cmd_response = await response.run_ui_command( + 'notebook-intelligence:list-available-notebook-kernels', + {} + ) + return json.dumps(ui_cmd_response) + + +@nbapi.auto_approve +@nbapi.tool +async def create_new_notebook( + language: str = "python", + kernel_name: str = "", + **args, +) -> str: """Creates a new empty notebook. + + Args: + language: Programming language for the notebook kernel, e.g. python or r. + kernel_name: Explicit Jupyter kernel name to use when creating the notebook. """ response = args["response"] - ui_cmd_response = await response.run_ui_command('notebook-intelligence:create-new-notebook-from-py', {'code': ''}) + request = args.get("request") + effective_language = language or getattr(request, "language", "") or "python" + effective_kernel_name = ( + kernel_name + or getattr(request, "kernel_name", "") + or "" + ) + ui_cmd_response = await response.run_ui_command( + 'notebook-intelligence:create-new-notebook', + { + 'code': '', + 'language': effective_language, + 'kernelName': effective_kernel_name, + } + ) file_path = ui_cmd_response['path'] return f"Created new notebook at {file_path}" @@ -135,7 +174,7 @@ async def add_markdown_cell(source: str, **args) -> str: async def add_code_cell(source: str, **args) -> str: """Adds a code cell to notebook. Args: - source: Python code source + source: Code source for the notebook's current language """ response = args["response"] ui_cmd_response = await response.run_ui_command('notebook-intelligence:add-code-cell-to-active-notebook', {'source': source}) @@ -189,7 +228,7 @@ async def set_cell_type_and_source(cell_index: int, cell_type: str, source: str, Args: cell_index: Zero based cell index cell_type: Cell type (code or markdown) - source: Markdown or Python code source + source: Markdown or code source """ response = args["response"] ui_cmd_response = await response.run_ui_command('notebook-intelligence:set-cell-type-and-source', {"cellIndex": cell_index, "cellType": cell_type, "source": source}) @@ -218,7 +257,7 @@ async def insert_cell(cell_index: int, cell_type: str, source: str, **args) -> s Args: cell_index: Zero based cell index cell_type: Cell type (code or markdown) - source: Markdown or Python code source + source: Markdown or code source """ response = args["response"] ui_cmd_response = await response.run_ui_command('notebook-intelligence:insert-cell-at-index', {"cellIndex": cell_index, "cellType": cell_type, "source": source}) @@ -711,10 +750,12 @@ async def run_command_in_embedded_terminal(command: str, working_directory: str return f"Error running command in embedded terminal: {str(e)}" NOTEBOOK_EDIT_INSTRUCTIONS = """ -You are an assistant that creates and edits Jupyter notebooks. Notebooks are made up of source code cells and markdown cells. Markdown cells have source in markdown format and code cells have source in a specified programming language. If no programming language is specified, then use Python for the language of the code. +You are an assistant that creates and edits Jupyter notebooks. Notebooks are made up of source code cells and markdown cells. Markdown cells have source in markdown format and code cells have source in a specified programming language. If no programming language is specified, then use Python for the language of the code. If the context specifies a kernel or language for the current notebook, keep that kernel and language. Do not silently switch kernels or rewrite the workflow in a different language. If you need to create a notebook use the create_new_notebook tool. If you need to add a code cell to the notebook use the add_code_cell tool. If you need to add a markdown cell to the notebook use the add_markdown_cell tool. +If you need to create a notebook in a language or kernel that is not already established by the current notebook context, call the list_available_notebook_kernels tool first and choose only from the kernels it returns. Do not guess kernel names. + If you need to rename a notebook use the rename_notebook tool. You can refer to cells in notebooks by their index. The first cell in the notebook has index 0, the second cell has index 1, and so on. You can get the number of cells in the notebook using the get_number_of_cells tool. You can get the type and source of a cell using the get_cell_type_and_source tool. You can get the output of a cell using the get_cell_output tool. @@ -792,6 +833,7 @@ async def run_command_in_embedded_terminal(command: str, working_directory: str description="Edit notebook using the JupyterLab notebook editor", provider=None, tools=[ + list_available_notebook_kernels, create_new_notebook, rename_notebook, add_markdown_cell, diff --git a/notebook_intelligence/claude.py b/notebook_intelligence/claude.py index 956702bd..f7c9c592 100644 --- a/notebook_intelligence/claude.py +++ b/notebook_intelligence/claude.py @@ -42,6 +42,7 @@ def _extract_text_from_content(content) -> str: CLAUDE_CODE_MAX_BUFFER_SIZE = 20 * 1024 * 1024 # 20MB JUPYTER_UI_TOOLS_SYSTEM_PROMPT = """You can interact with the JupyterLab UI (notebook / file editor, terminal, etc.) using the tools provided in 'nbi' MCP server. Tools in 'nbi' MCP server, directly interact with the JupyterLab UI, accessing notebooks and files open in the UI. When interacting with JupyterLab UI, use relative file paths for file paths. If the user has asked you to create a notebook, save it afterward. +If you need to create a notebook in a language or kernel that is not already established by the current notebook context, first call the list-available-notebook-kernels tool and choose only from the kernels it returns. Do not guess kernel names. """ @@ -100,6 +101,7 @@ class ClaudeAgentClientStatus(str, Enum): # label and a keyword-heuristic kind rather than masking the raw name. _CLAUDE_TOOLS: dict[str, tuple[str, str]] = { # NBI's MCP toolset (defined in this file via @tool(...)) + "list-available-notebook-kernels": ("Listing notebook kernels", "read"), "create-new-notebook": ("Creating notebook", "edit"), "rename-notebook": ("Renaming notebook", "edit"), "add-markdown-cell": ("Adding markdown cell", "edit"), @@ -1316,12 +1318,36 @@ def resume_session(self, session_id: str) -> None: self.reconnect() -@tool("create-new-notebook", "Creates a new empty notebook.", {}) +@tool( + "list-available-notebook-kernels", + "Lists Jupyter kernels available in the current frontend environment.", + {}, +) +async def list_available_notebook_kernels(args) -> str: + response = get_current_response() + ui_cmd_response = await response.run_ui_command( + 'notebook-intelligence:list-available-notebook-kernels', + {} + ) + return tool_text_response(json.dumps(ui_cmd_response)) + + +@tool( + "create-new-notebook", + "Creates a new empty notebook.", + {"language": str, "kernel_name": str}, +) async def create_new_notebook(args) -> str: """Creates a new empty notebook. """ response = get_current_response() - ui_cmd_response = await response.run_ui_command('notebook-intelligence:create-new-notebook-from-py', {'code': ''}) + request = get_current_request() + language = args.get("language") or getattr(request, "language", "") or "python" + kernel_name = args.get("kernel_name") or getattr(request, "kernel_name", "") or "" + ui_cmd_response = await response.run_ui_command( + 'notebook-intelligence:create-new-notebook', + {'code': '', 'language': language, 'kernelName': kernel_name} + ) file_path = ui_cmd_response['path'] return tool_text_response(f"Created new notebook at {file_path}") @@ -1353,7 +1379,7 @@ async def add_markdown_cell(args) -> str: async def add_code_cell(args) -> str: """Adds a code cell to notebook. Args: - source: Python code source + source: Code source for the notebook's current language """ response = get_current_response() ui_cmd_response = await response.run_ui_command('notebook-intelligence:add-code-cell-to-active-notebook', {'source': args['source']}) @@ -1783,7 +1809,7 @@ def _create_client_options(self) -> ClaudeAgentOptions: self._jupyter_ui_tools_mcp_server = create_sdk_mcp_server( name="nbi", version="1.0.0", - tools=[create_new_notebook, add_markdown_cell, add_code_cell, get_number_of_cells, get_cell_type_and_source, get_cell_output, set_cell_type_and_source, delete_cell, insert_cell, run_cell, save_notebook, rename_notebook, run_command_in_jupyter_terminal, open_file_in_jupyter_ui] + tools=[list_available_notebook_kernels, create_new_notebook, add_markdown_cell, add_code_cell, get_number_of_cells, get_cell_type_and_source, get_cell_output, set_cell_type_and_source, delete_cell, insert_cell, run_cell, save_notebook, rename_notebook, run_command_in_jupyter_terminal, open_file_in_jupyter_ui] ) mcp_servers = {} jupyter_ui_tools_enabled = ClaudeToolType.JupyterUITools in claude_settings.get('tools', []) @@ -1791,7 +1817,7 @@ def _create_client_options(self) -> ClaudeAgentOptions: mcp_servers["nbi"] = self._jupyter_ui_tools_mcp_server allowed_tools = [] if jupyter_ui_tools_enabled: - allowed_tools.extend(["mcp__nbi__create-new-notebook", "mcp__nbi__add-markdown-cell", "mcp__nbi__add-code-cell", "mcp__nbi__get-number-of-cells", "mcp__nbi__get-cell-type-and-source", "mcp__nbi__get-cell-output", "mcp__nbi__set-cell-type-and-source", "mcp__nbi__insert-cell", "mcp__nbi__save-notebook", "mcp__nbi__rename-notebook", "mcp__nbi__open-file-in-jupyter-ui"]) + allowed_tools.extend(["mcp__nbi__list-available-notebook-kernels", "mcp__nbi__create-new-notebook", "mcp__nbi__add-markdown-cell", "mcp__nbi__add-code-cell", "mcp__nbi__get-number-of-cells", "mcp__nbi__get-cell-type-and-source", "mcp__nbi__get-cell-output", "mcp__nbi__set-cell-type-and-source", "mcp__nbi__insert-cell", "mcp__nbi__save-notebook", "mcp__nbi__rename-notebook", "mcp__nbi__open-file-in-jupyter-ui"]) setting_sources = claude_settings.get('setting_sources') chat_model_id = claude_settings.get('chat_model', '').strip() if chat_model_id == "": diff --git a/notebook_intelligence/context_factory.py b/notebook_intelligence/context_factory.py index a33203e7..101f06c1 100644 --- a/notebook_intelligence/context_factory.py +++ b/notebook_intelligence/context_factory.py @@ -8,11 +8,18 @@ class RuleContextFactory: """Factory for creating RuleContext from various sources.""" @staticmethod - def create(filename: str, language: str, chat_mode_id: str, root_dir: str) -> RuleContext: + def create( + filename: str, + language: str, + chat_mode_id: str, + root_dir: str, + kernel_name: str | None = None, + ) -> RuleContext: """Create RuleContext from WebSocket message data.""" return RuleContext( filename=filename, - kernel=language, + language=language, + kernel_name=kernel_name or None, mode=chat_mode_id, directory=os.path.dirname(os.path.join(root_dir, filename)) ) diff --git a/notebook_intelligence/extension.py b/notebook_intelligence/extension.py index d733ebee..d693c1d3 100644 --- a/notebook_intelligence/extension.py +++ b/notebook_intelligence/extension.py @@ -2298,6 +2298,8 @@ def on_message(self, message): chatId = data['chatId'] prompt = data['prompt'] language = data['language'] + kernel_name = data.get('kernelName', '') + kernel_display_name = data.get('kernelDisplayName', '') filename = data['filename'] additionalContext = data.get('additionalContext', []) chat_mode = ChatMode('agent', 'Agent') if data.get('chatMode', 'ask') == 'agent' else ChatMode('ask', 'Ask') @@ -2323,6 +2325,18 @@ def on_message(self, message): current_directory_file_msg = f"{NBI_CONTEXT_PREFIX} '{current_directory}'" if filename != '': current_directory_file_msg += f" and current file is: '{filename}'" + if language: + current_directory_file_msg += ( + f" and active programming language is: '{language}'" + ) + if kernel_name: + current_directory_file_msg += ( + f" with active kernel name: '{kernel_name}'" + ) + if kernel_display_name: + current_directory_file_msg += ( + f" ({kernel_display_name})" + ) chat_history.append({"role": "user", "content": current_directory_file_msg}) token_limit = 100 if ai_service_manager.chat_model is None else ai_service_manager.chat_model.context_window @@ -2548,13 +2562,14 @@ def on_message(self, message): rule_context = self._context_factory.create( filename=filename, language=language, + kernel_name=kernel_name, chat_mode_id=chat_mode.id, root_dir=NotebookIntelligence.root_dir ) # last prompt is added later request_chat_history = chat_history[chat_history_initial_size:-1] if is_claude_code_mode else chat_history[:-1] - coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, chat_history=request_chat_history, cancel_token=cancel_token, rule_context=rule_context, permission_mode=permission_mode), response_emitter) + coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, tool_selection=tool_selection, prompt=prompt, language=language, kernel_name=kernel_name, chat_history=request_chat_history, cancel_token=cancel_token, rule_context=rule_context, permission_mode=permission_mode), response_emitter) thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId)) thread.start() elif messageType == RequestDataType.GenerateCode: @@ -2565,6 +2580,7 @@ def on_message(self, message): suffix = data['suffix'] existing_code = data['existingCode'] language = data['language'] + kernel_name = data.get('kernelName', '') filename = data['filename'] is_claude_code_mode = ai_service_manager.is_claude_code_mode chat_mode = ChatMode('inline-chat', 'Inline Chat') if is_claude_code_mode else ChatMode('ask', 'Ask') @@ -2589,7 +2605,7 @@ def on_message(self, message): root_dir=NotebookIntelligence.root_dir ) - coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, chat_history=self.chat_history.get_history(chatId), cancel_token=cancel_token, rule_context=rule_context), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."}) + coro = ai_service_manager.handle_chat_request(ChatRequest(chat_mode=chat_mode, prompt=prompt, language=language, kernel_name=kernel_name, chat_history=self.chat_history.get_history(chatId), cancel_token=cancel_token, rule_context=rule_context), response_emitter, options={"system_prompt": f"You are an assistant that generates code for '{language}' language. You generate code between existing leading and trailing code sections.{existing_code_message} Be concise and return only code as a response. Don't include leading content or trailing content in your response, they are provided only for context. You can reuse methods and symbols defined in leading and trailing content."}) thread = threading.Thread(target=self._run_request_thread, args=(coro, messageId)) thread.start() elif messageType == RequestDataType.InlineCompletionRequest: diff --git a/notebook_intelligence/rule_manager.py b/notebook_intelligence/rule_manager.py index 3a35211f..3b6328b2 100644 --- a/notebook_intelligence/rule_manager.py +++ b/notebook_intelligence/rule_manager.py @@ -164,13 +164,14 @@ def get_applicable_rules(self, context: RuleContext) -> List[Rule]: applicable_rules = self.ruleset.get_applicable_rules( filename=context.basename, - kernel=context.kernel, + language=context.language, + kernel_name=context.kernel_name, mode=context.mode, directory=context.directory ) log.debug(f"Found {len(applicable_rules)} applicable rules for context: " - f"file={context.basename}, kernel={context.kernel}, mode={context.mode}, dir={context.directory}") + f"file={context.basename}, language={context.language}, kernel_name={context.kernel_name}, mode={context.mode}, dir={context.directory}") return applicable_rules @@ -195,7 +196,11 @@ def validate_rule_file(self, filepath: str) -> Dict[str, Any]: if rule.apply not in ['always', 'auto', 'manual']: validation_result['warnings'].append(f"Unknown apply mode: {rule.apply}") - if not rule.scope.file_patterns and not rule.scope.kernels: + if ( + not rule.scope.file_patterns + and not rule.scope.languages + and not rule.scope.kernel_names + ): validation_result['warnings'].append("Rule has no scope restrictions, will apply to all contexts") except FileNotFoundError: diff --git a/notebook_intelligence/ruleset.py b/notebook_intelligence/ruleset.py index ff97ca06..8835b876 100644 --- a/notebook_intelligence/ruleset.py +++ b/notebook_intelligence/ruleset.py @@ -12,7 +12,8 @@ class RuleScope: """Defines the scope where a rule applies.""" file_patterns: List[str] = field(default_factory=list) - kernels: List[str] = field(default_factory=list) + languages: List[str] = field(default_factory=list) + kernel_names: List[str] = field(default_factory=list) directory_patterns: List[str] = field(default_factory=list) cell_types: Optional[List[str]] = None @@ -42,12 +43,17 @@ def matches_directory(self, directory: str) -> bool: return True return False - def matches_kernel(self, kernel_name: str) -> bool: - """Check if the kernel matches any of the specified kernels.""" - if not self.kernels: - return True # No kernels specified means matches all kernels - - return kernel_name in self.kernels + def matches_language(self, language: str) -> bool: + """Check if the language matches any of the specified languages.""" + if not self.languages: + return True + return language in self.languages + + def matches_kernel_name(self, kernel_name: str) -> bool: + """Check if the kernel name matches any of the specified kernel names.""" + if not self.kernel_names: + return True + return kernel_name in self.kernel_names @dataclass class Rule: @@ -100,9 +106,16 @@ def from_file(cls, filepath: str, mode: Optional[str] = None) -> 'Rule': apply_mode = 'always' # Create scope object + if 'kernels' in scope_data: + raise ValueError( + f"Invalid rule frontmatter in {filepath}: " + "scope.kernels is no longer supported; use scope.languages or scope.kernel_names" + ) + scope = RuleScope( file_patterns=scope_data.get('file_patterns', []), - kernels=scope_data.get('kernels', []), + languages=scope_data.get('languages', []), + kernel_names=scope_data.get('kernel_names', []), directory_patterns=scope_data.get('directory_patterns', []), cell_types=scope_data.get('cell_types') ) @@ -117,7 +130,15 @@ def from_file(cls, filepath: str, mode: Optional[str] = None) -> 'Rule': priority=priority ) - def matches_context(self, filename: str, kernel: str = None, cell_type: str = None, mode: str = None, directory: str = None) -> bool: + def matches_context( + self, + filename: str, + language: str = None, + kernel_name: str = None, + cell_type: str = None, + mode: str = None, + directory: str = None, + ) -> bool: """Check if this rule applies to the given context.""" if not self.active: return False @@ -130,7 +151,10 @@ def matches_context(self, filename: str, kernel: str = None, cell_type: str = No if not self.scope.matches_file(filename): return False - if kernel and not self.scope.matches_kernel(kernel): + if language and not self.scope.matches_language(language): + return False + + if kernel_name and not self.scope.matches_kernel_name(kernel_name): return False if directory and not self.scope.matches_directory(directory): @@ -145,7 +169,8 @@ def to_dict(self) -> Dict[str, Any]: 'apply': self.apply, 'scope': { 'file_patterns': self.scope.file_patterns, - 'kernels': self.scope.kernels, + 'languages': self.scope.languages, + 'kernel_names': self.scope.kernel_names, 'directory_patterns': self.scope.directory_patterns, 'cell_types': self.scope.cell_types }, @@ -159,9 +184,15 @@ def to_dict(self) -> Dict[str, Any]: def from_dict(cls, data: Dict[str, Any]) -> 'Rule': """Create rule from dictionary.""" scope_data = data.get('scope', {}) + if 'kernels' in scope_data: + raise ValueError( + "Invalid rule data: scope.kernels is no longer supported; " + "use scope.languages or scope.kernel_names" + ) scope = RuleScope( file_patterns=scope_data.get('file_patterns', []), - kernels=scope_data.get('kernels', []), + languages=scope_data.get('languages', []), + kernel_names=scope_data.get('kernel_names', []), directory_patterns=scope_data.get('directory_patterns', []), cell_types=scope_data.get('cell_types') ) @@ -191,20 +222,41 @@ def add_rule(self, rule: Rule) -> None: else: self.global_rules.append(rule) - def get_applicable_rules(self, filename: str, kernel: str = None, - cell_type: str = None, mode: str = None, directory: str = None) -> List[Rule]: + def get_applicable_rules( + self, + filename: str, + language: str = None, + kernel_name: str = None, + cell_type: str = None, + mode: str = None, + directory: str = None, + ) -> List[Rule]: """Get all rules that apply to the given context.""" applicable_rules = [] # Add applicable global rules for rule in self.global_rules: - if rule.matches_context(filename, kernel, cell_type, mode, directory): + if rule.matches_context( + filename, + language, + kernel_name, + cell_type, + mode, + directory, + ): applicable_rules.append(rule) # Add applicable mode-specific rules if mode and mode in self.mode_rules: for rule in self.mode_rules[mode]: - if rule.matches_context(filename, kernel, cell_type, mode, directory): + if rule.matches_context( + filename, + language, + kernel_name, + cell_type, + mode, + directory, + ): applicable_rules.append(rule) # Sort by priority (lower number = higher priority), then by filename @@ -267,7 +319,8 @@ def from_dict(cls, data: Dict[str, Any]) -> 'RuleSet': class RuleContext: """Context information for rule matching.""" filename: str - kernel: Optional[str] = None + language: Optional[str] = None + kernel_name: Optional[str] = None mode: Optional[str] = None directory: Optional[str] = None diff --git a/src/api.ts b/src/api.ts index 203dd3a0..4b7a4416 100644 --- a/src/api.ts +++ b/src/api.ts @@ -1255,6 +1255,8 @@ export class NBIAPI { chatId: string, prompt: string, language: string, + kernelName: string, + kernelDisplayName: string, currentDirectory: string, filename: string, additionalContext: IContextItem[], @@ -1272,6 +1274,8 @@ export class NBIAPI { chatId, prompt, language, + kernelName, + kernelDisplayName, currentDirectory, filename, additionalContext, diff --git a/src/chat-sidebar.tsx b/src/chat-sidebar.tsx index 7e766434..66dcad85 100644 --- a/src/chat-sidebar.tsx +++ b/src/chat-sidebar.tsx @@ -111,6 +111,8 @@ export interface IRunChatCompletionRequest { type: RunChatCompletionType; content: string; language?: string; + kernelName?: string; + kernelDisplayName?: string; currentDirectory?: string; filename?: string; prefix?: string; @@ -1188,6 +1190,8 @@ async function submitCompletionRequest( request.chatId, request.content, request.language || 'python', + request.kernelName || '', + request.kernelDisplayName || '', request.currentDirectory || '', request.filename || '', request.additionalContext || [], @@ -1203,6 +1207,8 @@ async function submitCompletionRequest( request.chatId, request.content, request.language || 'python', + request.kernelName || '', + request.kernelDisplayName || '', request.currentDirectory || '', request.filename || '', [], @@ -3089,6 +3095,8 @@ function SidebarComponent(props: any) { type: RunChatCompletionType.Chat, content: extractedPrompt, language: activeDocInfo.language, + kernelName: activeDocInfo.kernelName, + kernelDisplayName: activeDocInfo.kernelDisplayName, currentDirectory: props.getCurrentDirectory(), filename: activeDocInfo.filePath, additionalContext, @@ -3525,6 +3533,11 @@ function SidebarComponent(props: any) { externalActiveDocInfo?.filePath?.endsWith('.ipynb') ? externalActiveDocInfo.filePath : null; + request.language = request.language || externalActiveDocInfo?.language; + request.kernelName = + request.kernelName || externalActiveDocInfo?.kernelName; + request.kernelDisplayName = + request.kernelDisplayName || externalActiveDocInfo?.kernelDisplayName; const hideInChat = !!request.hideInChat; const newList = hideInChat ? chatMessages diff --git a/src/command-ids.ts b/src/command-ids.ts index ef6913e4..851a3412 100644 --- a/src/command-ids.ts +++ b/src/command-ids.ts @@ -9,8 +9,9 @@ export namespace CommandIDs { export const insertAtCursor = 'notebook-intelligence:insert-at-cursor'; export const addCodeAsNewCell = 'notebook-intelligence:add-code-as-new-cell'; export const createNewFile = 'notebook-intelligence:create-new-file'; - export const createNewNotebookFromPython = - 'notebook-intelligence:create-new-notebook-from-py'; + export const createNewNotebook = 'notebook-intelligence:create-new-notebook'; + export const listAvailableNotebookKernels = + 'notebook-intelligence:list-available-notebook-kernels'; export const renameNotebook = 'notebook-intelligence:rename-notebook'; export const addCodeCellToNotebook = 'notebook-intelligence:add-code-cell-to-notebook'; diff --git a/src/index.ts b/src/index.ts index 37760b33..a83c1132 100644 --- a/src/index.ts +++ b/src/index.ts @@ -126,6 +126,13 @@ import { ITerminalTracker } from '@jupyterlab/terminal'; import { Token } from '@lumino/coreutils'; import { NotebookGenerationToolbarExtension } from './notebook-generation-toolbar'; import { attachTerminalDragDrop } from './terminal-drag'; +import { + DEFAULT_NOTEBOOK_KERNEL, + NotebookKernelNotFoundError, + findKernelProfile, + listKernelProfiles, + normalizeNotebookLanguage +} from './notebook-kernels'; import { CommandIDs } from './command-ids'; @@ -389,9 +396,20 @@ class ActiveDocumentWatcher { const np = activeWidget as NotebookPanel; activeDocumentInfo.filename = np.sessionContext.name; activeDocumentInfo.filePath = np.sessionContext.path; - activeDocumentInfo.language = - (np.model?.sharedModel?.metadata?.kernelspec?.language as string) || - 'python'; + const kernelspec = np.model?.sharedModel?.metadata?.kernelspec as + | { + language?: string; + name?: string; + display_name?: string; + } + | undefined; + activeDocumentInfo.language = normalizeNotebookLanguage( + kernelspec?.language + ); + activeDocumentInfo.kernelName = + kernelspec?.name || DEFAULT_NOTEBOOK_KERNEL.kernelName; + activeDocumentInfo.kernelDisplayName = + kernelspec?.display_name || DEFAULT_NOTEBOOK_KERNEL.displayName; const { activeCellIndex, activeCell } = np.content; activeDocumentInfo.activeCellIndex = activeCellIndex; activeDocumentInfo.selection = activeCell?.editor?.getSelection(); @@ -406,6 +424,8 @@ class ActiveDocumentWatcher { contentsModel.mimetype ) || ActiveDocumentWatcher._languageRegistry.findByFileName(fileName); activeDocumentInfo.language = language?.name || 'unknown'; + activeDocumentInfo.kernelName = undefined; + activeDocumentInfo.kernelDisplayName = undefined; activeDocumentInfo.filename = fileName; activeDocumentInfo.filePath = filePath; if (activeWidget instanceof FileEditorWidget) { @@ -418,6 +438,8 @@ class ActiveDocumentWatcher { activeDocumentInfo.filename = ''; activeDocumentInfo.filePath = ''; activeDocumentInfo.language = ''; + activeDocumentInfo.kernelName = undefined; + activeDocumentInfo.kernelDisplayName = undefined; } } @@ -443,6 +465,8 @@ class ActiveDocumentWatcher { lhs.filename !== rhs.filename || lhs.filePath !== rhs.filePath || lhs.language !== rhs.language || + lhs.kernelName !== rhs.kernelName || + lhs.kernelDisplayName !== rhs.kernelDisplayName || lhs.activeCellIndex !== rhs.activeCellIndex || !compareSelections(lhs.selection, rhs.selection) ); @@ -509,6 +533,8 @@ class ActiveDocumentWatcher { static activeDocumentInfo: IActiveDocumentInfo = { language: 'python', + kernelName: DEFAULT_NOTEBOOK_KERNEL.kernelName, + kernelDisplayName: DEFAULT_NOTEBOOK_KERNEL.displayName, filename: 'nb-doesnt-exist.ipynb', filePath: 'nb-doesnt-exist.ipynb', activeWidget: null, @@ -1215,21 +1241,26 @@ const plugin: JupyterFrontEndPlugin = { } }); - app.commands.addCommand(CommandIDs.createNewNotebookFromPython, { + app.commands.addCommand(CommandIDs.createNewNotebook, { execute: async args => { - let pythonKernelSpec = null; const contents = new ContentsManager(); const kernels = new KernelSpecManager(); await kernels.ready; - const kernelspecs = kernels.specs?.kernelspecs; - if (kernelspecs) { - for (const key in kernelspecs) { - const kernelspec = kernelspecs[key]; - if (kernelspec?.language === 'python') { - pythonKernelSpec = kernelspec; - break; - } + let profile; + try { + profile = findKernelProfile(kernels.specs?.kernelspecs, { + language: args.language as string | undefined, + kernelName: args.kernelName as string | undefined + }); + } catch (error) { + if (error instanceof NotebookKernelNotFoundError) { + app.commands.execute('apputils:notify', { + message: error.message, + type: 'error', + options: { autoClose: true } + }); } + throw error; } const newNBFile = await contents.newUntitled({ @@ -1237,15 +1268,16 @@ const plugin: JupyterFrontEndPlugin = { path: defaultBrowser?.model.path }); const nbFileContent = structuredClone(emptyNotebookContent); - if (pythonKernelSpec) { - nbFileContent.metadata = { - kernelspec: { - language: 'python', - name: pythonKernelSpec.name, - display_name: pythonKernelSpec.display_name - } - }; - } + nbFileContent.metadata = { + kernelspec: { + language: profile.language, + name: profile.kernelName, + display_name: profile.displayName + }, + language_info: { + name: profile.language + } + }; if (args.code) { nbFileContent.cells.push({ @@ -1269,6 +1301,16 @@ const plugin: JupyterFrontEndPlugin = { } }); + app.commands.addCommand(CommandIDs.listAvailableNotebookKernels, { + execute: async () => { + const kernels = new KernelSpecManager(); + await kernels.ready; + return { + kernels: listKernelProfiles(kernels.specs?.kernelspecs) + }; + } + }); + app.commands.addCommand(CommandIDs.renameNotebook, { execute: async args => { const activeWidget = app.shell.currentWidget; diff --git a/src/markdown-renderer.tsx b/src/markdown-renderer.tsx index e39f64ae..c89e1c3a 100644 --- a/src/markdown-renderer.tsx +++ b/src/markdown-renderer.tsx @@ -95,10 +95,10 @@ export function MarkdownRenderer({ }; const handleCreateNewNotebookClick = () => { - app.commands.execute( - 'notebook-intelligence:create-new-notebook-from-py', - { language, code: codeString } - ); + app.commands.execute('notebook-intelligence:create-new-notebook', { + language, + code: codeString + }); }; if (inline || !match) { diff --git a/src/notebook-kernels.ts b/src/notebook-kernels.ts new file mode 100644 index 00000000..019a6b8b --- /dev/null +++ b/src/notebook-kernels.ts @@ -0,0 +1,97 @@ +// Copyright (c) Mehmet Bektas + +import { KernelSpec } from '@jupyterlab/services'; + +export interface INotebookKernelProfile { + language: string; + kernelName: string; + displayName: string; +} + +export class NotebookKernelNotFoundError extends Error { + readonly requestedLanguage: string; + readonly requestedKernelName: string; + + constructor(options?: { language?: string; kernelName?: string }) { + const requestedLanguage = normalizeNotebookLanguage(options?.language); + const requestedKernelName = (options?.kernelName ?? '').trim(); + const detail = requestedKernelName + ? `kernel "${requestedKernelName}"` + : `language "${requestedLanguage}"`; + super(`No installed Jupyter kernel matches ${detail}.`); + this.name = 'NotebookKernelNotFoundError'; + this.requestedLanguage = requestedLanguage; + this.requestedKernelName = requestedKernelName; + } +} + +export const DEFAULT_NOTEBOOK_KERNEL: INotebookKernelProfile = Object.freeze({ + language: 'python', + kernelName: 'python3', + displayName: 'Python 3 (ipykernel)' +}); + +export function normalizeNotebookLanguage(raw: string | undefined): string { + const language = (raw ?? '').trim().toLowerCase(); + if (!language) { + return DEFAULT_NOTEBOOK_KERNEL.language; + } + if (language === 'py') { + return 'python'; + } + return language; +} + +export function findKernelProfile( + specs: Record | undefined, + options?: { language?: string; kernelName?: string } +): INotebookKernelProfile { + const requestedKernelName = (options?.kernelName ?? '').trim(); + if (requestedKernelName && specs?.[requestedKernelName]) { + const spec = specs[requestedKernelName]; + return { + language: normalizeNotebookLanguage(spec.language), + kernelName: spec.name, + displayName: spec.display_name + }; + } + + const requestedLanguage = normalizeNotebookLanguage(options?.language); + if (specs) { + for (const key of Object.keys(specs)) { + const spec = specs[key]; + if (normalizeNotebookLanguage(spec.language) === requestedLanguage) { + return { + language: normalizeNotebookLanguage(spec.language), + kernelName: spec.name, + displayName: spec.display_name + }; + } + } + } + + if (requestedLanguage === DEFAULT_NOTEBOOK_KERNEL.language) { + return DEFAULT_NOTEBOOK_KERNEL; + } + + throw new NotebookKernelNotFoundError(options); +} + +export function listKernelProfiles( + specs: Record | undefined +): INotebookKernelProfile[] { + if (!specs) { + return []; + } + + return Object.keys(specs) + .sort((lhs, rhs) => lhs.localeCompare(rhs)) + .map(kernelName => { + const spec = specs[kernelName]; + return { + language: normalizeNotebookLanguage(spec.language), + kernelName: spec.name, + displayName: spec.display_name + }; + }); +} diff --git a/src/tokens.ts b/src/tokens.ts index 8998ca45..8b962105 100644 --- a/src/tokens.ts +++ b/src/tokens.ts @@ -7,6 +7,8 @@ import { Token } from '@lumino/coreutils'; export interface IActiveDocumentInfo { activeWidget: Widget | null; language: string; + kernelName?: string; + kernelDisplayName?: string; filename: string; filePath: string; activeCellIndex: number; diff --git a/tests/conftest.py b/tests/conftest.py index c75600e7..272e574c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,7 +91,8 @@ def sample_rule_context(): """Mock rule context for testing.""" return RuleContext( filename="test.ipynb", - kernel="python3", + language="python", + kernel_name="python3", mode="ask" ) @@ -101,7 +102,8 @@ def python_file_context(): """Mock Python file context for testing.""" return RuleContext( filename="test.py", - kernel="python3", + language="python", + kernel_name="python3", mode="agent" ) @@ -141,7 +143,9 @@ def populated_rules_directory(temp_rules_directory): file_patterns: - "*.ipynb" - "*.py" - kernels: + languages: + - python + kernel_names: - python3 active: true priority: 0 diff --git a/tests/fixtures/rules/01-test-global.md b/tests/fixtures/rules/01-test-global.md index f52fd34d..b7c0e8d6 100644 --- a/tests/fixtures/rules/01-test-global.md +++ b/tests/fixtures/rules/01-test-global.md @@ -4,7 +4,9 @@ scope: file_patterns: - '*.ipynb' - '*.py' - kernels: + languages: + - python + kernel_names: - python3 active: true priority: 0 diff --git a/tests/test_base_chat_participant_integration.py b/tests/test_base_chat_participant_integration.py index 99d1efbc..03cbb5b3 100644 --- a/tests/test_base_chat_participant_integration.py +++ b/tests/test_base_chat_participant_integration.py @@ -1,6 +1,8 @@ import asyncio from unittest.mock import Mock, AsyncMock from notebook_intelligence.base_chat_participant import BaseChatParticipant +from notebook_intelligence.base_chat_participant import CreateNewNotebookTool +from notebook_intelligence.base_chat_participant import ListAvailableNotebookKernelsTool from notebook_intelligence.api import ChatRequest, ChatResponse, ChatMode, CancelToken from notebook_intelligence.ruleset import RuleContext from notebook_intelligence.rule_injector import RuleInjector @@ -55,7 +57,8 @@ def test_handle_ask_mode_chat_request_with_rules(self): # Create request with rule context rule_context = RuleContext( filename="test.ipynb", - kernel="python3", + language="python", + kernel_name="python3", mode="ask" ) @@ -103,7 +106,8 @@ def test_handle_chat_request_agent_mode_with_rules(self): # Create request for agent mode rule_context = RuleContext( filename="test.py", - kernel="python3", + language="python", + kernel_name="python3", mode="agent" ) @@ -149,3 +153,83 @@ def test_handle_chat_request_agent_mode_with_rules(self): assert "system_prompt" in options assert options["system_prompt"] == "Enhanced agent prompt" + + def test_handle_ask_mode_new_notebook_uses_request_language_and_kernel(self): + participant = BaseChatParticipant() + + mock_chat_model = Mock() + mock_host = Mock() + mock_host.chat_model = mock_chat_model + + request = ChatRequest( + host=mock_host, + chat_mode=ChatMode("ask", "Ask"), + command="newNotebook", + prompt="Create a notebook in lang-x", + language="lang-x", + kernel_name="kernel-x", + chat_history=[{"role": "user", "content": "Create a notebook in lang-x"}], + cancel_token=Mock(spec=CancelToken), + rule_context=RuleContext( + filename="test.ipynb", + language="lang-x", + kernel_name="kernel-x", + mode="ask" + ) + ) + + response = Mock(spec=ChatResponse) + response.run_ui_command = AsyncMock( + side_effect=[ + {"path": "Untitled.ipynb"}, + {"ok": True}, + {"ok": True}, + ] + ) + response.stream = Mock() + response.finish = Mock() + + participant.generate_code_cell = AsyncMock(return_value='emit("hi")') + participant.generate_markdown_for_code = AsyncMock(return_value="# Lang X") + + asyncio.run(participant.handle_ask_mode_chat_request(request, response)) + + first_call = response.run_ui_command.await_args_list[0] + assert first_call.args == ( + 'notebook-intelligence:create-new-notebook', + {'code': '', 'language': 'lang-x', 'kernelName': 'kernel-x'} + ) + + def test_list_available_notebook_kernels_tool_reads_frontend_environment(self): + tool = ListAvailableNotebookKernelsTool() + request = ChatRequest() + response = Mock(spec=ChatResponse) + response.run_ui_command = AsyncMock( + return_value={ + "kernels": [ + { + "language": "lang-a", + "kernelName": "kernel-a", + "displayName": "Kernel A", + }, + { + "language": "lang-b", + "kernelName": "kernel-b", + "displayName": "Kernel B", + }, + ] + } + ) + + result = asyncio.run(tool.handle_tool_call(request, response, {}, {})) + + response.run_ui_command.assert_awaited_once_with( + "notebook-intelligence:list-available-notebook-kernels", + {}, + ) + assert '"kernelName": "kernel-a"' in result + assert '"kernelName": "kernel-b"' in result + + def test_get_tool_by_name_returns_kernel_listing_tool(self): + tool = BaseChatParticipant.get_tool_by_name("list_available_notebook_kernels") + assert isinstance(tool, ListAvailableNotebookKernelsTool) diff --git a/tests/test_context_factory.py b/tests/test_context_factory.py index 91d3a3be..ac379b19 100644 --- a/tests/test_context_factory.py +++ b/tests/test_context_factory.py @@ -15,11 +15,13 @@ def test_create(self): filename=filename, language=language, chat_mode_id=chat_mode_id, - root_dir=root_dir + root_dir=root_dir, + kernel_name="python3", ) assert context.filename == filename - assert context.kernel == language + assert context.language == "python" + assert context.kernel_name == "python3" assert context.mode == chat_mode_id assert context.directory == "/workspace" @@ -34,10 +36,12 @@ def test_create_with_subdirectory(self): filename=filename, language=language, chat_mode_id=chat_mode_id, - root_dir=root_dir + root_dir=root_dir, + kernel_name="python3", ) assert context.filename == filename - assert context.kernel == language + assert context.language == "python" + assert context.kernel_name == "python3" assert context.mode == chat_mode_id assert context.directory == "/workspace/notebooks" diff --git a/tests/test_end_to_end_rule_integration.py b/tests/test_end_to_end_rule_integration.py index 75d0c9c3..f9d6cd5b 100644 --- a/tests/test_end_to_end_rule_integration.py +++ b/tests/test_end_to_end_rule_integration.py @@ -108,7 +108,8 @@ def test_rule_application_for_notebook_ask_mode(self, temp_rules_dir): # Create context for notebook in ask mode context = RuleContext( filename="test.ipynb", - kernel="python3", + language="python", + kernel_name="python3", mode="ask" ) @@ -129,7 +130,8 @@ def test_rule_application_for_python_agent_mode(self, temp_rules_dir): # Create context for Python file in agent mode context = RuleContext( filename="script.py", - kernel="python3", + language="python", + kernel_name="python3", mode="agent" ) @@ -149,7 +151,8 @@ def test_rule_formatting_for_llm(self, temp_rules_dir): context = RuleContext( filename="test.ipynb", - kernel="python3", + language="python", + kernel_name="python3", mode="ask" ) @@ -171,12 +174,14 @@ def test_context_factory_creates_correct_context(self): context = factory.create( filename="notebooks/analysis.ipynb", language="python", + kernel_name="python3", chat_mode_id="ask", root_dir="/workspace" ) assert context.filename == "notebooks/analysis.ipynb" - assert context.kernel == "python" + assert context.language == "python" + assert context.kernel_name == "python3" assert context.mode == "ask" assert context.directory == "/workspace/notebooks" assert context.basename == "analysis.ipynb" @@ -191,7 +196,8 @@ def test_rule_injector_end_to_end(self, temp_rules_dir): # Create mock request with context context = RuleContext( filename="test.ipynb", - kernel="python3", + language="python", + kernel_name="python3", mode="ask" ) @@ -240,7 +246,8 @@ def test_rule_priority_ordering(self, temp_rules_dir): context = RuleContext( filename="test.ipynb", - kernel="python3", + language="python", + kernel_name="python3", mode="ask" ) @@ -275,7 +282,8 @@ def test_inactive_rules_not_applied(self, temp_rules_dir): context = RuleContext( filename="test.ipynb", - kernel="python3", + language="python", + kernel_name="python3", mode="ask" ) @@ -293,7 +301,8 @@ def test_file_pattern_matching(self, temp_rules_dir): # Test with .txt file (should not match any rules) context = RuleContext( filename="document.txt", - kernel="text", + language="text", + kernel_name="text", mode="ask" ) diff --git a/tests/test_models.py b/tests/test_models.py index 5a3374d5..e9e47004 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -21,20 +21,35 @@ def test_matches_file_no_patterns_matches_all(self): assert scope.matches_file("notebook.ipynb") is True assert scope.matches_file("data.csv") is True - def test_matches_kernel(self): - scope = RuleScope(kernels=["python3", "python"]) + def test_matches_language(self): + scope = RuleScope(languages=["python", "r"]) - assert scope.matches_kernel("python3") is True - assert scope.matches_kernel("python") is True - assert scope.matches_kernel("r") is False - assert scope.matches_kernel("julia") is False - - def test_matches_kernel_no_kernels_matches_all(self): - scope = RuleScope() # No kernels specified + assert scope.matches_language("python") is True + assert scope.matches_language("r") is True + assert scope.matches_language("python3") is False + assert scope.matches_language("julia") is False + + def test_matches_language_no_languages_matches_all(self): + scope = RuleScope() + + assert scope.matches_language("python") is True + assert scope.matches_language("r") is True + assert scope.matches_language("julia") is True + + def test_matches_kernel_name(self): + scope = RuleScope(kernel_names=["python3", "ir"]) + + assert scope.matches_kernel_name("python3") is True + assert scope.matches_kernel_name("ir") is True + assert scope.matches_kernel_name("python") is False + assert scope.matches_kernel_name("r") is False + + def test_matches_kernel_name_no_kernel_names_matches_all(self): + scope = RuleScope() - assert scope.matches_kernel("python3") is True - assert scope.matches_kernel("r") is True - assert scope.matches_kernel("julia") is True + assert scope.matches_kernel_name("python3") is True + assert scope.matches_kernel_name("r") is True + assert scope.matches_kernel_name("julia-1.10") is True class TestRule: @@ -45,7 +60,9 @@ def test_from_file_with_valid_frontmatter(self, tmp_path): file_patterns: - "*.py" - "*.ipynb" - kernels: + languages: + - python + kernel_names: - python3 cell_types: - code @@ -63,7 +80,8 @@ def test_from_file_with_valid_frontmatter(self, tmp_path): assert rule.filename == "test_rule.md" assert rule.apply == "always" assert rule.scope.file_patterns == ["*.py", "*.ipynb"] - assert rule.scope.kernels == ["python3"] + assert rule.scope.languages == ["python"] + assert rule.scope.kernel_names == ["python3"] assert rule.scope.cell_types == ["code"] assert rule.active is True assert rule.priority == 5 @@ -122,7 +140,7 @@ def test_from_file_nonexistent_file_raises_error(self): Rule.from_file("nonexistent_file.md") def test_matches_context_active_rule(self): - scope = RuleScope(file_patterns=["*.py"], kernels=["python3"]) + scope = RuleScope(file_patterns=["*.py"], languages=["python"], kernel_names=["python3"]) rule = Rule( filename="test.md", apply="always", @@ -131,9 +149,9 @@ def test_matches_context_active_rule(self): content="Test content" ) - assert rule.matches_context("test.py", "python3") is True - assert rule.matches_context("test.ipynb", "python3") is False - assert rule.matches_context("test.py", "r") is False + assert rule.matches_context("test.py", "python", "python3") is True + assert rule.matches_context("test.ipynb", "python", "python3") is False + assert rule.matches_context("test.py", "r", "ir") is False def test_matches_context_inactive_rule(self): scope = RuleScope(file_patterns=["*.py"]) @@ -163,7 +181,7 @@ def test_matches_context_with_mode(self): assert rule.matches_context("test.py") is True # No mode specified def test_to_dict(self): - scope = RuleScope(file_patterns=["*.py"], kernels=["python3"]) + scope = RuleScope(file_patterns=["*.py"], languages=["python"], kernel_names=["python3"]) rule = Rule( filename="test.md", apply="always", @@ -181,7 +199,8 @@ def test_to_dict(self): 'apply': 'always', 'scope': { 'file_patterns': ['*.py'], - 'kernels': ['python3'], + 'languages': ['python'], + 'kernel_names': ['python3'], 'cell_types': None, 'directory_patterns': [] }, @@ -199,7 +218,8 @@ def test_from_dict(self): 'apply': 'auto', 'scope': { 'file_patterns': ['*.ipynb'], - 'kernels': ['python3'], + 'languages': ['python'], + 'kernel_names': ['python3'], 'cell_types': ['code'] }, 'active': False, @@ -213,7 +233,8 @@ def test_from_dict(self): assert rule.filename == 'test.md' assert rule.apply == 'auto' assert rule.scope.file_patterns == ['*.ipynb'] - assert rule.scope.kernels == ['python3'] + assert rule.scope.languages == ['python'] + assert rule.scope.kernel_names == ['python3'] assert rule.scope.cell_types == ['code'] assert rule.active is False assert rule.content == 'Test content' diff --git a/tests/test_rule_auto_reload.py b/tests/test_rule_auto_reload.py index 874e6a48..d93f530f 100644 --- a/tests/test_rule_auto_reload.py +++ b/tests/test_rule_auto_reload.py @@ -157,7 +157,7 @@ def test_get_applicable_rules_auto_reloads(self, tmp_path): manager = RuleManager(str(rules_dir)) # First call - loads rules - context = RuleContext(filename="test.ipynb", kernel="python", mode="ask") + context = RuleContext(filename="test.ipynb", language="python", kernel_name="python3", mode="ask") rules1 = manager.get_applicable_rules(context) assert len(rules1) == 1 assert rules1[0].content == "Original content" @@ -180,7 +180,7 @@ def test_auto_reload_updates_last_modified_time(self, tmp_path): with patch.dict(os.environ, {'NBI_RULES_AUTO_RELOAD': 'true'}): manager = RuleManager(str(rules_dir)) - context = RuleContext(filename="test.ipynb", kernel="python", mode="ask") + context = RuleContext(filename="test.ipynb", language="python", kernel_name="python3", mode="ask") # Initial load manager.get_applicable_rules(context) @@ -208,7 +208,7 @@ def test_auto_reload_with_nested_directories(self, tmp_path): with patch.dict(os.environ, {'NBI_RULES_AUTO_RELOAD': 'true'}): manager = RuleManager(str(rules_dir)) - context = RuleContext(filename="test.ipynb", kernel="python", mode="ask") + context = RuleContext(filename="test.ipynb", language="python", kernel_name="python3", mode="ask") # Initial load rules1 = manager.get_applicable_rules(context) @@ -231,7 +231,7 @@ def test_auto_reload_doesnt_trigger_without_changes(self, tmp_path): with patch.dict(os.environ, {'NBI_RULES_AUTO_RELOAD': 'true'}): manager = RuleManager(str(rules_dir)) - context = RuleContext(filename="test.ipynb", kernel="python", mode="ask") + context = RuleContext(filename="test.ipynb", language="python", kernel_name="python3", mode="ask") # First call manager.get_applicable_rules(context) diff --git a/tests/test_websocket_handler_integration.py b/tests/test_websocket_handler_integration.py index f3ba3c03..0da21867 100644 --- a/tests/test_websocket_handler_integration.py +++ b/tests/test_websocket_handler_integration.py @@ -100,6 +100,7 @@ def test_on_message_chat_request_creates_context(self, mock_thread, mock_nb_inte mock_factory.create.assert_called_once_with( filename='test.ipynb', language='python', + kernel_name='', chat_mode_id='ask', root_dir='/workspace' ) @@ -115,6 +116,9 @@ def test_on_message_chat_request_creates_context(self, mock_thread, mock_nb_inte # We can't easily inspect the ChatRequest object, but we can verify # that the thread was created with the right target assert mock_thread.call_args[1]['target'] is not None + chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] + assert chat_request.language == 'python' + assert chat_request.kernel_name == '' @patch('notebook_intelligence.extension.ai_service_manager') @patch('notebook_intelligence.extension.NotebookIntelligence') @@ -164,6 +168,8 @@ def test_on_message_generate_code_creates_context(self, mock_thread, mock_nb_int # Verify thread was started mock_thread.assert_called_once() + chat_request = mock_ai_manager.handle_chat_request.call_args[0][0] + assert chat_request.language == 'python' @patch('notebook_intelligence.extension.ai_service_manager') @patch('notebook_intelligence.extension.NotebookIntelligence') @@ -211,6 +217,7 @@ def test_on_message_agent_mode_creates_context(self, mock_thread, mock_nb_intel, mock_factory.create.assert_called_once_with( filename='notebook.ipynb', language='python', + kernel_name='', chat_mode_id='agent', root_dir='/workspace' )