From a140a278e71fc346ee062029627a8f4998375edc Mon Sep 17 00:00:00 2001
From: aandersen2323 <132087791+aandersen2323@users.noreply.github.com>
Date: Thu, 23 Oct 2025 14:53:53 -0500
Subject: [PATCH] Expose retry bridge and add token budgeting
---
mcp/mcp_http_server.py | 11 ++--
server/app.py | 105 +++++++++++++++++++++++++++++--------
server/requirements.txt | 1 +
server/tests/test_utils.py | 21 ++++++--
4 files changed, 106 insertions(+), 32 deletions(-)
diff --git a/mcp/mcp_http_server.py b/mcp/mcp_http_server.py
index d0ddb68..ec9842f 100644
--- a/mcp/mcp_http_server.py
+++ b/mcp/mcp_http_server.py
@@ -90,12 +90,12 @@ async def handle_mcp_request(request: Request) -> Dict[str, Any]:
try:
if method == "ping":
return mcp_respond(request_id, result={"ok": True, "ts": time.time()})
- if method == "flowdex.infer":
+ elif method == "flowdex.infer":
return mcp_respond(
request_id,
result=call_api("POST", "/infer", json_body=params),
)
- if method == "flowdex.infer.retry":
+ elif method == "flowdex.infer.retry":
run_id = params.get("run_id")
if not run_id:
raise BridgeError("run_id is required for flowdex.infer.retry")
@@ -112,17 +112,18 @@ async def handle_mcp_request(request: Request) -> Dict[str, Any]:
json_body=retry_payload,
),
)
- if method == "flowdex.memory.get":
+ elif method == "flowdex.memory.get":
return mcp_respond(
request_id,
result=call_api("GET", "/memory/get", params=params),
)
- if method == "flowdex.health":
+ elif method == "flowdex.health":
return mcp_respond(
request_id,
result=call_api("GET", "/health", params=params),
)
- raise BridgeError(f"Unknown method: {method}")
+ else:
+ raise BridgeError(f"Unknown method: {method}")
except BridgeError as exc:
return mcp_respond(request_id, error=str(exc))
except Exception as exc: # pragma: no cover - defensive programming
diff --git a/server/app.py b/server/app.py
index 0b7aed5..fcf0332 100644
--- a/server/app.py
+++ b/server/app.py
@@ -6,7 +6,7 @@
import time
from collections import Counter
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
import redis
import requests
@@ -14,6 +14,15 @@
from pydantic import BaseModel, Field
from redis.exceptions import RedisError
+try:
+ import tiktoken
+
+ _TOKEN_ENCODER = tiktoken.get_encoding("cl100k_base")
+ _HAS_TIKTOKEN = True
+except ImportError: # pragma: no cover - optional dependency
+ _TOKEN_ENCODER = None
+ _HAS_TIKTOKEN = False
+
CACHE_DIR = Path(os.environ.get("FLOWDEX_CACHE_DIR", ".flowdex_cache"))
CACHE_DIR.mkdir(parents=True, exist_ok=True)
@@ -130,10 +139,48 @@ def json_patch(old: str, new: str) -> Dict[str, Any]:
return {"common_prefix": prefix_len, "added": new[prefix_len:]}
-def truncate_to_budget(text: str, budget_chars: int) -> str:
- if len(text) <= budget_chars:
- return text
- return text[-budget_chars:]
+def count_tokens(text: str) -> int:
+ """Count tokens using ``tiktoken`` when available, otherwise estimate."""
+
+ if not text:
+ return 0
+ if _HAS_TIKTOKEN and _TOKEN_ENCODER is not None:
+ return len(_TOKEN_ENCODER.encode(text))
+ return max(1, len(text) // 4)
+
+
+def truncate_to_token_budget(text: str, budget_tokens: int) -> Tuple[str, int]:
+ """Truncate ``text`` to fit the supplied token budget.
+
+ Returns a tuple of the truncated text and the number of tokens consumed.
+ """
+
+ if not text or budget_tokens <= 0:
+ return "", 0
+
+ if not (_HAS_TIKTOKEN and _TOKEN_ENCODER is not None):
+ char_budget = max(budget_tokens * 4, 0)
+ truncated_text = text[-char_budget:] if char_budget else ""
+ return truncated_text, count_tokens(truncated_text)
+
+ tokens = _TOKEN_ENCODER.encode(text)
+ if len(tokens) <= budget_tokens:
+ return text, len(tokens)
+
+ truncated_tokens = tokens[-budget_tokens:]
+ try:
+ truncated_text = _TOKEN_ENCODER.decode(truncated_tokens)
+ except Exception: # pragma: no cover - defensive decoding fallback
+ start_char = max(
+ 0,
+ len(text)
+ - int(len(text) * (budget_tokens / max(len(tokens), 1)) * 1.1),
+ )
+ truncated_text = text[start_char:]
+ truncated_tokens = _TOKEN_ENCODER.encode(truncated_text)[-budget_tokens:]
+ truncated_text = _TOKEN_ENCODER.decode(truncated_tokens)
+
+ return truncated_text, len(truncated_tokens)
def _safe_json_dumps(value: Any) -> str:
@@ -147,7 +194,7 @@ def _safe_json_dumps(value: Any) -> str:
return repr(value)
-def build_tool_context(tool_specs: List[Dict[str, Any]], budget_chars: int) -> str:
+def build_tool_context(tool_specs: List[Dict[str, Any]], max_length: Optional[int] = None) -> str:
if not tool_specs:
return ""
serialized = []
@@ -167,7 +214,9 @@ def build_tool_context(tool_specs: List[Dict[str, Any]], budget_chars: int) -> s
)
)
blob = "\n\n".join(serialized)
- return truncate_to_budget(blob, budget_chars)
+ if max_length is None or len(blob) <= max_length:
+ return blob
+ return blob[-max_length:]
TOKEN_PATTERN = re.compile(r"\b\w+\b", re.UNICODE)
@@ -259,9 +308,14 @@ def call_llm(prompt: Dict[str, Any], request: InferRequest, tool_specs: List[Dic
if prompt.get("user"):
segments.append(f"\n{prompt['user']}\n")
- tool_context = build_tool_context(tool_specs, request.budget.tools)
+ tool_context_blob = build_tool_context(tool_specs, 100_000)
+ tool_context, tool_tokens = truncate_to_token_budget(
+ tool_context_blob, request.budget.tools
+ )
if tool_context:
segments.append(f"\n{tool_context}\n")
+ if "tokens_used" in prompt:
+ prompt["tokens_used"]["tools"] = tool_tokens
user_message = "\n\n".join(seg for seg in segments if seg)
if not user_message.strip():
@@ -417,9 +471,9 @@ def run_inference(req: InferRequest) -> Dict[str, Any]:
patch = json_patch(prior_text, context_blob)
prior_cache_path.write_text(context_blob)
- sys_b = truncate_to_budget(req.system_prompt, req.budget.system)
- ctx_b = truncate_to_budget(context_blob, req.budget.context)
- usr_b = truncate_to_budget(req.user_input, req.budget.user)
+ sys_b, sys_tokens = truncate_to_token_budget(req.system_prompt, req.budget.system)
+ ctx_b, ctx_tokens = truncate_to_token_budget(context_blob, req.budget.context)
+ usr_b, usr_tokens = truncate_to_token_budget(req.user_input, req.budget.user)
tool_names: List[str] = []
tool_specs: List[Dict[str, Any]] = []
@@ -445,6 +499,11 @@ def run_inference(req: InferRequest) -> Dict[str, Any]:
"model": req.model,
"max_tokens": req.max_tokens,
"retrievals": retrieved_contexts,
+ "tokens_used": {
+ "system": sys_tokens,
+ "context": ctx_tokens,
+ "user": usr_tokens,
+ },
}
output = call_llm(prompt, req, tool_specs, retrieved_contexts)
@@ -463,6 +522,7 @@ def run_inference(req: InferRequest) -> Dict[str, Any]:
"output": output,
"budgets": req.budget.dict(),
"used_context_chars": len(ctx_b),
+ "used_tokens": prompt["tokens_used"],
"retrievals": retrieved_contexts,
"tools_considered": tool_names,
}
@@ -501,18 +561,19 @@ def infer_retry(
except Exception as exc:
raise HTTPException(400, detail=f"Failed to parse original request: {exc}") from exc
- new_system_prompt = retry_req.system_prompt_override or original_req.system_prompt
- new_user_input = retry_req.user_input_override or original_req.user_input
-
- error_prefix = (
- "--- Previous Attempt Failed ---\n"
- "The previous attempt to solve this task failed with the following error:\n"
- f"\n{retry_req.error_context}\n\n"
- "Please analyze this error and provide a new solution.\n"
- "--- Original System Prompt ---\n"
+ original_req.system_prompt = (
+ retry_req.system_prompt_override or original_req.system_prompt
)
- original_req.system_prompt = error_prefix + new_system_prompt
- original_req.user_input = new_user_input
+ original_user_input = retry_req.user_input_override or original_req.user_input
+ error_context_block = (
+ "--- Context: Previous Attempt Failed ---\n"
+ "The previous attempt to solve the task failed. Analyze the following error "
+ "and provide a corrected solution based on the original request.\n"
+ f"\n{retry_req.error_context}\n\n"
+ "--- Original User Input ---\n"
+ )
+ original_req.user_input = error_context_block + original_user_input
+ print(f"Retrying run {run_id}. Added error context to user input.")
return run_inference(original_req)
diff --git a/server/requirements.txt b/server/requirements.txt
index aced120..546bd9d 100644
--- a/server/requirements.txt
+++ b/server/requirements.txt
@@ -4,3 +4,4 @@ pydantic
python-dotenv
redis
requests
+tiktoken
diff --git a/server/tests/test_utils.py b/server/tests/test_utils.py
index 1e0c356..209a9cb 100644
--- a/server/tests/test_utils.py
+++ b/server/tests/test_utils.py
@@ -10,8 +10,9 @@
_safe_json_dumps,
_tokenize,
build_tool_context,
+ count_tokens,
json_patch,
- truncate_to_budget,
+ truncate_to_token_budget,
)
@@ -27,9 +28,19 @@ def test_json_patch(old: str, new: str, expected: dict[str, object]) -> None:
assert json_patch(old, new) == expected
-def test_truncate_to_budget() -> None:
- assert truncate_to_budget("abcdef", 10) == "abcdef"
- assert truncate_to_budget("abcdef", 3) == "def"
+def test_truncate_to_token_budget() -> None:
+ text = "abcdef" * 5
+ truncated, tokens = truncate_to_token_budget(text, 3)
+ assert truncated.endswith("def")
+ assert tokens <= 3
+
+ truncated_full, tokens_full = truncate_to_token_budget(text, 100)
+ assert truncated_full == text
+ assert tokens_full >= count_tokens(text)
+
+ empty_text, empty_tokens = truncate_to_token_budget("", 10)
+ assert empty_text == ""
+ assert empty_tokens == 0
@pytest.mark.parametrize(
@@ -51,7 +62,7 @@ def test_build_tool_context_truncates_and_serializes() -> None:
{"name": "beta", "description": "second", "cost": "low", "schema": complex_schema},
]
- context = build_tool_context(tool_specs, budget_chars=500)
+ context = build_tool_context(tool_specs, max_length=500)
assert "Tool: alpha" in context
assert "Schema: {\"x\": \"int\"}" in context
# Non-serializable value should fallback to repr