diff --git a/akd_ext/artifacts/stores/github.py b/akd_ext/artifacts/stores/github.py new file mode 100644 index 0000000..acbd4cb --- /dev/null +++ b/akd_ext/artifacts/stores/github.py @@ -0,0 +1,209 @@ +import os +import re +from contextlib import nullcontext +from datetime import datetime, timezone +from pathlib import PurePosixPath +from typing import Self + +from github import Auth, Github, UnknownObjectException +from loguru import logger + +from akd_ext.artifacts._base import Artifact, ArtifactStore + + +class GitHubArtifactStore(ArtifactStore[str]): + """GitHub-backed artifact store. + + Reads and writes artifacts inside a sub-tree of a GitHub repo via + the GitHub REST API (PyGithub). `root` can be either + "owner/repo[/path/within/repo]" or a github.com URL; in both cases + the first two segments identify the repo and any remaining segments + scope the store to that sub-tree. Slugs returned by the store are + relative to the sub-tree, matching `LocalArtifactStore`'s form — so + an agent's tool calls don't depend on which backend is in use. + """ + + def __init__( + self, + root: str, + *, + branch: str = "main", + access_token: str | None = None, + github_client: Github | None = None, + index_file: str | None = "README.md", + supported_extensions: tuple[str, ...] = (".md",), + debug: bool = False, + ) -> None: + """Construct a GitHub-backed artifact store. + + Args: + root: "owner/repo[/path]" or a github.com URL (e.g. + "https://github.com/NASA-IMPACT/akd/tree/main/agents/x/artifacts"). + branch: Branch to read/write. Defaults to "main". A branch + embedded in a `.../tree//...` URL takes precedence + over this kwarg. + access_token: GitHub personal access token. Falls back to + the `GITHUB_ACCESS_TOKEN` env var. Anonymous (public-read + only) if neither is set. Ignored if `github_client` is set. + github_client: Optional pre-built PyGithub client. When set, + the store uses it for all API calls and does NOT close + it on method exit (caller owns the lifetime). Useful for + bulk writes (reuses the HTTP session) and for tests + (inject a MagicMock). When None (default), a fresh + client is constructed per method call. + index_file: Directory-overview filename (defaults to README.md, + which GitHub renders at the directory level). + supported_extensions: Extensions to include; see ArtifactStore. + debug: Enable debug logging. + + Raises: + ValueError: If `root` does not contain at least "owner/repo". + """ + super().__init__( + root=root, + index_file=index_file, + supported_extensions=supported_extensions, + debug=debug, + ) + self.repo_name, self.path_prefix, url_branch = self._parse_root(root) + self.branch = url_branch or branch + self._token = access_token or os.getenv("GITHUB_ACCESS_TOKEN") + self._auth = Auth.Token(self._token) if self._token else None + self._github_client = github_client + + def _client(self): + """Context manager yielding a Github client. + + If a `github_client` was injected via `__init__`, returns it + wrapped in `nullcontext` (no close on exit — caller owns it). + Otherwise constructs a fresh client that closes on context exit. + """ + if self._github_client is not None: + return nullcontext(self._github_client) + return Github(auth=self._auth, retry=None) + + @staticmethod + def _parse_root(root: str) -> tuple[str, str, str | None]: + """Parse `root` into (repo_name, path_prefix, branch_hint). + + Accepts plain ``owner/repo[/path]`` or a github.com URL (with or + without scheme; with or without a ``/tree//...`` segment). + `branch_hint` is the URL-extracted branch if present, else None. + """ + s = re.sub(r"^(?:https?://)?(?:www\.)?github\.com/?", "", root.strip()).rstrip("/") + parts = PurePosixPath(s).parts + if len(parts) < 2: + raise ValueError(f"GitHub root must be 'owner/repo[/path]' or a github.com URL, got {root!r}") + repo_name = f"{parts[0]}/{parts[1]}" + if len(parts) >= 4 and parts[2] in ("tree", "blob"): + return repo_name, "/".join(parts[4:]), parts[3] + return repo_name, "/".join(parts[2:]), None + + async def load_artifacts(self) -> Self: + """Load all available artifacts into the cache. + + Returns: + Self, for fluent chaining. + """ + prefix = PurePosixPath(self.path_prefix) if self.path_prefix else None + with self._client() as gh: + repo = gh.get_repo(self.repo_name) + tree = repo.get_git_tree(self.branch, recursive=True) + for entry in tree.tree: + if entry.type != "blob": + continue + full = PurePosixPath(entry.path) + if prefix and not full.is_relative_to(prefix): + continue + slug = str(full.relative_to(prefix)) if prefix else entry.path + if not self._is_supported(slug): + continue + content_file = repo.get_contents(entry.path, ref=self.branch) + self[slug] = Artifact[str]( + path=slug, + content=content_file.decoded_content.decode("utf-8"), + metadata={"sha": entry.sha}, + updated_at=content_file.last_modified_datetime, + ) + logger.info( + "[GitHubArtifactStore] loaded {} artifacts from {}", + len(self), + self.root, + ) + return self + + async def read_artifact(self, path: str) -> Artifact[str]: + """Fetch an artifact by path from the cache. + + Cache-only — does not hit the GitHub API. Call `refresh()` to + re-sync after external changes to the repo. + + Args: + path: Path of the artifact to load (e.g. "contexts/role.md"). + + Returns: + The artifact including its content. + + Raises: + KeyError: If the artifact is not in the cache — call + `load_artifacts()` or `refresh()` first. + """ + return self[path] + + async def write_artifact(self, artifact: Artifact[str]) -> Artifact[str]: + """Persist an artifact to the repo as a commit. + + Args: + artifact: Artifact to write. Commit message comes from + `artifact.metadata["commit_message"]` if set, else + "Update {path}". + + Returns: + Stored artifact with refreshed `metadata["sha"]` and + `updated_at` from the new commit. + """ + full_path = str(PurePosixPath(self.path_prefix) / artifact.path) if self.path_prefix else artifact.path + message = artifact.metadata.get("commit_message") or f"Update {artifact.path}" + + cached = self._artifacts.get(artifact.path) + # Fast-path: skip write if content is unchanged vs. cache — avoids a + # spurious commit with identical tree. + if cached and cached.content == artifact.content: + logger.debug( + "[GitHubArtifactStore] no-op write (unchanged): {}", + artifact.path, + ) + return cached + sha = cached.metadata.get("sha") if cached else None + + with self._client() as gh: + repo = gh.get_repo(self.repo_name) + + # Probe remote if we don't already know the sha + if sha is None: + try: + sha = repo.get_contents(full_path, ref=self.branch).sha + except UnknownObjectException: + pass # stays None → will create + + common = dict( + path=full_path, + message=message, + content=artifact.content, + branch=self.branch, + ) + result = repo.update_file(sha=sha, **common) if sha else repo.create_file(**common) + + stored = artifact.model_copy( + update={ + "metadata": {**artifact.metadata, "sha": result["content"].sha}, + "updated_at": datetime.now(timezone.utc), + } + ) + self[artifact.path] = stored + logger.info( + "[GitHubArtifactStore] wrote: {} (sha={})", + stored.path, + result["content"].sha, + ) + return stored diff --git a/tests/artifacts/test_github.py b/tests/artifacts/test_github.py new file mode 100644 index 0000000..aac02c8 --- /dev/null +++ b/tests/artifacts/test_github.py @@ -0,0 +1,79 @@ +"""Barebone mocked tests for GitHubArtifactStore: load, read, write. + +Uses the `github_client` injection kwarg rather than patching — see +`GitHubArtifactStore.__init__` docstring. +""" + +from unittest.mock import MagicMock + +import pytest +from github import UnknownObjectException + +from akd_ext.artifacts import Artifact +from akd_ext.artifacts.stores.github import GitHubArtifactStore + + +def _tree_entry(path: str, sha: str, type_: str = "blob"): + e = MagicMock() + e.path = path + e.sha = sha + e.type = type_ + return e + + +def _content_file(content: str, sha: str): + cf = MagicMock() + cf.decoded_content = content.encode() + cf.sha = sha + cf.last_modified_datetime = None + return cf + + +@pytest.fixture +def mock_github(): + """Return (github_client_mock, repo_mock) for tests to configure.""" + gh = MagicMock() + repo = gh.get_repo.return_value + return gh, repo + + +async def test_load(mock_github): + gh, repo = mock_github + repo.get_git_tree.return_value.tree = [ + _tree_entry("index.md", "sha1"), + _tree_entry("data.json", "sha2"), # filtered by supported_extensions + ] + repo.get_contents.return_value = _content_file("root", "sha1") + + store = await GitHubArtifactStore("akd/test", github_client=gh).load_artifacts() + + assert "index.md" in store + assert "data.json" not in store + assert store["index.md"].metadata["sha"] == "sha1" + + +async def test_read(mock_github): + gh, repo = mock_github + repo.get_git_tree.return_value.tree = [] + store = await GitHubArtifactStore("akd/test", github_client=gh).load_artifacts() + store["foo.md"] = Artifact(path="foo.md", content="hi") + + got = await store.read_artifact("foo.md") + assert got.content == "hi" + + +async def test_write(mock_github): + gh, repo = mock_github + repo.get_git_tree.return_value.tree = [] + repo.get_contents.side_effect = UnknownObjectException(404, {}, {}) + repo.create_file.return_value = { + "content": _content_file("new", "new_sha"), + "commit": MagicMock(), + } + + store = await GitHubArtifactStore("akd/test", github_client=gh).load_artifacts() + stored = await store.write_artifact(Artifact(path="new.md", content="hello")) + + repo.create_file.assert_called_once() + assert stored.metadata["sha"] == "new_sha" + assert "new.md" in store