From 6bd7cebe13f3b8d4ea38450adf8eef4cad09e042 Mon Sep 17 00:00:00 2001 From: Robert Jackson Date: Tue, 9 Dec 2025 17:36:04 -0500 Subject: [PATCH] Revert "fix: Ensure file uploads use RFC 7578 standard for multipart arrays" The changes on the Speakeasy side don't seem to be working as intended. Specifically, `multipartArrayFormat: standard` does not cause `src/glean/api_client/utils/forms.py` to be updated upon the next re-generation. This brings back the manual work around for now, while we debug what is going on with the generation issue. The way the original fix works is that if we end up with fields that end with `[]` they are corrected, which means our work around is safe to have alongside the fixed generation (because once the generation is fixed there won't be any fields with `[]`). This reverts commit cbf4a87535f6d921370d0bb42b7ce5f9dbd930fa. --- .../api_client/_hooks/multipart_fix_hook.py | 92 ++++++++ src/glean/api_client/_hooks/registration.py | 4 + tests/test_multipart_fix_hook.py | 203 ++++++++++++++++++ 3 files changed, 299 insertions(+) create mode 100644 src/glean/api_client/_hooks/multipart_fix_hook.py create mode 100644 tests/test_multipart_fix_hook.py diff --git a/src/glean/api_client/_hooks/multipart_fix_hook.py b/src/glean/api_client/_hooks/multipart_fix_hook.py new file mode 100644 index 00000000..e57e65cd --- /dev/null +++ b/src/glean/api_client/_hooks/multipart_fix_hook.py @@ -0,0 +1,92 @@ +"""Hook to fix multipart form file field names that incorrectly have '[]' suffix.""" + +from typing import Any, Dict, List, Tuple +from .types import SDKInitHook +from glean.api_client.httpclient import HttpClient +from glean.api_client.utils import forms + + +class MultipartFileFieldFixHook(SDKInitHook): + """ + Fixes multipart form serialization where file field names incorrectly have '[]' suffix. + + Speakeasy sometimes generates code that adds '[]' to file field names in multipart forms, + but this is incorrect. File fields should not have the array suffix, only regular form + fields should use this convention. + + This hook patches the serialize_multipart_form function to fix the issue at the source. + """ + + def sdk_init(self, base_url: str, client: HttpClient) -> Tuple[str, HttpClient]: + """Initialize the SDK and patch the multipart form serialization.""" + self._patch_multipart_serialization() + return base_url, client + + def _patch_multipart_serialization(self): + """Patch the serialize_multipart_form function to fix file field names.""" + # Store reference to original function + original_serialize_multipart_form = forms.serialize_multipart_form + + def fixed_serialize_multipart_form( + media_type: str, request: Any + ) -> Tuple[str, Dict[str, Any], List[Tuple[str, Any]]]: + """Fixed version of serialize_multipart_form that doesn't add '[]' to file field names.""" + # Call the original function + result_media_type, form_data, files_list = ( + original_serialize_multipart_form(media_type, request) + ) + + # Fix file field names in the files list + fixed_files = [] + for item in files_list: + if isinstance(item, tuple) and len(item) >= 2: + field_name = item[0] + file_data = item[1] + + # Remove '[]' suffix from file field names only + # We can identify file fields by checking if the data looks like file content + if field_name.endswith("[]") and self._is_file_field_data( + file_data + ): + fixed_field_name = field_name[:-2] # Remove '[]' suffix + fixed_item = (fixed_field_name,) + item[1:] + fixed_files.append(fixed_item) + else: + fixed_files.append(item) + else: + fixed_files.append(item) + + return result_media_type, form_data, fixed_files + + # Replace the original function with our fixed version + forms.serialize_multipart_form = fixed_serialize_multipart_form + + def _is_file_field_data(self, file_data: Any) -> bool: + """ + Determine if the data represents file field content. + + File fields typically have tuple format: (filename, content) or (filename, content, content_type) + where content is bytes, file-like object, or similar. + """ + if isinstance(file_data, tuple) and len(file_data) >= 2: + # Check the structure: (filename, content, [optional content_type]) + filename = file_data[0] + content = file_data[1] + + # If filename is empty, this is likely JSON content, not a file + if filename == "": + return False + + # File content is typically bytes, string, or file-like object + # But exclude empty strings and None values + if content is None or content == "": + return False + + return ( + isinstance(content, (bytes, str)) + or hasattr(content, "read") # File-like object + or ( + hasattr(content, "__iter__") and not isinstance(content, str) + ) # Iterable but not string + ) + return False diff --git a/src/glean/api_client/_hooks/registration.py b/src/glean/api_client/_hooks/registration.py index 25192fc5..01498ddf 100644 --- a/src/glean/api_client/_hooks/registration.py +++ b/src/glean/api_client/_hooks/registration.py @@ -1,4 +1,5 @@ from .types import Hooks +from .multipart_fix_hook import MultipartFileFieldFixHook from .agent_file_upload_error_hook import AgentFileUploadErrorHook @@ -13,5 +14,8 @@ def init_hooks(hooks: Hooks): with an instance of a hook that implements that specific Hook interface Hooks are registered per SDK instance, and are valid for the lifetime of the SDK instance""" + # Register hook to fix multipart file field names that incorrectly have '[]' suffix + hooks.register_sdk_init_hook(MultipartFileFieldFixHook()) + # Register hook to provide helpful error messages for agent file upload issues hooks.register_after_error_hook(AgentFileUploadErrorHook()) diff --git a/tests/test_multipart_fix_hook.py b/tests/test_multipart_fix_hook.py new file mode 100644 index 00000000..1ba1c8c2 --- /dev/null +++ b/tests/test_multipart_fix_hook.py @@ -0,0 +1,203 @@ +"""Test for the multipart file field fix hook.""" + +from unittest.mock import Mock, patch + +import pytest + +from src.glean.api_client._hooks.multipart_fix_hook import MultipartFileFieldFixHook +from src.glean.api_client.httpclient import HttpClient + + +class TestMultipartFileFieldFixHook: + """Test cases for the MultipartFileFieldFixHook.""" + + def setup_method(self): + """Set up test fixtures.""" + self.hook = MultipartFileFieldFixHook() + self.mock_client = Mock(spec=HttpClient) + + def test_sdk_init_returns_unchanged_params(self): + """Test that SDK init returns the same base_url and client.""" + base_url = "https://api.example.com" + + with patch.object(self.hook, "_patch_multipart_serialization"): + result_url, result_client = self.hook.sdk_init(base_url, self.mock_client) + + assert result_url == base_url + assert result_client == self.mock_client + + def test_sdk_init_calls_patch_function(self): + """Test that SDK init calls the patch function.""" + base_url = "https://api.example.com" + + with patch.object(self.hook, "_patch_multipart_serialization") as mock_patch: + self.hook.sdk_init(base_url, self.mock_client) + mock_patch.assert_called_once() + + def test_is_file_field_data_identifies_file_content(self): + """Test the file field data identification logic.""" + # Test file field formats + assert self.hook._is_file_field_data(("test.txt", b"content")) + assert self.hook._is_file_field_data(("test.txt", b"content", "text/plain")) + assert self.hook._is_file_field_data(("test.txt", "string content")) + + # Test with file-like object + mock_file = Mock() + mock_file.read = Mock() + assert self.hook._is_file_field_data(("test.txt", mock_file)) + + # Test non-file field formats + assert not self.hook._is_file_field_data("regular_value") + assert not self.hook._is_file_field_data(123) + assert not self.hook._is_file_field_data(("single_item",)) + assert not self.hook._is_file_field_data((None, None)) + + @patch("src.glean.api_client._hooks.multipart_fix_hook.forms") + def test_patch_multipart_serialization_replaces_function(self, mock_forms_module): + """Test that the patching replaces the serialize_multipart_form function.""" + # Mock the original function + original_function = Mock() + mock_forms_module.serialize_multipart_form = original_function + + # Call the patch method + self.hook._patch_multipart_serialization() + + # Verify that the function was replaced + assert mock_forms_module.serialize_multipart_form != original_function + + @patch("src.glean.api_client._hooks.multipart_fix_hook.forms") + def test_patched_function_fixes_file_field_names(self, mock_forms_module): + """Test that the patched function correctly fixes file field names.""" + # Mock original function to return data with '[]' suffix + original_function = Mock() + original_function.return_value = ( + "multipart/form-data", + {"regular_field": "value"}, + [ + ("file[]", ("test.txt", b"file content", "text/plain")), + ("documents[]", ("doc.pdf", b"pdf content", "application/pdf")), + ("regular_array[]", "regular_value"), # This should not be changed + ], + ) + mock_forms_module.serialize_multipart_form = original_function + + # Apply the patch + self.hook._patch_multipart_serialization() + + # Get the patched function + patched_function = mock_forms_module.serialize_multipart_form + + # Call the patched function + media_type, form_data, files_list = patched_function( + "multipart/form-data", Mock() + ) + + # Verify the results + assert media_type == "multipart/form-data" + assert form_data == {"regular_field": "value"} + + # Check that file field names are fixed but regular fields are not + expected_files = [ + ("file", ("test.txt", b"file content", "text/plain")), + ("documents", ("doc.pdf", b"pdf content", "application/pdf")), + ("regular_array[]", "regular_value"), # Should remain unchanged + ] + assert files_list == expected_files + + @patch("src.glean.api_client._hooks.multipart_fix_hook.forms") + def test_patched_function_preserves_correct_names(self, mock_forms_module): + """Test that the patched function preserves already correct field names.""" + # Mock original function to return data without '[]' suffix + original_function = Mock() + original_function.return_value = ( + "multipart/form-data", + {}, + [ + ("file", ("test.txt", b"file content", "text/plain")), + ("document", ("doc.pdf", b"pdf content", "application/pdf")), + ], + ) + mock_forms_module.serialize_multipart_form = original_function + + # Apply the patch + self.hook._patch_multipart_serialization() + + # Get the patched function + patched_function = mock_forms_module.serialize_multipart_form + + # Call the patched function + media_type, form_data, files_list = patched_function( + "multipart/form-data", Mock() + ) + + # Verify that nothing was changed + expected_files = [ + ("file", ("test.txt", b"file content", "text/plain")), + ("document", ("doc.pdf", b"pdf content", "application/pdf")), + ] + assert files_list == expected_files + + @patch("src.glean.api_client._hooks.multipart_fix_hook.forms") + def test_patched_function_handles_mixed_fields(self, mock_forms_module): + """Test handling of mixed file and non-file fields.""" + # Mock original function with mixed field types + original_function = Mock() + original_function.return_value = ( + "multipart/form-data", + {"form_field": "value"}, + [ + ("correct_file", ("test1.txt", b"content1", "text/plain")), + ("wrong_file[]", ("test2.txt", b"content2", "text/plain")), + ("form_array[]", "form_value"), # Regular form field, should keep [] + ( + "json_field[]", + ("", '{"key": "value"}', "application/json"), + ), # JSON field, might need [] + ], + ) + mock_forms_module.serialize_multipart_form = original_function + + # Apply the patch + self.hook._patch_multipart_serialization() + + # Get the patched function + patched_function = mock_forms_module.serialize_multipart_form + + # Call the patched function + media_type, form_data, files_list = patched_function( + "multipart/form-data", Mock() + ) + + # Verify the results - only actual file fields should have [] removed + expected_files = [ + ("correct_file", ("test1.txt", b"content1", "text/plain")), + ("wrong_file", ("test2.txt", b"content2", "text/plain")), # Fixed + ("form_array[]", "form_value"), # Preserved - not a file field + ( + "json_field[]", + ("", '{"key": "value"}', "application/json"), + ), # Preserved - JSON content + ] + assert files_list == expected_files + + def test_file_field_detection_edge_cases(self): + """Test edge cases for file field detection.""" + # Empty content + assert not self.hook._is_file_field_data(("test.txt", "")) + + # None content + assert not self.hook._is_file_field_data(("test.txt", None)) + + # List/tuple content (should be considered file-like) + assert self.hook._is_file_field_data(("test.txt", [1, 2, 3])) + assert self.hook._is_file_field_data(("test.txt", (1, 2, 3))) + + # String content (should be considered file content) + assert self.hook._is_file_field_data(("test.txt", "string content")) + + # But not if it's the first element + assert not self.hook._is_file_field_data(("string content",)) + + +if __name__ == "__main__": + pytest.main([__file__])