Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@ requires-python = ">=3.14"
dependencies = [
"starlette>=1.0.0",
"granian[uvloop]>=2.7.2",
"httpx>=0.28.1",
"loguru>=0.7.3",
"pydantic-settings>=2.13.1",
"python-dxf>=12.1.1",
]

[dependency-groups]
dev = [
"httpx>=0.28.1",
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
"responses>=0.26.0",
"respx>=0.22.0",
"ruff>=0.15.8",
]

Expand All @@ -27,7 +26,7 @@ asyncio_mode = "auto"
switchyard = "switchyard.__main__:main"

[build-system]
requires = ["uv_build>=0.7"]
requires = ["uv_build>=0.11.3,<0.12"]
build-backend = "uv_build"

[tool.ruff]
Expand Down
177 changes: 90 additions & 87 deletions src/switchyard/upstream.py
Original file line number Diff line number Diff line change
@@ -1,128 +1,131 @@
# ABOUTME: Client for communicating with the upstream Docker registry.
# ABOUTME: Uses python-dxf for registry v2 operations, wrapped with asyncio.to_thread.
# ABOUTME: Uses httpx for async registry v2 operations with correct Content-Type handling.
from __future__ import annotations

import asyncio
import json
from collections.abc import AsyncIterator

import requests.exceptions
from dxf import DXF
import httpx
from loguru import logger

log = logger.bind(component="upstream")

CHUNK_SIZE = 1024 * 1024 # 1MB
_SENTINEL = object()

# Accept header for manifest pulls, covering both Docker and OCI formats.
_MANIFEST_ACCEPT = ", ".join([
"application/vnd.docker.distribution.manifest.v2+json",
"application/vnd.oci.image.manifest.v1+json",
"application/vnd.docker.distribution.manifest.list.v2+json",
"application/vnd.oci.image.index.v1+json",
])


class UpstreamClient:
def __init__(self, base_url: str) -> None:
self._base_url = base_url.rstrip("/")
if "://" in self._base_url:
self._insecure = self._base_url.startswith("http://")
self._host = self._base_url.split("://", 1)[1]
else:
self._host = self._base_url
self._insecure = False
self._dxf_cache: dict[str, DXF] = {}

def _get_dxf(self, repo: str) -> DXF:
"""Get or create a DXF instance for the given repo."""
if repo not in self._dxf_cache:
dxf = DXF(
host=self._host,
repo=repo,
insecure=self._insecure,
timeout=300,
)
dxf.__enter__()
self._dxf_cache[repo] = dxf
return self._dxf_cache[repo]
# Don't set follow_redirects globally: HTTP spec allows 301/302
# to change POST/PUT to GET, which breaks upload sessions.
self._client = httpx.AsyncClient(
base_url=self._base_url,
timeout=httpx.Timeout(connect=10, read=300, write=300, pool=10),
)

async def close(self) -> None:
for dxf in self._dxf_cache.values():
dxf.__exit__(None, None, None)
self._dxf_cache.clear()
await self._client.aclose()

# -- Blob operations --

async def check_blob(self, name: str, digest: str) -> bool:
dxf = self._get_dxf(name)

def _check() -> bool:
try:
dxf.blob_size(digest)
return True
except requests.exceptions.HTTPError:
return False

return await asyncio.to_thread(_check)
resp = await self._client.head(
f"/v2/{name}/blobs/{digest}", follow_redirects=True
)
return resp.status_code == 200

async def pull_blob(self, name: str, digest: str) -> AsyncIterator[bytes]:
dxf = self._get_dxf(name)
chunks = await asyncio.to_thread(dxf.pull_blob, digest, chunk_size=CHUNK_SIZE)
while True:
chunk = await asyncio.to_thread(next, chunks, _SENTINEL)
if chunk is _SENTINEL:
break
yield chunk
async with self._client.stream(
"GET", f"/v2/{name}/blobs/{digest}", follow_redirects=True
) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes(chunk_size=CHUNK_SIZE):
yield chunk

