Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mcp/mcp_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ async def handle_mcp_request(request: Request) -> Dict[str, Any]:
try:
if method == "ping":
return mcp_respond(request_id, result={"ok": True, "ts": time.time()})
if method == "flowdex.infer":
elif method == "flowdex.infer":
return mcp_respond(
request_id,
result=call_api("POST", "/infer", json_body=params),
)
if method == "flowdex.infer.retry":
elif method == "flowdex.infer.retry":
run_id = params.get("run_id")
if not run_id:
raise BridgeError("run_id is required for flowdex.infer.retry")
Expand All @@ -112,17 +112,18 @@ async def handle_mcp_request(request: Request) -> Dict[str, Any]:
json_body=retry_payload,
),
)
if method == "flowdex.memory.get":
elif method == "flowdex.memory.get":
return mcp_respond(
request_id,
result=call_api("GET", "/memory/get", params=params),
)
if method == "flowdex.health":
elif method == "flowdex.health":
return mcp_respond(
request_id,
result=call_api("GET", "/health", params=params),
)
raise BridgeError(f"Unknown method: {method}")
else:
raise BridgeError(f"Unknown method: {method}")
except BridgeError as exc:
return mcp_respond(request_id, error=str(exc))
except Exception as exc: # pragma: no cover - defensive programming
Expand Down
105 changes: 83 additions & 22 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,23 @@
import time
from collections import Counter
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import redis
import requests
from fastapi import Depends, FastAPI, Header, HTTPException
from pydantic import BaseModel, Field
from redis.exceptions import RedisError

try:
import tiktoken

_TOKEN_ENCODER = tiktoken.get_encoding("cl100k_base")
_HAS_TIKTOKEN = True
except ImportError: # pragma: no cover - optional dependency
_TOKEN_ENCODER = None
_HAS_TIKTOKEN = False

CACHE_DIR = Path(os.environ.get("FLOWDEX_CACHE_DIR", ".flowdex_cache"))
CACHE_DIR.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -130,10 +139,48 @@ def json_patch(old: str, new: str) -> Dict[str, Any]:
return {"common_prefix": prefix_len, "added": new[prefix_len:]}


def truncate_to_budget(text: str, budget_chars: int) -> str:
if len(text) <= budget_chars:
return text
return text[-budget_chars:]
def count_tokens(text: str) -> int:
"""Count tokens using ``tiktoken`` when available, otherwise estimate."""

if not text:
return 0
if _HAS_TIKTOKEN and _TOKEN_ENCODER is not None:
return len(_TOKEN_ENCODER.encode(text))
return max(1, len(text) // 4)


def truncate_to_token_budget(text: str, budget_tokens: int) -> Tuple[str, int]:
"""Truncate ``text`` to fit the supplied token budget.

Returns a tuple of the truncated text and the number of tokens consumed.
"""

if not text or budget_tokens <= 0:
return "", 0

if not (_HAS_TIKTOKEN and _TOKEN_ENCODER is not None):
char_budget = max(budget_tokens * 4, 0)
truncated_text = text[-char_budget:] if char_budget else ""
return truncated_text, count_tokens(truncated_text)

tokens = _TOKEN_ENCODER.encode(text)
if len(tokens) <= budget_tokens:
return text, len(tokens)

truncated_tokens = tokens[-budget_tokens:]
try:
truncated_text = _TOKEN_ENCODER.decode(truncated_tokens)
except Exception: # pragma: no cover - defensive decoding fallback
start_char = max(
0,
len(text)
- int(len(text) * (budget_tokens / max(len(tokens), 1)) * 1.1),
)
truncated_text = text[start_char:]
truncated_tokens = _TOKEN_ENCODER.encode(truncated_text)[-budget_tokens:]
truncated_text = _TOKEN_ENCODER.decode(truncated_tokens)

return truncated_text, len(truncated_tokens)


def _safe_json_dumps(value: Any) -> str:
Expand All @@ -147,7 +194,7 @@ def _safe_json_dumps(value: Any) -> str:
return repr(value)


def build_tool_context(tool_specs: List[Dict[str, Any]], budget_chars: int) -> str:
def build_tool_context(tool_specs: List[Dict[str, Any]], max_length: Optional[int] = None) -> str:
if not tool_specs:
return ""
serialized = []
Expand All @@ -167,7 +214,9 @@ def build_tool_context(tool_specs: List[Dict[str, Any]], budget_chars: int) -> s
)
)
blob = "\n\n".join(serialized)
return truncate_to_budget(blob, budget_chars)
if max_length is None or len(blob) <= max_length:
return blob
return blob[-max_length:]


