Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cecli/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,7 +1718,7 @@ def get_default_notification_command(self):
"$toastXml = $template.GetXml(); "
"$toastXml.GetElementsByTagName('text')[0].AppendChild"
"($template.CreateTextNode('cecli')) > $null; "
f"$toastXml.GetElementsByTagName('text')[1].AppendChild"
"$toastXml.GetElementsByTagName('text')[1].AppendChild"
f"($template.CreateTextNode('{NOTIFICATION_MESSAGE}')) > $null; "
"$toast = [Windows.UI.Notifications.ToastNotification]::new($toastXml); "
"[Windows.UI.Notifications.ToastNotificationManager]::CreateToastNotifier('cecli')"
Expand Down
20 changes: 20 additions & 0 deletions cecli/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,26 @@ async def send_completion(
sorted_tools = sorted(
effective_tools, key=lambda x: x.get("function", {}).get("name", "Invalid Name")
)

try:
# Deep copy to avoid modifying original tool schemas
sorted_tools = json.loads(json.dumps(sorted_tools))

for tool in sorted_tools:
function_schema = tool.get("function")
if function_schema and "description" in function_schema:
desc = function_schema.get("description")
if isinstance(desc, str):
# Escape the description string for JSON, but without the outer quotes.
# This is a workaround for issues with special characters in descriptions.
function_schema["description"] = json.dumps(desc, ensure_ascii=False)[
1:-1
]
except (TypeError, json.JSONDecodeError):
# If deep copy fails, proceed with original tools.
# This is a safeguard.
pass

kwargs["tools"] = sorted_tools

if functions and len(functions) == 1:
Expand Down
52 changes: 52 additions & 0 deletions tests/basic/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,55 @@ def test_print_matching_models_price_formatting(self):
output_found = any("$10.50/1m/output" in call for call in calls)
assert input_found, "Input pricing format incorrect"
assert output_found, "Output pricing format incorrect"

@patch("cecli.models.litellm.acompletion")
async def test_tool_description_escaping(self, mock_acompletion):
"""
Test that tool descriptions with special characters are properly escaped.
"""
model = Model("gpt-4")
messages = [{"role": "user", "content": "Hello"}]

# A complex description with various special characters
complex_description = (
'This is a "test" description with `special` characters like \\, \n, and *.'
)

# Mock tool with the complex description
mock_tool = {
"type": "function",
"function": {
"name": "test_tool",
"description": complex_description,
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}

await model.send_completion(messages, functions=None, stream=False, tools=[mock_tool])

# Verify that acompletion was called
mock_acompletion.assert_called_once()

# Get the keyword arguments passed to acompletion
call_kwargs = mock_acompletion.call_args.kwargs

# Check that the 'tools' argument is present and correctly formatted
assert "tools" in call_kwargs
sent_tools = call_kwargs["tools"]
assert isinstance(sent_tools, list)
assert len(sent_tools) == 1

# Verify the description of the sent tool
sent_tool_function = sent_tools[0].get("function", {})
sent_description = sent_tool_function.get("description")

# The description should be a JSON-escaped string
# Expected: 'This is a \\"test\\" description with `special` characters like \\\\, \\n, and *.'
expected_escaped_description = (
'This is a \\"test\\" description with `special` characters like \\\\, \\n, and *.'
)
assert sent_description == expected_escaped_description
Loading