Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ requires-python = ">=3.14"
dependencies = [
"starlette>=1.0.0",
"granian[uvloop]>=2.7.2",
"httpx[brotli,zstd,http2]>=0.28.1",
"loguru>=0.7.3",
"pydantic-settings>=2.13.1",
"python-dxf>=12.1.1",
]

[dependency-groups]
dev = [
"httpx>=0.28.1",
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
"respx>=0.22.0",
"responses>=0.26.0",
"ruff>=0.15.8",
]

Expand Down
2 changes: 1 addition & 1 deletion src/switchyard/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def main() -> None:
server = Granian(
target="switchyard.app:app",
address="0.0.0.0",
address="::",
port=PORT,
interface=Interfaces.ASGI,
loop=Loops.uvloop,
Expand Down
179 changes: 88 additions & 91 deletions src/switchyard/upstream.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,128 @@
# ABOUTME: HTTP client for communicating with the upstream Docker registry.
# ABOUTME: Handles pushing blobs/manifests and proxying pull requests.
# ABOUTME: Client for communicating with the upstream Docker registry.
# ABOUTME: Uses python-dxf for registry v2 operations, wrapped with asyncio.to_thread.
from __future__ import annotations

import asyncio
import json
from collections.abc import AsyncIterator

import httpx
import requests.exceptions
from dxf import DXF
from loguru import logger

log = logger.bind(component="upstream")

CHUNK_SIZE = 1024 * 1024 # 1MB
_SENTINEL = object()


class UpstreamClient:
def __init__(self, base_url: str) -> None:
self._base_url = base_url.rstrip("/")
self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=httpx.Timeout(connect=10, read=300, write=300, pool=10),
follow_redirects=True,
http2=True,
)
if "://" in self._base_url:
self._insecure = self._base_url.startswith("http://")
self._host = self._base_url.split("://", 1)[1]
else:
self._host = self._base_url
self._insecure = False
self._dxf_cache: dict[str, DXF] = {}

def _get_dxf(self, repo: str) -> DXF:
"""Get or create a DXF instance for the given repo."""
if repo not in self._dxf_cache:
dxf = DXF(
host=self._host,
repo=repo,
insecure=self._insecure,
timeout=300,
)
dxf.__enter__()
self._dxf_cache[repo] = dxf
return self._dxf_cache[repo]

async def close(self) -> None:
await self._client.aclose()
for dxf in self._dxf_cache.values():
dxf.__exit__(None, None, None)
self._dxf_cache.clear()

# -- Blob operations --

async def check_blob(self, name: str, digest: str) -> bool:
resp = await self._client.head(f"/v2/{name}/blobs/{digest}")
return resp.status_code == 200
dxf = self._get_dxf(name)

def _check() -> bool:
try:
dxf.blob_size(digest)
return True
except requests.exceptions.HTTPError:
return False

return await asyncio.to_thread(_check)

async def pull_blob(self, name: str, digest: str) -> AsyncIterator[bytes]:
async with self._client.stream("GET", f"/v2/{name}/blobs/{digest}") as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes(chunk_size=1024 * 1024):
yield chunk
dxf = self._get_dxf(name)
chunks = await asyncio.to_thread(dxf.pull_blob, digest, chunk_size=CHUNK_SIZE)
while True:
chunk = await asyncio.to_thread(next, chunks, _SENTINEL)
if chunk is _SENTINEL:
break
yield chunk

async def push_blob(self, name: str, digest: str, data: bytes) -> None:
"""Push a blob using monolithic upload (POST + PUT)."""
# Check if blob already exists
if await self.check_blob(name, digest):
log.debug("Blob {} already exists upstream, skipping", digest[:19])
return

# Initiate upload
resp = await self._client.post(f"/v2/{name}/blobs/uploads/")
resp.raise_for_status()
location = resp.headers["Location"]

# Complete with monolithic PUT
if location.startswith("/"):
upload_url = location
else:
upload_url = location

resp = await self._client.put(
upload_url,
content=data,
params={"digest": digest},
headers={"Content-Type": "application/octet-stream"},
)
resp.raise_for_status()
"""Push a blob using monolithic upload."""
dxf = self._get_dxf(name)
await asyncio.to_thread(dxf.push_blob, data=iter([data]), digest=digest)
log.debug("Pushed blob {} upstream", digest[:19])