async def push_blob(self, name: str, digest: str, data: bytes) -> None:
"""Push a blob using monolithic upload."""
dxf = self._get_dxf(name)
await asyncio.to_thread(dxf.push_blob, data=iter([data]), digest=digest)
"""Push a blob using monolithic upload (POST + PUT)."""
if await self.check_blob(name, digest):
log.debug("Blob {} already exists upstream, skipping", digest[:19])
return

resp = await self._client.post(f"/v2/{name}/blobs/uploads/")
resp.raise_for_status()
location = resp.headers["Location"]

resp = await self._client.put(
location,
content=data,
params={"digest": digest},
headers={"Content-Type": "application/octet-stream"},
)
resp.raise_for_status()
log.debug("Pushed blob {} upstream", digest[:19])

async def push_blob_streaming(
self, name: str, digest: str, stream: AsyncIterator[bytes]
) -> None:
"""Push a blob by collecting the stream and uploading."""
chunks = [chunk async for chunk in stream]
dxf = self._get_dxf(name)
await asyncio.to_thread(dxf.push_blob, data=iter(chunks), digest=digest)
"""Push a blob by streaming from local storage."""
if await self.check_blob(name, digest):
log.debug("Blob {} already exists upstream, skipping", digest[:19])
return

resp = await self._client.post(f"/v2/{name}/blobs/uploads/")
resp.raise_for_status()
location = resp.headers["Location"]

async def _body() -> AsyncIterator[bytes]:
async for chunk in stream:
yield chunk

resp = await self._client.put(
location,
content=_body(),
params={"digest": digest},
headers={"Content-Type": "application/octet-stream"},
)
resp.raise_for_status()
log.debug("Pushed blob {} upstream (streamed)", digest[:19])

# -- Manifest operations --

async def check_manifest(self, name: str, reference: str) -> bool:
dxf = self._get_dxf(name)

def _check() -> bool:
try:
dxf.head_manifest_and_response(reference)
return True
except requests.exceptions.HTTPError:
return False

return await asyncio.to_thread(_check)

async def pull_manifest(self, name: str, reference: str) -> tuple[bytes, str, str] | None:
resp = await self._client.head(
f"/v2/{name}/manifests/{reference}", follow_redirects=True
)
return resp.status_code == 200

async def pull_manifest(
self, name: str, reference: str
) -> tuple[bytes, str, str] | None:
"""Pull a manifest. Returns (body, content_type, digest) or None."""
dxf = self._get_dxf(name)

def _pull() -> tuple[bytes, str, str] | None:
try:
manifest_str, resp = dxf.get_manifest_and_response(reference)
body = manifest_str.encode() if isinstance(manifest_str, str) else manifest_str
content_type = resp.headers.get("content-type", "application/json")
digest = resp.headers.get("docker-content-digest", "")
return body, content_type, digest
except requests.exceptions.HTTPError as exc:
if exc.response is not None and exc.response.status_code == 404:
return None
raise

return await asyncio.to_thread(_pull)
resp = await self._client.get(
f"/v2/{name}/manifests/{reference}",
headers={"Accept": _MANIFEST_ACCEPT},
follow_redirects=True,
)
if resp.status_code == 404:
return None
resp.raise_for_status()

body = resp.content
content_type = resp.headers.get("content-type", "application/json")
digest = resp.headers.get("docker-content-digest", "")
return body, content_type, digest

async def push_manifest(
self, name: str, reference: str, data: bytes, content_type: str
) -> None:
dxf = self._get_dxf(name)
manifest_json = data.decode() if isinstance(data, bytes) else data
parsed = json.loads(manifest_json)
if "mediaType" not in parsed:
parsed["mediaType"] = content_type
manifest_json = json.dumps(parsed)
await asyncio.to_thread(dxf.set_manifest, reference, manifest_json)
resp = await self._client.put(
f"/v2/{name}/manifests/{reference}",
content=data,
headers={"Content-Type": content_type},
)
resp.raise_for_status()
log.debug("Pushed manifest {name}:{ref} upstream", name=name, ref=reference)
Loading
Loading