diff --git a/cecli/io.py b/cecli/io.py index cbea56e2109..d3cdf0b04d6 100644 --- a/cecli/io.py +++ b/cecli/io.py @@ -1718,7 +1718,7 @@ def get_default_notification_command(self): "$toastXml = $template.GetXml(); " "$toastXml.GetElementsByTagName('text')[0].AppendChild" "($template.CreateTextNode('cecli')) > $null; " - f"$toastXml.GetElementsByTagName('text')[1].AppendChild" + "$toastXml.GetElementsByTagName('text')[1].AppendChild" f"($template.CreateTextNode('{NOTIFICATION_MESSAGE}')) > $null; " "$toast = [Windows.UI.Notifications.ToastNotification]::new($toastXml); " "[Windows.UI.Notifications.ToastNotificationManager]::CreateToastNotifier('cecli')" diff --git a/cecli/models.py b/cecli/models.py index 495895bda12..19a6f8cff35 100644 --- a/cecli/models.py +++ b/cecli/models.py @@ -1172,6 +1172,26 @@ async def send_completion( sorted_tools = sorted( effective_tools, key=lambda x: x.get("function", {}).get("name", "Invalid Name") ) + + try: + # Deep copy to avoid modifying original tool schemas + sorted_tools = json.loads(json.dumps(sorted_tools)) + + for tool in sorted_tools: + function_schema = tool.get("function") + if function_schema and "description" in function_schema: + desc = function_schema.get("description") + if isinstance(desc, str): + # Escape the description string for JSON, but without the outer quotes. + # This is a workaround for issues with special characters in descriptions. + function_schema["description"] = json.dumps(desc, ensure_ascii=False)[ + 1:-1 + ] + except (TypeError, json.JSONDecodeError): + # If deep copy fails, proceed with original tools. + # This is a safeguard. + pass + kwargs["tools"] = sorted_tools if functions and len(functions) == 1: diff --git a/tests/basic/test_models.py b/tests/basic/test_models.py index 82a765171ad..5a9e5171d36 100644 --- a/tests/basic/test_models.py +++ b/tests/basic/test_models.py @@ -810,3 +810,55 @@ def test_print_matching_models_price_formatting(self): output_found = any("$10.50/1m/output" in call for call in calls) assert input_found, "Input pricing format incorrect" assert output_found, "Output pricing format incorrect" + + @patch("cecli.models.litellm.acompletion") + async def test_tool_description_escaping(self, mock_acompletion): + """ + Test that tool descriptions with special characters are properly escaped. + """ + model = Model("gpt-4") + messages = [{"role": "user", "content": "Hello"}] + + # A complex description with various special characters + complex_description = ( + 'This is a "test" description with `special` characters like \\, \n, and *.' + ) + + # Mock tool with the complex description + mock_tool = { + "type": "function", + "function": { + "name": "test_tool", + "description": complex_description, + "parameters": { + "type": "object", + "properties": {}, + "required": [], + }, + }, + } + + await model.send_completion(messages, functions=None, stream=False, tools=[mock_tool]) + + # Verify that acompletion was called + mock_acompletion.assert_called_once() + + # Get the keyword arguments passed to acompletion + call_kwargs = mock_acompletion.call_args.kwargs + + # Check that the 'tools' argument is present and correctly formatted + assert "tools" in call_kwargs + sent_tools = call_kwargs["tools"] + assert isinstance(sent_tools, list) + assert len(sent_tools) == 1 + + # Verify the description of the sent tool + sent_tool_function = sent_tools[0].get("function", {}) + sent_description = sent_tool_function.get("description") + + # The description should be a JSON-escaped string + # Expected: 'This is a \\"test\\" description with `special` characters like \\\\, \\n, and *.' + expected_escaped_description = ( + 'This is a \\"test\\" description with `special` characters like \\\\, \\n, and *.' + ) + assert sent_description == expected_escaped_description