async def push_blob_streaming(
self, name: str, digest: str, stream: AsyncIterator[bytes]
) -> None:
"""Push a blob by streaming from local storage."""
if await self.check_blob(name, digest):
log.debug("Blob {} already exists upstream, skipping", digest[:19])
return

# Initiate upload
resp = await self._client.post(f"/v2/{name}/blobs/uploads/")
resp.raise_for_status()
location = resp.headers["Location"]

# Stream the blob content as a monolithic PUT
async def _body() -> AsyncIterator[bytes]:
async for chunk in stream:
yield chunk

resp = await self._client.put(
location,
content=_body(),
params={"digest": digest},
headers={"Content-Type": "application/octet-stream"},
)
resp.raise_for_status()
"""Push a blob by collecting the stream and uploading."""
chunks = [chunk async for chunk in stream]
dxf = self._get_dxf(name)
await asyncio.to_thread(dxf.push_blob, data=iter(chunks), digest=digest)
log.debug("Pushed blob {} upstream (streamed)", digest[:19])

# -- Manifest operations --

async def check_manifest(self, name: str, reference: str) -> bool:
resp = await self._client.head(f"/v2/{name}/manifests/{reference}")
return resp.status_code == 200
dxf = self._get_dxf(name)

def _check() -> bool:
try:
dxf.head_manifest_and_response(reference)
return True
except requests.exceptions.HTTPError:
return False

return await asyncio.to_thread(_check)

async def pull_manifest(self, name: str, reference: str) -> tuple[bytes, str, str] | None:
"""Pull a manifest. Returns (body, content_type, digest) or None."""
resp = await self._client.get(
f"/v2/{name}/manifests/{reference}",
headers={
"Accept": ", ".join(
[
"application/vnd.docker.distribution.manifest.v2+json",
"application/vnd.oci.image.manifest.v1+json",
"application/vnd.docker.distribution.manifest.list.v2+json",
"application/vnd.oci.image.index.v1+json",
]
),
},
)
if resp.status_code == 404:
return None
resp.raise_for_status()

body = resp.content
content_type = resp.headers.get("content-type", "application/json")
digest = resp.headers.get("docker-content-digest", "")
return body, content_type, digest
dxf = self._get_dxf(name)

def _pull() -> tuple[bytes, str, str] | None:
try:
manifest_str, resp = dxf.get_manifest_and_response(reference)
body = manifest_str.encode() if isinstance(manifest_str, str) else manifest_str
content_type = resp.headers.get("content-type", "application/json")
digest = resp.headers.get("docker-content-digest", "")
return body, content_type, digest
except requests.exceptions.HTTPError as exc:
if exc.response is not None and exc.response.status_code == 404:
return None
raise

return await asyncio.to_thread(_pull)

async def push_manifest(
self, name: str, reference: str, data: bytes, content_type: str
) -> None:
resp = await self._client.put(
f"/v2/{name}/manifests/{reference}",
content=data,
headers={"Content-Type": content_type},
)
resp.raise_for_status()
dxf = self._get_dxf(name)
manifest_json = data.decode() if isinstance(data, bytes) else data
parsed = json.loads(manifest_json)
if "mediaType" not in parsed:
parsed["mediaType"] = content_type
manifest_json = json.dumps(parsed)
await asyncio.to_thread(dxf.set_manifest, reference, manifest_json)
log.debug("Pushed manifest {name}:{ref} upstream", name=name, ref=reference)
76 changes: 51 additions & 25 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# ABOUTME: Simulates a complete Docker image push, sync to upstream, and pull.
from __future__ import annotations

import asyncio
import hashlib
import json
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path

import httpx
import respx
import responses
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.testclient import TestClient
Expand Down Expand Up @@ -81,7 +81,7 @@ def _make_manifest(config_digest: str, layer_digests: list[str]) -> bytes:
class TestFullPushSyncCycle:
"""Push an image locally, sync it to a mock upstream, verify everything."""

@respx.mock
@responses.activate
def test_push_and_sync(self, tmp_path: Path) -> None:
storage = Storage(str(tmp_path))
queue = SyncQueue(str(tmp_path))
Expand Down Expand Up @@ -127,14 +127,41 @@ def test_push_and_sync(self, tmp_path: Path) -> None:
assert len(markers) == 1

