Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 292 additions & 0 deletions brev/welcome-ui/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,45 @@
POLICY_FILE = os.path.join(SANDBOX_DIR, "policy.yaml")

LOG_FILE = "/tmp/nemoclaw-sandbox-create.log"
PROVIDER_CONFIG_CACHE = "/tmp/nemoclaw-provider-config-cache.json"
BREV_ENV_ID = os.environ.get("BREV_ENV_ID", "")
_detected_brev_id = ""

SANDBOX_PORT = 18789

_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]")

def _strip_ansi(text: str) -> str:
return _ANSI_RE.sub("", text)

def _read_config_cache() -> dict:
try:
with open(PROVIDER_CONFIG_CACHE) as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return {}


def _write_config_cache(cache: dict) -> None:
try:
with open(PROVIDER_CONFIG_CACHE, "w") as f:
json.dump(cache, f)
except OSError:
pass


def _cache_provider_config(name: str, config: dict) -> None:
cache = _read_config_cache()
cache[name] = config
_write_config_cache(cache)


def _remove_cached_provider(name: str) -> None:
cache = _read_config_cache()
cache.pop(name, None)
_write_config_cache(cache)


_sandbox_lock = threading.Lock()
_sandbox_state = {
"status": "idle", # idle | creating | running | error
Expand Down Expand Up @@ -93,6 +127,9 @@ def _run_inject_key(key: str, key_hash: str) -> None:
return

_inject_log(f"step 3/3: SUCCESS — provider nvidia-inference updated")
_cache_provider_config("nvidia-inference", {
"OPENAI_BASE_URL": "https://inference-api.nvidia.com/v1",
})
with _inject_key_lock:
_inject_key_state["status"] = "done"
_inject_key_state["error"] = None
Expand Down Expand Up @@ -455,6 +492,20 @@ def _route(self):
if path == "/api/inject-key" and self.command == "POST":
return self._handle_inject_key()

if path == "/api/providers" and self.command == "GET":
return self._handle_providers_list()
if path == "/api/providers" and self.command == "POST":
return self._handle_provider_create()
if re.match(r"^/api/providers/[\w-]+$", path) and self.command == "PUT":
return self._handle_provider_update(path.split("/")[-1])
if re.match(r"^/api/providers/[\w-]+$", path) and self.command == "DELETE":
return self._handle_provider_delete(path.split("/")[-1])

if path == "/api/cluster-inference" and self.command == "GET":
return self._handle_cluster_inference_get()
if path == "/api/cluster-inference" and self.command == "POST":
return self._handle_cluster_inference_set()

if _sandbox_ready():
return self._proxy_to_sandbox()

Expand Down Expand Up @@ -658,6 +709,234 @@ def _handle_inject_key(self):

return self._json_response(202, {"ok": True, "started": True})

# -- Provider CRUD --------------------------------------------------

@staticmethod
def _parse_provider_detail(stdout: str) -> dict | None:
"""Parse the text output of ``nemoclaw provider get <name>``."""
info: dict = {}
for line in stdout.splitlines():
line = _strip_ansi(line).strip()
if line.startswith("Id:"):
info["id"] = line.split(":", 1)[1].strip()
elif line.startswith("Name:"):
info["name"] = line.split(":", 1)[1].strip()
elif line.startswith("Type:"):
info["type"] = line.split(":", 1)[1].strip()
elif line.startswith("Credential keys:"):
raw = line.split(":", 1)[1].strip()
info["credentialKeys"] = (
[k.strip() for k in raw.split(",") if k.strip()]
if raw and raw != "<none>" else []
)
elif line.startswith("Config keys:"):
raw = line.split(":", 1)[1].strip()
info["configKeys"] = (
[k.strip() for k in raw.split(",") if k.strip()]
if raw and raw != "<none>" else []
)
return info if "name" in info else None

def _handle_providers_list(self):
try:
result = subprocess.run(
["nemoclaw", "provider", "list", "--names"],
capture_output=True, text=True, timeout=30,
)
if result.returncode != 0:
return self._json_response(502, {
"ok": False,
"error": (result.stderr or result.stdout or "provider list failed").strip(),
})
names = [n.strip() for n in result.stdout.strip().splitlines() if n.strip()]
except Exception as exc:
return self._json_response(502, {"ok": False, "error": str(exc)})

providers = []
config_cache = _read_config_cache()
for name in names:
try:
detail = subprocess.run(
["nemoclaw", "provider", "get", name],
capture_output=True, text=True, timeout=30,
)
if detail.returncode == 0:
parsed = self._parse_provider_detail(detail.stdout)
if parsed:
cached = config_cache.get(name, {})
if cached:
parsed["configValues"] = cached
providers.append(parsed)
except Exception:
pass

return self._json_response(200, {"ok": True, "providers": providers})

def _read_json_body(self) -> dict | None:
content_length = int(self.headers.get("Content-Length", 0))
if content_length == 0:
return None
raw = self.rfile.read(content_length).decode("utf-8", errors="replace")
try:
return json.loads(raw)
except json.JSONDecodeError:
return None

def _handle_provider_create(self):
data = self._read_json_body()
if not data:
return self._json_response(400, {"ok": False, "error": "invalid or empty JSON body"})

name = data.get("name", "").strip()
ptype = data.get("type", "").strip()
if not name or not ptype:
return self._json_response(400, {"ok": False, "error": "name and type are required"})

cmd = ["nemoclaw", "provider", "create", "--name", name, "--type", ptype]
creds = data.get("credentials", {})
configs = data.get("config", {})
if not creds:
cmd += ["--credential", "PLACEHOLDER=unused"]
for k, v in creds.items():
cmd += ["--credential", f"{k}={v}"]
for k, v in configs.items():
cmd += ["--config", f"{k}={v}"]

try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
err = (result.stderr or result.stdout or "create failed").strip()
return self._json_response(400, {"ok": False, "error": err})
if configs:
_cache_provider_config(name, configs)
return self._json_response(200, {"ok": True})
except Exception as exc:
return self._json_response(502, {"ok": False, "error": str(exc)})

def _handle_provider_update(self, name: str):
data = self._read_json_body()
if not data:
return self._json_response(400, {"ok": False, "error": "invalid or empty JSON body"})

ptype = data.get("type", "").strip()
if not ptype:
return self._json_response(400, {"ok": False, "error": "type is required"})

cmd = ["nemoclaw", "provider", "update", name, "--type", ptype]
for k, v in data.get("credentials", {}).items():
cmd += ["--credential", f"{k}={v}"]
configs = data.get("config", {})
for k, v in configs.items():
cmd += ["--config", f"{k}={v}"]

try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
if result.returncode != 0:
err = (result.stderr or result.stdout or "update failed").strip()
return self._json_response(400, {"ok": False, "error": err})
if configs:
_cache_provider_config(name, configs)
return self._json_response(200, {"ok": True})
except Exception as exc:
return self._json_response(502, {"ok": False, "error": str(exc)})

def _handle_provider_delete(self, name: str):
try:
result = subprocess.run(
["nemoclaw", "provider", "delete", name],
capture_output=True, text=True, timeout=30,
)
if result.returncode != 0:
err = (result.stderr or result.stdout or "delete failed").strip()
return self._json_response(400, {"ok": False, "error": err})
_remove_cached_provider(name)
return self._json_response(200, {"ok": True})
except Exception as exc:
return self._json_response(502, {"ok": False, "error": str(exc)})

# -- GET /api/cluster-inference ------------------------------------

@staticmethod
def _parse_cluster_inference(stdout: str) -> dict | None:
"""Parse ``nemoclaw cluster inference get/set`` output."""
fields: dict[str, str] = {}
for line in stdout.splitlines():
stripped = _strip_ansi(line).strip()
for key in ("Provider:", "Model:", "Version:"):
if stripped.startswith(key):
fields[key.rstrip(":")] = stripped[len(key):].strip()
if "Provider" not in fields:
return None
version = 0
try:
version = int(fields.get("Version", "0"))
except ValueError:
pass
return {
"providerName": fields["Provider"],
"modelId": fields.get("Model", ""),
"version": version,
}

def _handle_cluster_inference_get(self):
try:
result = subprocess.run(
["nemoclaw", "cluster", "inference", "get"],
capture_output=True, text=True, timeout=30,
)
if result.returncode != 0:
stderr = (result.stderr or "").strip()
if "not configured" in stderr.lower() or "not found" in stderr.lower():
return self._json_response(200, {
"ok": True,
"providerName": None,
"modelId": "",
"version": 0,
})
err = stderr or (result.stdout or "get failed").strip()
return self._json_response(400, {"ok": False, "error": err})
parsed = self._parse_cluster_inference(result.stdout)
if not parsed:
return self._json_response(200, {
"ok": True,
"providerName": None,
"modelId": "",
"version": 0,
})
return self._json_response(200, {"ok": True, **parsed})
except Exception as exc:
return self._json_response(502, {"ok": False, "error": str(exc)})

# -- POST /api/cluster-inference -----------------------------------

def _handle_cluster_inference_set(self):
body = self._read_json_body()
if body is None:
return self._json_response(400, {"ok": False, "error": "invalid JSON body"})
provider_name = (body.get("providerName") or "").strip()
model_id = (body.get("modelId") or "").strip()
if not provider_name:
return self._json_response(400, {"ok": False, "error": "providerName is required"})
if not model_id:
return self._json_response(400, {"ok": False, "error": "modelId is required"})
try:
result = subprocess.run(
["nemoclaw", "cluster", "inference", "set",
"--provider", provider_name,
"--model", model_id],
capture_output=True, text=True, timeout=30,
)
if result.returncode != 0:
err = (result.stderr or result.stdout or "set failed").strip()
return self._json_response(400, {"ok": False, "error": err})
parsed = self._parse_cluster_inference(result.stdout)
resp = {"ok": True}
if parsed:
resp.update(parsed)
return self._json_response(200, resp)
except Exception as exc:
return self._json_response(502, {"ok": False, "error": str(exc)})

# -- GET /api/sandbox-status ----------------------------------------

def _handle_sandbox_status(self):
Expand Down Expand Up @@ -716,7 +995,20 @@ def log_message(self, fmt, *args):
sys.stderr.write(f"[welcome-ui] {fmt % args}\n")


def _bootstrap_config_cache() -> None:
"""Seed the config cache for providers created before caching existed."""
if os.path.isfile(PROVIDER_CONFIG_CACHE):
return
_write_config_cache({
"nvidia-inference": {
"OPENAI_BASE_URL": "https://inference-api.nvidia.com/v1",
},
})
sys.stderr.write("[welcome-ui] Bootstrapped provider config cache\n")


if __name__ == "__main__":
_bootstrap_config_cache()
server = http.server.ThreadingHTTPServer(("", PORT), Handler)
print(f"NemoClaw Welcome UI → http://localhost:{PORT}")
server.serve_forever()
Loading