Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/switchyard/upstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from urllib.parse import urlencode

import httpx
from loguru import logger
Expand All @@ -20,6 +21,12 @@
])


def _append_digest(location: str, digest: str) -> str:
"""Append digest query param, preserving any existing query string (e.g. _state)."""
sep = "&" if "?" in location else "?"
return f"{location}{sep}{urlencode({'digest': digest})}"


class UpstreamClient:
def __init__(self, base_url: str) -> None:
self._base_url = base_url.rstrip("/")
Expand Down Expand Up @@ -57,12 +64,11 @@ async def push_blob(self, name: str, digest: str, data: bytes) -> None:

resp = await self._client.post(f"/v2/{name}/blobs/uploads/")
resp.raise_for_status()
location = resp.headers["Location"]
location = _append_digest(resp.headers["Location"], digest)

resp = await self._client.put(
location,
content=data,
params={"digest": digest},
headers={"Content-Type": "application/octet-stream"},
)
resp.raise_for_status()
Expand All @@ -78,7 +84,7 @@ async def push_blob_streaming(

resp = await self._client.post(f"/v2/{name}/blobs/uploads/")
resp.raise_for_status()
location = resp.headers["Location"]
location = _append_digest(resp.headers["Location"], digest)

async def _body() -> AsyncIterator[bytes]:
async for chunk in stream:
Expand All @@ -87,7 +93,6 @@ async def _body() -> AsyncIterator[bytes]:
resp = await self._client.put(
location,
content=_body(),
params={"digest": digest},
headers={"Content-Type": "application/octet-stream"},
)
resp.raise_for_status()
Expand Down
72 changes: 72 additions & 0 deletions tests/test_upstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
from collections.abc import AsyncIterator

import respx
from httpx import Response
Expand Down Expand Up @@ -62,6 +63,77 @@ async def test_push_blob_uploads() -> None:
await client.close()


@respx.mock
async def test_push_blob_preserves_location_query_params() -> None:
"""Registry Location may include a _state token; digest must be appended, not replace it."""
location = f"{BASE}/v2/myapp/blobs/uploads/uuid-1?_state=signed-token"
respx.head(f"{BASE}/v2/myapp/blobs/sha256:abc123").mock(
return_value=Response(404)
)
respx.post(f"{BASE}/v2/myapp/blobs/uploads/").mock(
return_value=Response(202, headers={"Location": location})
)
put_route = respx.put(url__regex=r".*/v2/myapp/blobs/uploads/uuid-1.*").mock(
return_value=Response(201)
)

client = UpstreamClient(BASE)
await client.push_blob("myapp", "sha256:abc123", b"blob data")

put_url = str(put_route.calls[0].request.url)
assert "_state=signed-token" in put_url, f"_state param lost: {put_url}"
assert "digest=sha256" in put_url, f"digest param missing: {put_url}"
await client.close()


@respx.mock
async def test_push_blob_streaming_uploads() -> None:
respx.head(f"{BASE}/v2/myapp/blobs/sha256:abc123").mock(
return_value=Response(404)
)
respx.post(f"{BASE}/v2/myapp/blobs/uploads/").mock(
return_value=Response(202, headers={"Location": f"{BASE}/v2/myapp/blobs/uploads/uuid-1"})
)
put_route = respx.put(f"{BASE}/v2/myapp/blobs/uploads/uuid-1").mock(
return_value=Response(201)
)

async def blob_stream() -> AsyncIterator[bytes]:
yield b"blob data"

client = UpstreamClient(BASE)
await client.push_blob_streaming("myapp", "sha256:abc123", blob_stream())
assert respx.calls.call_count == 3 # HEAD + POST + PUT
assert put_route.calls[0].request.headers["content-type"] == "application/octet-stream"
await client.close()


@respx.mock
async def test_push_blob_streaming_preserves_location_query_params() -> None:
"""Registry Location may include a _state token; digest must be appended, not replace it."""
location = f"{BASE}/v2/myapp/blobs/uploads/uuid-1?_state=signed-token"
respx.head(f"{BASE}/v2/myapp/blobs/sha256:abc123").mock(
return_value=Response(404)
)
respx.post(f"{BASE}/v2/myapp/blobs/uploads/").mock(
return_value=Response(202, headers={"Location": location})
)
put_route = respx.put(url__regex=r".*/v2/myapp/blobs/uploads/uuid-1.*").mock(
return_value=Response(201)
)

async def blob_stream() -> AsyncIterator[bytes]:
yield b"blob data"

client = UpstreamClient(BASE)
await client.push_blob_streaming("myapp", "sha256:abc123", blob_stream())

put_url = str(put_route.calls[0].request.url)
assert "_state=signed-token" in put_url, f"_state param lost: {put_url}"
assert "digest=sha256" in put_url, f"digest param missing: {put_url}"
await client.close()


@respx.mock
async def test_push_manifest() -> None:
manifest = json.dumps({
Expand Down
Loading