# 5. Mock upstream and run sync
respx.head(url__regex=r".*/blobs/.*").mock(return_value=httpx.Response(404))
respx.post(url__regex=r".*/blobs/uploads/").mock(
return_value=httpx.Response(202, headers={"Location": "/v2/myapp/blobs/uploads/u1"})
# HEAD checks for both blobs (config + layer)
responses.add(
responses.HEAD,
url="https://central:5000/v2/myapp/blobs/" + config_digest,
status=404,
)
responses.add(
responses.HEAD,
url="https://central:5000/v2/myapp/blobs/" + layer_digest,
status=404,
)
# POST to initiate upload for each blob
responses.add(
responses.POST,
url="https://central:5000/v2/myapp/blobs/uploads/",
status=202,
headers={"Location": "https://central:5000/v2/myapp/blobs/uploads/u1"},
)
responses.add(
responses.POST,
url="https://central:5000/v2/myapp/blobs/uploads/",
status=202,
headers={"Location": "https://central:5000/v2/myapp/blobs/uploads/u2"},
)
# PUT to complete each blob upload
responses.add(
responses.PUT, url="https://central:5000/v2/myapp/blobs/uploads/u1", status=201
)
responses.add(
responses.PUT, url="https://central:5000/v2/myapp/blobs/uploads/u2", status=201
)
# PUT manifest
responses.add(
responses.PUT, url="https://central:5000/v2/myapp/manifests/latest", status=201
)
respx.put(url__regex=r".*/blobs/uploads/.*").mock(return_value=httpx.Response(201))
respx.put(url__regex=r".*/manifests/.*").mock(return_value=httpx.Response(201))

import asyncio

async def _run_sync() -> None:
await storage.init()
Expand All @@ -147,9 +174,9 @@ async def _run_sync() -> None:
asyncio.run(_run_sync())

# Verify: 2 blob HEAD checks + 2 blob uploads (POST+PUT each) + 1 manifest PUT
head_calls = [c for c in respx.calls if c.request.method == "HEAD"]
post_calls = [c for c in respx.calls if c.request.method == "POST"]
put_calls = [c for c in respx.calls if c.request.method == "PUT"]
head_calls = [c for c in responses.calls if c.request.method == "HEAD"]
post_calls = [c for c in responses.calls if c.request.method == "POST"]
put_calls = [c for c in responses.calls if c.request.method == "PUT"]
assert len(head_calls) == 2 # config + layer
assert len(post_calls) == 2 # config + layer upload initiation
assert len(put_calls) == 3 # config + layer upload completion + manifest
Expand All @@ -162,7 +189,7 @@ async def _run_sync() -> None:
class TestPullProxyFromUpstream:
"""Pull an image that only exists on the upstream registry."""

@respx.mock
@responses.activate
def test_pull_manifest_from_upstream(self, tmp_path: Path) -> None:
storage = Storage(str(tmp_path))
queue = SyncQueue(str(tmp_path))
Expand All @@ -174,15 +201,14 @@ def test_pull_manifest_from_upstream(self, tmp_path: Path) -> None:
manifest_ct = "application/vnd.docker.distribution.manifest.v2+json"
manifest_digest = f"sha256:{hashlib.sha256(manifest_body).hexdigest()}"

respx.get("https://central:5000/v2/remote-app/manifests/latest").mock(
return_value=httpx.Response(
200,
content=manifest_body,
headers={
"Content-Type": manifest_ct,
"Docker-Content-Digest": manifest_digest,
},
)
responses.get(
"https://central:5000/v2/remote-app/manifests/latest",
body=manifest_body,
status=200,
headers={
"Content-Type": manifest_ct,
"Docker-Content-Digest": manifest_digest,
},
)

with TestClient(app) as client:
Expand All @@ -193,8 +219,8 @@ def test_pull_manifest_from_upstream(self, tmp_path: Path) -> None:
assert resp.headers["Docker-Content-Digest"] == manifest_digest

# Second pull should be served from local cache (no more upstream calls)
respx.reset()
resp = client.get("/v2/remote-app/manifests/latest")
assert resp.status_code == 200
assert resp.content == manifest_body
assert len(respx.calls) == 0 # served from cache
# Only 1 call to upstream (first pull), second served from cache
assert len(responses.calls) == 1
Loading
Loading