Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/clients/s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class S3Client:
}
# Restricted http headers that can't be sent to s3 as part of downloading object using preseigned url
# Adding these headers can causes a mismatch with Sigv4 signature
BLOCKED_REQUEST_HEADERS = ("Host")
BLOCKED_REQUEST_HEADERS = ("host")

def __init__(self, s3ol_access_point: str, max_file_supported=DOCUMENT_MAX_SIZE):
self.max_file_supported = max_file_supported
Expand Down Expand Up @@ -128,12 +128,13 @@ def _filter_request_headers(self, presigned_url, headers={}):
filtered_headers = {}
parsed_url = urllib.parse.urlparse(presigned_url)
parsed_query_params = urllib.parse.parse_qs(parsed_url.query)
signed_headers = set(parsed_query_params.get('X-Amz-SignedHeaders', []))
signed_headers_params = parsed_query_params.get('X-Amz-SignedHeaders', [])
signed_headers = {h.lower() for p in signed_headers_params for h in p.split(";")}

for header in headers:
if header in self.BLOCKED_REQUEST_HEADERS:
if header.lower() in self.BLOCKED_REQUEST_HEADERS:
continue
if str(header).lower().startswith('x-amz-') and header not in signed_headers:
if str(header).lower().startswith('x-amz-') and header.lower() not in signed_headers:
continue
filtered_headers[header] = headers[header]
return filtered_headers
Expand Down
34 changes: 27 additions & 7 deletions test/unit/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def test_s3_client_respond_back_with_data_default_status_code(self, mocked_boto3
mocked_boto3.client.return_value = mocked_client
s3_client = S3Client(s3ol_access_point="Random_access_point")
s3_client.respond_back_with_data(data='SomeData',
headers={"ContentRange": "0-100", "SomeRandomHeader": '0123', "Content-Length": "101"},
headers={"ContentRange": "0-100", "SomeRandomHeader": '0123',
"Content-Length": "101"},
request_route="Route", request_token="q2334")

mocked_client.write_get_object_response.assert_called_once_with(Body='SomeData', ContentLength=101,
Expand All @@ -52,8 +53,10 @@ def test_s3_client_respond_back_with_data_partial_data(self, mocked_boto3):
mocked_client = MagicMock()
mocked_boto3.client.return_value = mocked_client
s3_client = S3Client(s3ol_access_point="Random_access_point")
s3_client.respond_back_with_data(data='SomeData', headers={"Content-Range": "0-1200", "SomeRandomHeader": '0123'},
request_route="Route", request_token="q2334", status_code=S3_STATUS_CODES.PARTIAL_CONTENT_206)
s3_client.respond_back_with_data(data='SomeData',
headers={"Content-Range": "0-1200", "SomeRandomHeader": '0123'},
request_route="Route", request_token="q2334",
status_code=S3_STATUS_CODES.PARTIAL_CONTENT_206)

mocked_client.write_get_object_response.assert_called_once_with(Body='SomeData', ContentRange="0-1200",
RequestRoute='Route', RequestToken="q2334",
Expand All @@ -64,7 +67,8 @@ def test_s3_client_respond_back_with_data_partial_data(self, mocked_boto3):
def test_s3_client_download_file_from_presigned_url_200_ok(self, mocked_get):
s3_client = S3Client(s3ol_access_point="Random_access_point")
http_header = {'some-header': 'header-value'}
text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, http_header)
text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST,
http_header)
assert text == 'Test'
assert response_http_headers == {'Content-Length': '4'}
assert status_code == S3_STATUS_CODES.OK_200
Expand All @@ -75,7 +79,8 @@ def test_s3_client_download_file_from_presigned_url_200_ok(self, mocked_get):
def test_s3_client_download_partial_file_from_presigned_url(self, mocked_get):
s3_client = S3Client(s3ol_access_point="Random_access_point")
http_header = {'some-header': 'header-value'}
text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST, http_header)
text, response_http_headers, status_code = s3_client.download_file_from_presigned_url(PRESIGNED_URL_TEST,
http_header)
assert text == 'Test'
assert response_http_headers == {'Content-Length': '100'}
assert status_code == S3_STATUS_CODES.PARTIAL_CONTENT_206
Expand All @@ -90,10 +95,12 @@ def test_s3_client_download_file_from_presigned_url_400_from_get(self, mocked_ge
assert mocked_get.call_count == 5

@patch('clients.s3_client.requests.Session.get',
side_effect=lambda *args, **kwargs: MockResponse(b'A' * (11 * 1024 * 1024), 200, {'Content-Length': str(11 * 1024 * 1024)}))
side_effect=lambda *args, **kwargs: MockResponse(b'A' * (11 * 1024 * 1024), 200,
{'Content-Length': str(11 * 1024 * 1024)}))
def test_s3_client_download_file_from_presigned_url_file_size_limit_exceeded(self, mocked_get):
s3_client = S3Client(s3ol_access_point="Random_access_point")
self.assertRaises(FileSizeLimitExceededException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})
self.assertRaises(FileSizeLimitExceededException, s3_client.download_file_from_presigned_url,
PRESIGNED_URL_TEST, {})

mocked_get.assert_called_once()

Expand Down Expand Up @@ -139,3 +146,16 @@ def test_s3_client_download_file_from_presigned_retry(self, mocked_get):
self.assertRaises(S3DownloadException, s3_client.download_file_from_presigned_url, PRESIGNED_URL_TEST, {})

assert mocked_get.call_count == 5

def test_s3_client_filter_request_headers(self):
s3_client = S3Client(s3ol_access_point="Random_access_point")
url = "https://dummy/myfile.txt?X-Amz-SignedHeaders=host%3Bx-amz-to-include"

filtered_headers = s3_client._filter_request_headers(
url,
{
"Host":"otherhost", # will be excluded since Host is on the denylist
"X-Amz-To-Include":"foo", # will be included since it is signed
"X-Amz-To-Exclude":"bar", # will be excluded since it starts with X-Amz- and is not signed
"X-Irrelevant":"baz"}) # will be included by default
assert filtered_headers == {"X-Amz-To-Include":"foo", "X-Irrelevant":"baz"}