Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 209 additions & 0 deletions akd_ext/artifacts/stores/github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import os
import re
from contextlib import nullcontext
from datetime import datetime, timezone
from pathlib import PurePosixPath
from typing import Self

from github import Auth, Github, UnknownObjectException
from loguru import logger

from akd_ext.artifacts._base import Artifact, ArtifactStore


class GitHubArtifactStore(ArtifactStore[str]):
"""GitHub-backed artifact store.

Reads and writes artifacts inside a sub-tree of a GitHub repo via
the GitHub REST API (PyGithub). `root` can be either
"owner/repo[/path/within/repo]" or a github.com URL; in both cases
the first two segments identify the repo and any remaining segments
scope the store to that sub-tree. Slugs returned by the store are
relative to the sub-tree, matching `LocalArtifactStore`'s form — so
an agent's tool calls don't depend on which backend is in use.
"""

def __init__(
self,
root: str,
*,
branch: str = "main",
access_token: str | None = None,
github_client: Github | None = None,
index_file: str | None = "README.md",
supported_extensions: tuple[str, ...] = (".md",),
debug: bool = False,
) -> None:
"""Construct a GitHub-backed artifact store.

Args:
root: "owner/repo[/path]" or a github.com URL (e.g.
"https://github.com/NASA-IMPACT/akd/tree/main/agents/x/artifacts").
branch: Branch to read/write. Defaults to "main". A branch
embedded in a `.../tree/<branch>/...` URL takes precedence
over this kwarg.
access_token: GitHub personal access token. Falls back to
the `GITHUB_ACCESS_TOKEN` env var. Anonymous (public-read
only) if neither is set. Ignored if `github_client` is set.
github_client: Optional pre-built PyGithub client. When set,
the store uses it for all API calls and does NOT close
it on method exit (caller owns the lifetime). Useful for
bulk writes (reuses the HTTP session) and for tests
(inject a MagicMock). When None (default), a fresh
client is constructed per method call.
index_file: Directory-overview filename (defaults to README.md,
which GitHub renders at the directory level).
supported_extensions: Extensions to include; see ArtifactStore.
debug: Enable debug logging.

Raises:
ValueError: If `root` does not contain at least "owner/repo".
"""
super().__init__(
root=root,
index_file=index_file,
supported_extensions=supported_extensions,
debug=debug,
)
self.repo_name, self.path_prefix, url_branch = self._parse_root(root)
self.branch = url_branch or branch
self._token = access_token or os.getenv("GITHUB_ACCESS_TOKEN")
self._auth = Auth.Token(self._token) if self._token else None
self._github_client = github_client

def _client(self):
"""Context manager yielding a Github client.

If a `github_client` was injected via `__init__`, returns it
wrapped in `nullcontext` (no close on exit — caller owns it).
Otherwise constructs a fresh client that closes on context exit.
"""
if self._github_client is not None:
return nullcontext(self._github_client)
return Github(auth=self._auth, retry=None)

@staticmethod
def _parse_root(root: str) -> tuple[str, str, str | None]:
"""Parse `root` into (repo_name, path_prefix, branch_hint).

Accepts plain ``owner/repo[/path]`` or a github.com URL (with or
without scheme; with or without a ``/tree/<branch>/...`` segment).
`branch_hint` is the URL-extracted branch if present, else None.
"""
s = re.sub(r"^(?:https?://)?(?:www\.)?github\.com/?", "", root.strip()).rstrip("/")
parts = PurePosixPath(s).parts
if len(parts) < 2:
raise ValueError(f"GitHub root must be 'owner/repo[/path]' or a github.com URL, got {root!r}")
repo_name = f"{parts[0]}/{parts[1]}"
if len(parts) >= 4 and parts[2] in ("tree", "blob"):
return repo_name, "/".join(parts[4:]), parts[3]
return repo_name, "/".join(parts[2:]), None

async def load_artifacts(self) -> Self:
"""Load all available artifacts into the cache.

Returns:
Self, for fluent chaining.
"""
prefix = PurePosixPath(self.path_prefix) if self.path_prefix else None
with self._client() as gh:
repo = gh.get_repo(self.repo_name)
tree = repo.get_git_tree(self.branch, recursive=True)
for entry in tree.tree:
if entry.type != "blob":
continue
full = PurePosixPath(entry.path)
if prefix and not full.is_relative_to(prefix):
continue
slug = str(full.relative_to(prefix)) if prefix else entry.path
if not self._is_supported(slug):
continue
content_file = repo.get_contents(entry.path, ref=self.branch)
self[slug] = Artifact[str](
path=slug,
content=content_file.decoded_content.decode("utf-8"),
metadata={"sha": entry.sha},
updated_at=content_file.last_modified_datetime,
)
logger.info(
"[GitHubArtifactStore] loaded {} artifacts from {}",
len(self),
self.root,
)
return self