TOKEN_PATTERN = re.compile(r"\b\w+\b", re.UNICODE)
Expand Down Expand Up @@ -259,9 +308,14 @@ def call_llm(prompt: Dict[str, Any], request: InferRequest, tool_specs: List[Dic
if prompt.get("user"):
segments.append(f"<user_input>\n{prompt['user']}\n</user_input>")

tool_context = build_tool_context(tool_specs, request.budget.tools)
tool_context_blob = build_tool_context(tool_specs, 100_000)
tool_context, tool_tokens = truncate_to_token_budget(
tool_context_blob, request.budget.tools
)
if tool_context:
segments.append(f"<tools>\n{tool_context}\n</tools>")
if "tokens_used" in prompt:
prompt["tokens_used"]["tools"] = tool_tokens

user_message = "\n\n".join(seg for seg in segments if seg)
if not user_message.strip():
Expand Down Expand Up @@ -417,9 +471,9 @@ def run_inference(req: InferRequest) -> Dict[str, Any]:
patch = json_patch(prior_text, context_blob)
prior_cache_path.write_text(context_blob)

sys_b = truncate_to_budget(req.system_prompt, req.budget.system)
ctx_b = truncate_to_budget(context_blob, req.budget.context)
usr_b = truncate_to_budget(req.user_input, req.budget.user)
sys_b, sys_tokens = truncate_to_token_budget(req.system_prompt, req.budget.system)
ctx_b, ctx_tokens = truncate_to_token_budget(context_blob, req.budget.context)
usr_b, usr_tokens = truncate_to_token_budget(req.user_input, req.budget.user)

tool_names: List[str] = []
tool_specs: List[Dict[str, Any]] = []
Expand All @@ -445,6 +499,11 @@ def run_inference(req: InferRequest) -> Dict[str, Any]:
"model": req.model,
"max_tokens": req.max_tokens,
"retrievals": retrieved_contexts,
"tokens_used": {
"system": sys_tokens,
"context": ctx_tokens,
"user": usr_tokens,
},
}

output = call_llm(prompt, req, tool_specs, retrieved_contexts)
Expand All @@ -463,6 +522,7 @@ def run_inference(req: InferRequest) -> Dict[str, Any]:
"output": output,
"budgets": req.budget.dict(),
"used_context_chars": len(ctx_b),
"used_tokens": prompt["tokens_used"],
"retrievals": retrieved_contexts,
"tools_considered": tool_names,
}
Expand Down Expand Up @@ -501,18 +561,19 @@ def infer_retry(
except Exception as exc:
raise HTTPException(400, detail=f"Failed to parse original request: {exc}") from exc

new_system_prompt = retry_req.system_prompt_override or original_req.system_prompt
new_user_input = retry_req.user_input_override or original_req.user_input

error_prefix = (
"--- Previous Attempt Failed ---\n"
"The previous attempt to solve this task failed with the following error:\n"
f"<error>\n{retry_req.error_context}\n</error>\n"
"Please analyze this error and provide a new solution.\n"
"--- Original System Prompt ---\n"
original_req.system_prompt = (
retry_req.system_prompt_override or original_req.system_prompt
)

original_req.system_prompt = error_prefix + new_system_prompt
original_req.user_input = new_user_input
original_user_input = retry_req.user_input_override or original_req.user_input
error_context_block = (
"--- Context: Previous Attempt Failed ---\n"
"The previous attempt to solve the task failed. Analyze the following error "
"and provide a corrected solution based on the original request.\n"
f"<error_details>\n{retry_req.error_context}\n</error_details>\n"
"--- Original User Input ---\n"
)
original_req.user_input = error_context_block + original_user_input

print(f"Retrying run {run_id}. Added error context to user input.")
return run_inference(original_req)
1 change: 1 addition & 0 deletions server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pydantic
python-dotenv
redis
requests
tiktoken
21 changes: 16 additions & 5 deletions server/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
_safe_json_dumps,
_tokenize,
build_tool_context,
count_tokens,
json_patch,
truncate_to_budget,
truncate_to_token_budget,
)


Expand All @@ -27,9 +28,19 @@ def test_json_patch(old: str, new: str, expected: dict[str, object]) -> None:
assert json_patch(old, new) == expected


def test_truncate_to_budget() -> None:
assert truncate_to_budget("abcdef", 10) == "abcdef"
assert truncate_to_budget("abcdef", 3) == "def"
def test_truncate_to_token_budget() -> None:
text = "abcdef" * 5
truncated, tokens = truncate_to_token_budget(text, 3)
assert truncated.endswith("def")
assert tokens <= 3

truncated_full, tokens_full = truncate_to_token_budget(text, 100)
assert truncated_full == text
assert tokens_full >= count_tokens(text)

empty_text, empty_tokens = truncate_to_token_budget("", 10)
assert empty_text == ""
assert empty_tokens == 0


@pytest.mark.parametrize(
Expand All @@ -51,7 +62,7 @@ def test_build_tool_context_truncates_and_serializes() -> None:
{"name": "beta", "description": "second", "cost": "low", "schema": complex_schema},
]

context = build_tool_context(tool_specs, budget_chars=500)
context = build_tool_context(tool_specs, max_length=500)
assert "Tool: alpha" in context
assert "Schema: {\"x\": \"int\"}" in context
# Non-serializable value should fallback to repr
Expand Down
Loading