diff --git a/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py b/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py index 81d595b..d6bc234 100644 --- a/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py +++ b/notebook_intelligence/llm_providers/litellm_compatible_llm_provider.py @@ -7,6 +7,24 @@ DEFAULT_CONTEXT_WINDOW = 4096 + +def strip_reasoning_fields(messages: list[dict]) -> list[dict]: + """Return a copy of messages with output-only reasoning fields removed. + + ``reasoning_content`` (and ``reasoning``) are OUTPUT-only fields produced by + the model. NBI stores them in chat history and replays the full history on + the next turn. Strict-validating endpoints (e.g. Databricks model serving, + which uses pydantic ``extra="forbid"``) reject requests that contain these + keys. We strip them before sending without mutating the caller's list or + NBI's stored history. + """ + return [ + {k: v for k, v in m.items() if k not in ("reasoning_content", "reasoning")} + if isinstance(m, dict) else m + for m in messages + ] + + class LiteLLMCompatibleChatModel(ChatModel): def __init__(self, provider: "LiteLLMCompatibleLLMProvider"): super().__init__(provider) @@ -44,7 +62,7 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: api_key = api_key_prop.value if api_key_prop is not None else None litellm_resp = litellm.completion( model=model_id, - messages=messages.copy(), + messages=strip_reasoning_fields(messages), tools=tools, tool_choice=options.get("tool_choice", None), api_base=base_url, diff --git a/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py b/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py index 750de47..81cb234 100644 --- a/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py +++ b/notebook_intelligence/llm_providers/openai_compatible_llm_provider.py @@ -25,6 +25,23 @@ def sanitize_tools_for_openai_compatible(tools: list[dict] | None) -> list[dict] return sanitized_tools +def strip_reasoning_fields(messages: list[dict]) -> list[dict]: + """Return a copy of messages with output-only reasoning fields removed. + + ``reasoning_content`` (and ``reasoning``) are OUTPUT-only fields produced by + the model. NBI stores them in chat history and replays the full history on + the next turn. Strict-validating endpoints (e.g. Databricks model serving, + which uses pydantic ``extra="forbid"``) reject requests that contain these + keys. We strip them before sending without mutating the caller's list or + NBI's stored history. + """ + return [ + {k: v for k, v in m.items() if k not in ("reasoning_content", "reasoning")} + if isinstance(m, dict) else m + for m in messages + ] + + class OpenAICompatibleChatModel(ChatModel): def __init__(self, provider: "OpenAICompatibleLLMProvider"): super().__init__(provider) @@ -65,7 +82,7 @@ def completions(self, messages: list[dict], tools: list[dict] = None, response: client = OpenAI(base_url=base_url, api_key=api_key) resp = client.chat.completions.create( model=model_id, - messages=messages.copy(), + messages=strip_reasoning_fields(messages), tools=sanitize_tools_for_openai_compatible(tools) or omit, tool_choice=options.get("tool_choice", omit), stream=stream, diff --git a/tests/test_openai_compatible_llm_provider.py b/tests/test_openai_compatible_llm_provider.py index 3c66752..9385311 100644 --- a/tests/test_openai_compatible_llm_provider.py +++ b/tests/test_openai_compatible_llm_provider.py @@ -3,6 +3,7 @@ from notebook_intelligence.llm_providers.openai_compatible_llm_provider import ( OpenAICompatibleLLMProvider, sanitize_tools_for_openai_compatible, + strip_reasoning_fields, ) @@ -58,3 +59,53 @@ def test_openai_compatible_chat_model_drops_strict_before_request(mock_openai_cl create_kwargs = mock_client.chat.completions.create.call_args.kwargs assert "strict" not in create_kwargs["tools"][0]["function"] assert tools[0]["function"]["strict"] is True + + +def test_strip_reasoning_fields_removes_reasoning_keys_without_mutating_input(): + messages = [ + {"role": "system", "content": "sys"}, + {"role": "assistant", "content": "hi", "reasoning_content": "thinking...", "reasoning": "more"}, + "not-a-dict", + ] + + stripped = strip_reasoning_fields(messages) + + assert "reasoning_content" not in stripped[1] + assert "reasoning" not in stripped[1] + assert stripped[1]["content"] == "hi" + assert stripped[1]["role"] == "assistant" + assert stripped[2] == "not-a-dict" + # original messages must not be mutated + assert messages[1]["reasoning_content"] == "thinking..." + assert messages[1]["reasoning"] == "more" + + +@patch("notebook_intelligence.llm_providers.openai_compatible_llm_provider.OpenAI") +def test_openai_compatible_chat_model_strips_reasoning_before_request(mock_openai_cls): + provider = OpenAICompatibleLLMProvider() + model = provider.chat_models[0] + model.set_property_value("model_id", "test-model") + model.set_property_value("api_key", "test-key") + model.set_property_value("base_url", "https://example.com/v1") + + mock_client = MagicMock() + mock_openai_cls.return_value = mock_client + mock_response = MagicMock() + mock_response.model_dump_json.return_value = '{"choices": [{"message": {"content": "ok"}}]}' + mock_response.choices = [MagicMock(message=MagicMock(reasoning_content=None, reasoning=None))] + mock_client.chat.completions.create.return_value = mock_response + + messages = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "prev", "reasoning_content": "", "reasoning": "x"}, + ] + + model.completions(messages=messages) + + outbound = mock_client.chat.completions.create.call_args.kwargs["messages"] + for m in outbound: + assert "reasoning_content" not in m + assert "reasoning" not in m + # NBI's stored history must be left intact for the next turn + assert messages[1]["reasoning_content"] == "" + assert messages[1]["reasoning"] == "x"