From 660205b7bbcc3b35ec9d31b146545c5fd7000391 Mon Sep 17 00:00:00 2001 From: frij Date: Thu, 21 Aug 2025 18:51:55 +0000 Subject: [PATCH 1/2] Correct header filtering --- src/clients/s3_client.py | 10 ++++++---- test/unit/test_s3_client.py | 34 +++++++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/clients/s3_client.py b/src/clients/s3_client.py index 52e4514..d1981e4 100644 --- a/src/clients/s3_client.py +++ b/src/clients/s3_client.py @@ -1,4 +1,5 @@ """Client wrapper over aws services.""" +import json import re import time import urllib @@ -65,7 +66,7 @@ class S3Client: } # Restricted http headers that can't be sent to s3 as part of downloading object using preseigned url # Adding these headers can causes a mismatch with Sigv4 signature - BLOCKED_REQUEST_HEADERS = ("Host") + BLOCKED_REQUEST_HEADERS = ("host") def __init__(self, s3ol_access_point: str, max_file_supported=DOCUMENT_MAX_SIZE): self.max_file_supported = max_file_supported @@ -128,12 +129,13 @@ def _filter_request_headers(self, presigned_url, headers={}): filtered_headers = {} parsed_url = urllib.parse.urlparse(presigned_url) parsed_query_params = urllib.parse.parse_qs(parsed_url.query) - signed_headers = set(parsed_query_params.get('X-Amz-SignedHeaders', [])) + signed_headers_params = parsed_query_params.get('X-Amz-SignedHeaders', []) + signed_headers = {h.lower() for p in signed_headers_params for h in p.split(";")} for header in headers: - if header in self.BLOCKED_REQUEST_HEADERS: + if header.lower() in self.BLOCKED_REQUEST_HEADERS: continue - if str(header).lower().startswith('x-amz-') and header not in signed_headers: + if str(header).lower().startswith('x-amz-') and header.lower() not in signed_headers: continue filtered_headers[header] = headers[header] return filtered_headers diff --git a/test/unit/test_s3_client.py b/test/unit/test_s3_client.py index e494ef6..64b552a 100644 --- a/test/unit/test_s3_client.py +++ b/test/unit/test_s3_client.py @@ -40,7 +40,8 @@ def test_s3_client_respond_back_with_data_default_status_code(self, mocked_boto3 mocked_boto3.client.return_value = mocked_client s3_client = S3Client(s3ol_access_point="Random_access_point") s3_client.respond_back_with_data(data='SomeData', - headers={"ContentRange": "0-100", "SomeRandomHeader": '0123', "Content-Length": "101"}, + headers={"ContentRange": "0-100", "SomeRandomHeader": '0123', + "Content-Length": "101"}, request_route="Route", request_token="q2334") mocked_client.write_get_object_response.assert_called_once_with(Body='SomeData', ContentLength=101, @@ -52,8 +53,10 @@ def test_s3_client_respond_back_with_data_partial_data(self, mocked_boto3): mocked_client = MagicMock() mocked_boto3.client.return_value = mocked_client s3_client = S3Client(s3ol_access_point="Random_access_point") - s3_client.respond_back_with_data(data='SomeData', headers={"Content-Range": "0-1200", "SomeRandomHeader": '0123'}, - request_route="Route", request_token="q2334", status_code=S3_STATUS_CODES.PARTIAL_CONTENT_206) + s3_client.respond_back_with_data(data='SomeData', + headers={"Content-Range": "0-1200", "SomeRandomHeader": '0123'}, + request_route="Route", request_token="q2334", + status_code=S3_STATUS_CODES.PARTIAL_CONTENT_206) mocked_client.write_get_object_response.assert_called_once_with(Body='SomeData', ContentRange="0-1200", RequestRoute='Route', RequestToken="q2334", @@ -64,7 +67,8 @@ def test_s3_client_respond_back_with_data_partial_data(self, mocked_boto3): def test_s3_client_download_file_from_presigned_url_200_ok(self, mocked_get): s3_client = S3Client(s3ol_access_point="Random_access_point") http_header = {'some-header': 'header-value'} - text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, http_header) + text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, + http_header) assert text == 'Test' assert response_http_headers == {'Content-Length': '4'} assert status_code == S3_STATUS_CODES.OK_200 @@ -75,7 +79,8 @@ def test_s3_client_download_file_from_presigned_url_200_ok(self, mocked_get): def test_s3_client_download_partial_file_from_presigned_url(self, mocked_get): s3_client = S3Client(s3ol_access_point="Random_access_point") http_header = {'some-header': 'header-value'} - text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, http_header) + text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, + http_header) assert text == 'Test' assert response_http_headers == {'Content-Length': '100'} assert status_code == S3_STATUS_CODES.PARTIAL_CONTENT_206 @@ -90,10 +95,12 @@ def test_s3_client_download_file_from_presigned_url_400_from_get(self, mocked_ge assert mocked_get.call_count == 5 @patch('clients.s3_client.requests.Session.get', - side_effect=lambda *args, **kwargs: MockResponse(b'A' * (11 * 1024 * 1024), 200, {'Content-Length': str(11 * 1024 * 1024)})) + side_effect=lambda *args, **kwargs: MockResponse(b'A' * (11 * 1024 * 1024), 200, + {'Content-Length': str(11 * 1024 * 1024)})) def test_s3_client_download_file_from_presigned_url_file_size_limit_exceeded(self, mocked_get): s3_client = S3Client(s3ol_access_point="Random_access_point") - self.assertRaises(FileSizeLimitExceededException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) + self.assertRaises(FileSizeLimitExceededException, s3_client.download_file_from_presigned_url, + PRESIGNED_URL_TEST, {}) mocked_get.assert_called_once() @@ -139,3 +146,16 @@ def test_s3_client_download_file_from_presigned_retry(self, mocked_get): self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {}) assert mocked_get.call_count == 5 + + def test_s3_client_filter_request_headers(self): + s3_client = S3Client(s3ol_access_point="Random_access_point") + url = "https://dummy/myfile.txt?X-Amz-SignedHeaders=host%3Bx-amz-to-include" + + filtered_headers = s3_client._filter_request_headers( + url, + { + "Host":"otherhost", # will be excluded since Host is on the denylist + "X-Amz-To-Include":"foo", # will be included since it is signed + "X-Amz-To-Exclude":"bar", # will be excluded since it starts with X-Amz- and is not signed + "X-Irrelevant":"baz"}) # will be included by default + assert filtered_headers == {"X-Amz-To-Include":"foo", "X-Irrelevant":"baz"} From 920ef08e85f8ba76f34257d25646cf6a22d910a1 Mon Sep 17 00:00:00 2001 From: frij Date: Thu, 21 Aug 2025 19:04:21 +0000 Subject: [PATCH 2/2] remove unnecessary import --- src/clients/s3_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/clients/s3_client.py b/src/clients/s3_client.py index d1981e4..08d9789 100644 --- a/src/clients/s3_client.py +++ b/src/clients/s3_client.py @@ -1,5 +1,4 @@ """Client wrapper over aws services.""" -import json import re import time import urllib