async def read_artifact(self, path: str) -> Artifact[str]:
"""Fetch an artifact by path from the cache.

Cache-only — does not hit the GitHub API. Call `refresh()` to
re-sync after external changes to the repo.

Args:
path: Path of the artifact to load (e.g. "contexts/role.md").

Returns:
The artifact including its content.

Raises:
KeyError: If the artifact is not in the cache — call
`load_artifacts()` or `refresh()` first.
"""
return self[path]

async def write_artifact(self, artifact: Artifact[str]) -> Artifact[str]:
"""Persist an artifact to the repo as a commit.

Args:
artifact: Artifact to write. Commit message comes from
`artifact.metadata["commit_message"]` if set, else
"Update {path}".

Returns:
Stored artifact with refreshed `metadata["sha"]` and
`updated_at` from the new commit.
"""
full_path = str(PurePosixPath(self.path_prefix) / artifact.path) if self.path_prefix else artifact.path
message = artifact.metadata.get("commit_message") or f"Update {artifact.path}"

cached = self._artifacts.get(artifact.path)
# Fast-path: skip write if content is unchanged vs. cache — avoids a
# spurious commit with identical tree.
if cached and cached.content == artifact.content:
logger.debug(
"[GitHubArtifactStore] no-op write (unchanged): {}",
artifact.path,
)
return cached
sha = cached.metadata.get("sha") if cached else None

with self._client() as gh:
repo = gh.get_repo(self.repo_name)

# Probe remote if we don't already know the sha
if sha is None:
try:
sha = repo.get_contents(full_path, ref=self.branch).sha
except UnknownObjectException:
pass # stays None → will create

common = dict(
path=full_path,
message=message,
content=artifact.content,
branch=self.branch,
)
result = repo.update_file(sha=sha, **common) if sha else repo.create_file(**common)

stored = artifact.model_copy(
update={
"metadata": {**artifact.metadata, "sha": result["content"].sha},
"updated_at": datetime.now(timezone.utc),
}
)
self[artifact.path] = stored
logger.info(
"[GitHubArtifactStore] wrote: {} (sha={})",
stored.path,
result["content"].sha,
)
return stored
79 changes: 79 additions & 0 deletions tests/artifacts/test_github.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Barebone mocked tests for GitHubArtifactStore: load, read, write.

Uses the `github_client` injection kwarg rather than patching — see
`GitHubArtifactStore.__init__` docstring.
"""

from unittest.mock import MagicMock

import pytest
from github import UnknownObjectException

from akd_ext.artifacts import Artifact
from akd_ext.artifacts.stores.github import GitHubArtifactStore


def _tree_entry(path: str, sha: str, type_: str = "blob"):
e = MagicMock()
e.path = path
e.sha = sha
e.type = type_
return e


def _content_file(content: str, sha: str):
cf = MagicMock()
cf.decoded_content = content.encode()
cf.sha = sha
cf.last_modified_datetime = None
return cf


@pytest.fixture
def mock_github():
"""Return (github_client_mock, repo_mock) for tests to configure."""
gh = MagicMock()
repo = gh.get_repo.return_value
return gh, repo


async def test_load(mock_github):
gh, repo = mock_github
repo.get_git_tree.return_value.tree = [
_tree_entry("index.md", "sha1"),
_tree_entry("data.json", "sha2"), # filtered by supported_extensions
]
repo.get_contents.return_value = _content_file("root", "sha1")

store = await GitHubArtifactStore("akd/test", github_client=gh).load_artifacts()

assert "index.md" in store
assert "data.json" not in store
assert store["index.md"].metadata["sha"] == "sha1"


async def test_read(mock_github):
gh, repo = mock_github
repo.get_git_tree.return_value.tree = []
store = await GitHubArtifactStore("akd/test", github_client=gh).load_artifacts()
store["foo.md"] = Artifact(path="foo.md", content="hi")

got = await store.read_artifact("foo.md")
assert got.content == "hi"


async def test_write(mock_github):
gh, repo = mock_github
repo.get_git_tree.return_value.tree = []
repo.get_contents.side_effect = UnknownObjectException(404, {}, {})
repo.create_file.return_value = {
"content": _content_file("new", "new_sha"),
"commit": MagicMock(),
}

store = await GitHubArtifactStore("akd/test", github_client=gh).load_artifacts()
stored = await store.write_artifact(Artifact(path="new.md", content="hello"))

repo.create_file.assert_called_once()
assert stored.metadata["sha"] == "new_sha"
assert "new.md" in store