From a140a278e71fc346ee062029627a8f4998375edc Mon Sep 17 00:00:00 2001 From: aandersen2323 <132087791+aandersen2323@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:53:53 -0500 Subject: [PATCH] Expose retry bridge and add token budgeting --- mcp/mcp_http_server.py | 11 ++-- server/app.py | 105 +++++++++++++++++++++++++++++-------- server/requirements.txt | 1 + server/tests/test_utils.py | 21 ++++++-- 4 files changed, 106 insertions(+), 32 deletions(-) diff --git a/mcp/mcp_http_server.py b/mcp/mcp_http_server.py index d0ddb68..ec9842f 100644 --- a/mcp/mcp_http_server.py +++ b/mcp/mcp_http_server.py @@ -90,12 +90,12 @@ async def handle_mcp_request(request: Request) -> Dict[str, Any]: try: if method == "ping": return mcp_respond(request_id, result={"ok": True, "ts": time.time()}) - if method == "flowdex.infer": + elif method == "flowdex.infer": return mcp_respond( request_id, result=call_api("POST", "/infer", json_body=params), ) - if method == "flowdex.infer.retry": + elif method == "flowdex.infer.retry": run_id = params.get("run_id") if not run_id: raise BridgeError("run_id is required for flowdex.infer.retry") @@ -112,17 +112,18 @@ async def handle_mcp_request(request: Request) -> Dict[str, Any]: json_body=retry_payload, ), ) - if method == "flowdex.memory.get": + elif method == "flowdex.memory.get": return mcp_respond( request_id, result=call_api("GET", "/memory/get", params=params), ) - if method == "flowdex.health": + elif method == "flowdex.health": return mcp_respond( request_id, result=call_api("GET", "/health", params=params), ) - raise BridgeError(f"Unknown method: {method}") + else: + raise BridgeError(f"Unknown method: {method}") except BridgeError as exc: return mcp_respond(request_id, error=str(exc)) except Exception as exc: # pragma: no cover - defensive programming diff --git a/server/app.py b/server/app.py index 0b7aed5..fcf0332 100644 --- a/server/app.py +++ b/server/app.py @@ -6,7 +6,7 @@ import time from collections import Counter from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import redis import requests @@ -14,6 +14,15 @@ from pydantic import BaseModel, Field from redis.exceptions import RedisError +try: + import tiktoken + + _TOKEN_ENCODER = tiktoken.get_encoding("cl100k_base") + _HAS_TIKTOKEN = True +except ImportError: # pragma: no cover - optional dependency + _TOKEN_ENCODER = None + _HAS_TIKTOKEN = False + CACHE_DIR = Path(os.environ.get("FLOWDEX_CACHE_DIR", ".flowdex_cache")) CACHE_DIR.mkdir(parents=True, exist_ok=True) @@ -130,10 +139,48 @@ def json_patch(old: str, new: str) -> Dict[str, Any]: return {"common_prefix": prefix_len, "added": new[prefix_len:]} -def truncate_to_budget(text: str, budget_chars: int) -> str: - if len(text) <= budget_chars: - return text - return text[-budget_chars:] +def count_tokens(text: str) -> int: + """Count tokens using ``tiktoken`` when available, otherwise estimate.""" + + if not text: + return 0 + if _HAS_TIKTOKEN and _TOKEN_ENCODER is not None: + return len(_TOKEN_ENCODER.encode(text)) + return max(1, len(text) // 4) + + +def truncate_to_token_budget(text: str, budget_tokens: int) -> Tuple[str, int]: + """Truncate ``text`` to fit the supplied token budget. + + Returns a tuple of the truncated text and the number of tokens consumed. + """ + + if not text or budget_tokens <= 0: + return "", 0 + + if not (_HAS_TIKTOKEN and _TOKEN_ENCODER is not None): + char_budget = max(budget_tokens * 4, 0) + truncated_text = text[-char_budget:] if char_budget else "" + return truncated_text, count_tokens(truncated_text) + + tokens = _TOKEN_ENCODER.encode(text) + if len(tokens) <= budget_tokens: + return text, len(tokens) + + truncated_tokens = tokens[-budget_tokens:] + try: + truncated_text = _TOKEN_ENCODER.decode(truncated_tokens) + except Exception: # pragma: no cover - defensive decoding fallback + start_char = max( + 0, + len(text) + - int(len(text) * (budget_tokens / max(len(tokens), 1)) * 1.1), + ) + truncated_text = text[start_char:] + truncated_tokens = _TOKEN_ENCODER.encode(truncated_text)[-budget_tokens:] + truncated_text = _TOKEN_ENCODER.decode(truncated_tokens) + + return truncated_text, len(truncated_tokens) def _safe_json_dumps(value: Any) -> str: @@ -147,7 +194,7 @@ def _safe_json_dumps(value: Any) -> str: return repr(value) -def build_tool_context(tool_specs: List[Dict[str, Any]], budget_chars: int) -> str: +def build_tool_context(tool_specs: List[Dict[str, Any]], max_length: Optional[int] = None) -> str: if not tool_specs: return "" serialized = [] @@ -167,7 +214,9 @@ def build_tool_context(tool_specs: List[Dict[str, Any]], budget_chars: int) -> s ) ) blob = "\n\n".join(serialized) - return truncate_to_budget(blob, budget_chars) + if max_length is None or len(blob) <= max_length: + return blob + return blob[-max_length:] TOKEN_PATTERN = re.compile(r"\b\w+\b", re.UNICODE) @@ -259,9 +308,14 @@ def call_llm(prompt: Dict[str, Any], request: InferRequest, tool_specs: List[Dic if prompt.get("user"): segments.append(f"\n{prompt['user']}\n") - tool_context = build_tool_context(tool_specs, request.budget.tools) + tool_context_blob = build_tool_context(tool_specs, 100_000) + tool_context, tool_tokens = truncate_to_token_budget( + tool_context_blob, request.budget.tools + ) if tool_context: segments.append(f"\n{tool_context}\n") + if "tokens_used" in prompt: + prompt["tokens_used"]["tools"] = tool_tokens user_message = "\n\n".join(seg for seg in segments if seg) if not user_message.strip(): @@ -417,9 +471,9 @@ def run_inference(req: InferRequest) -> Dict[str, Any]: patch = json_patch(prior_text, context_blob) prior_cache_path.write_text(context_blob) - sys_b = truncate_to_budget(req.system_prompt, req.budget.system) - ctx_b = truncate_to_budget(context_blob, req.budget.context) - usr_b = truncate_to_budget(req.user_input, req.budget.user) + sys_b, sys_tokens = truncate_to_token_budget(req.system_prompt, req.budget.system) + ctx_b, ctx_tokens = truncate_to_token_budget(context_blob, req.budget.context) + usr_b, usr_tokens = truncate_to_token_budget(req.user_input, req.budget.user) tool_names: List[str] = [] tool_specs: List[Dict[str, Any]] = [] @@ -445,6 +499,11 @@ def run_inference(req: InferRequest) -> Dict[str, Any]: "model": req.model, "max_tokens": req.max_tokens, "retrievals": retrieved_contexts, + "tokens_used": { + "system": sys_tokens, + "context": ctx_tokens, + "user": usr_tokens, + }, } output = call_llm(prompt, req, tool_specs, retrieved_contexts) @@ -463,6 +522,7 @@ def run_inference(req: InferRequest) -> Dict[str, Any]: "output": output, "budgets": req.budget.dict(), "used_context_chars": len(ctx_b), + "used_tokens": prompt["tokens_used"], "retrievals": retrieved_contexts, "tools_considered": tool_names, } @@ -501,18 +561,19 @@ def infer_retry( except Exception as exc: raise HTTPException(400, detail=f"Failed to parse original request: {exc}") from exc - new_system_prompt = retry_req.system_prompt_override or original_req.system_prompt - new_user_input = retry_req.user_input_override or original_req.user_input - - error_prefix = ( - "--- Previous Attempt Failed ---\n" - "The previous attempt to solve this task failed with the following error:\n" - f"\n{retry_req.error_context}\n\n" - "Please analyze this error and provide a new solution.\n" - "--- Original System Prompt ---\n" + original_req.system_prompt = ( + retry_req.system_prompt_override or original_req.system_prompt ) - original_req.system_prompt = error_prefix + new_system_prompt - original_req.user_input = new_user_input + original_user_input = retry_req.user_input_override or original_req.user_input + error_context_block = ( + "--- Context: Previous Attempt Failed ---\n" + "The previous attempt to solve the task failed. Analyze the following error " + "and provide a corrected solution based on the original request.\n" + f"\n{retry_req.error_context}\n\n" + "--- Original User Input ---\n" + ) + original_req.user_input = error_context_block + original_user_input + print(f"Retrying run {run_id}. Added error context to user input.") return run_inference(original_req) diff --git a/server/requirements.txt b/server/requirements.txt index aced120..546bd9d 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -4,3 +4,4 @@ pydantic python-dotenv redis requests +tiktoken diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py index 1e0c356..209a9cb 100644 --- a/server/tests/test_utils.py +++ b/server/tests/test_utils.py @@ -10,8 +10,9 @@ _safe_json_dumps, _tokenize, build_tool_context, + count_tokens, json_patch, - truncate_to_budget, + truncate_to_token_budget, ) @@ -27,9 +28,19 @@ def test_json_patch(old: str, new: str, expected: dict[str, object]) -> None: assert json_patch(old, new) == expected -def test_truncate_to_budget() -> None: - assert truncate_to_budget("abcdef", 10) == "abcdef" - assert truncate_to_budget("abcdef", 3) == "def" +def test_truncate_to_token_budget() -> None: + text = "abcdef" * 5 + truncated, tokens = truncate_to_token_budget(text, 3) + assert truncated.endswith("def") + assert tokens <= 3 + + truncated_full, tokens_full = truncate_to_token_budget(text, 100) + assert truncated_full == text + assert tokens_full >= count_tokens(text) + + empty_text, empty_tokens = truncate_to_token_budget("", 10) + assert empty_text == "" + assert empty_tokens == 0 @pytest.mark.parametrize( @@ -51,7 +62,7 @@ def test_build_tool_context_truncates_and_serializes() -> None: {"name": "beta", "description": "second", "cost": "low", "schema": complex_schema}, ] - context = build_tool_context(tool_specs, budget_chars=500) + context = build_tool_context(tool_specs, max_length=500) assert "Tool: alpha" in context assert "Schema: {\"x\": \"int\"}" in context # Non-serializable value should fallback to repr