diff --git a/src/switchyard/upstream.py b/src/switchyard/upstream.py index c6cee9d..987fdd6 100644 --- a/src/switchyard/upstream.py +++ b/src/switchyard/upstream.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import AsyncIterator +from urllib.parse import urlencode import httpx from loguru import logger @@ -20,6 +21,12 @@ ]) +def _append_digest(location: str, digest: str) -> str: + """Append digest query param, preserving any existing query string (e.g. _state).""" + sep = "&" if "?" in location else "?" + return f"{location}{sep}{urlencode({'digest': digest})}" + + class UpstreamClient: def __init__(self, base_url: str) -> None: self._base_url = base_url.rstrip("/") @@ -57,12 +64,11 @@ async def push_blob(self, name: str, digest: str, data: bytes) -> None: resp = await self._client.post(f"/v2/{name}/blobs/uploads/") resp.raise_for_status() - location = resp.headers["Location"] + location = _append_digest(resp.headers["Location"], digest) resp = await self._client.put( location, content=data, - params={"digest": digest}, headers={"Content-Type": "application/octet-stream"}, ) resp.raise_for_status() @@ -78,7 +84,7 @@ async def push_blob_streaming( resp = await self._client.post(f"/v2/{name}/blobs/uploads/") resp.raise_for_status() - location = resp.headers["Location"] + location = _append_digest(resp.headers["Location"], digest) async def _body() -> AsyncIterator[bytes]: async for chunk in stream: @@ -87,7 +93,6 @@ async def _body() -> AsyncIterator[bytes]: resp = await self._client.put( location, content=_body(), - params={"digest": digest}, headers={"Content-Type": "application/octet-stream"}, ) resp.raise_for_status() diff --git a/tests/test_upstream.py b/tests/test_upstream.py index a2a4997..34e1d30 100644 --- a/tests/test_upstream.py +++ b/tests/test_upstream.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from collections.abc import AsyncIterator import respx from httpx import Response @@ -62,6 +63,77 @@ async def test_push_blob_uploads() -> None: await client.close() +@respx.mock +async def test_push_blob_preserves_location_query_params() -> None: + """Registry Location may include a _state token; digest must be appended, not replace it.""" + location = f"{BASE}/v2/myapp/blobs/uploads/uuid-1?_state=signed-token" + respx.head(f"{BASE}/v2/myapp/blobs/sha256:abc123").mock( + return_value=Response(404) + ) + respx.post(f"{BASE}/v2/myapp/blobs/uploads/").mock( + return_value=Response(202, headers={"Location": location}) + ) + put_route = respx.put(url__regex=r".*/v2/myapp/blobs/uploads/uuid-1.*").mock( + return_value=Response(201) + ) + + client = UpstreamClient(BASE) + await client.push_blob("myapp", "sha256:abc123", b"blob data") + + put_url = str(put_route.calls[0].request.url) + assert "_state=signed-token" in put_url, f"_state param lost: {put_url}" + assert "digest=sha256" in put_url, f"digest param missing: {put_url}" + await client.close() + + +@respx.mock +async def test_push_blob_streaming_uploads() -> None: + respx.head(f"{BASE}/v2/myapp/blobs/sha256:abc123").mock( + return_value=Response(404) + ) + respx.post(f"{BASE}/v2/myapp/blobs/uploads/").mock( + return_value=Response(202, headers={"Location": f"{BASE}/v2/myapp/blobs/uploads/uuid-1"}) + ) + put_route = respx.put(f"{BASE}/v2/myapp/blobs/uploads/uuid-1").mock( + return_value=Response(201) + ) + + async def blob_stream() -> AsyncIterator[bytes]: + yield b"blob data" + + client = UpstreamClient(BASE) + await client.push_blob_streaming("myapp", "sha256:abc123", blob_stream()) + assert respx.calls.call_count == 3 # HEAD + POST + PUT + assert put_route.calls[0].request.headers["content-type"] == "application/octet-stream" + await client.close() + + +@respx.mock +async def test_push_blob_streaming_preserves_location_query_params() -> None: + """Registry Location may include a _state token; digest must be appended, not replace it.""" + location = f"{BASE}/v2/myapp/blobs/uploads/uuid-1?_state=signed-token" + respx.head(f"{BASE}/v2/myapp/blobs/sha256:abc123").mock( + return_value=Response(404) + ) + respx.post(f"{BASE}/v2/myapp/blobs/uploads/").mock( + return_value=Response(202, headers={"Location": location}) + ) + put_route = respx.put(url__regex=r".*/v2/myapp/blobs/uploads/uuid-1.*").mock( + return_value=Response(201) + ) + + async def blob_stream() -> AsyncIterator[bytes]: + yield b"blob data" + + client = UpstreamClient(BASE) + await client.push_blob_streaming("myapp", "sha256:abc123", blob_stream()) + + put_url = str(put_route.calls[0].request.url) + assert "_state=signed-token" in put_url, f"_state param lost: {put_url}" + assert "digest=sha256" in put_url, f"digest param missing: {put_url}" + await client.close() + + @respx.mock async def test_push_manifest() -> None: manifest = json.dumps({