diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py new file mode 100644 index 000000000..792835181 --- /dev/null +++ b/astrbot/core/agent/context/compressor.py @@ -0,0 +1,243 @@ +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +from ..message import Message + +if TYPE_CHECKING: + from astrbot import logger +else: + try: + from astrbot import logger + except ImportError: + import logging + + logger = logging.getLogger("astrbot") + +if TYPE_CHECKING: + from astrbot.core.provider.provider import Provider + +from ..context.truncator import ContextTruncator + + +@runtime_checkable +class ContextCompressor(Protocol): + """ + Protocol for context compressors. + Provides an interface for compressing message lists. + """ + + def should_compress( + self, messages: list[Message], current_tokens: int, max_tokens: int + ) -> bool: + """Check if compression is needed. + + Args: + messages: The message list to evaluate. + current_tokens: The current token count. + max_tokens: The maximum allowed tokens for the model. + + Returns: + True if compression is needed, False otherwise. + """ + ... + + async def __call__(self, messages: list[Message]) -> list[Message]: + """Compress the message list. + + Args: + messages: The original message list. + + Returns: + The compressed message list. + """ + ... + + +class TruncateByTurnsCompressor: + """Truncate by turns compressor implementation. + Truncates the message list by removing older turns. + """ + + def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82): + """Initialize the truncate by turns compressor. + + Args: + truncate_turns: The number of turns to remove when truncating (default: 1). + compression_threshold: The compression trigger threshold (default: 0.82). + """ + self.truncate_turns = truncate_turns + self.compression_threshold = compression_threshold + + def should_compress( + self, messages: list[Message], current_tokens: int, max_tokens: int + ) -> bool: + """Check if compression is needed. + + Args: + messages: The message list to evaluate. + current_tokens: The current token count. + max_tokens: The maximum allowed tokens. + + Returns: + True if compression is needed, False otherwise. + """ + if max_tokens <= 0 or current_tokens <= 0: + return False + usage_rate = current_tokens / max_tokens + return usage_rate > self.compression_threshold + + async def __call__(self, messages: list[Message]) -> list[Message]: + truncator = ContextTruncator() + truncated_messages = truncator.truncate_by_dropping_oldest_turns( + messages, + drop_turns=self.truncate_turns, + ) + return truncated_messages + + +def split_history( + messages: list[Message], keep_recent: int +) -> tuple[list[Message], list[Message], list[Message]]: + """Split the message list into system messages, messages to summarize, and recent messages. + + Ensures that the split point is between complete user-assistant pairs to maintain conversation flow. + + Args: + messages: The original message list. + keep_recent: The number of latest messages to keep. + + Returns: + tuple: (system_messages, messages_to_summarize, recent_messages) + """ + # keep the system messages + first_non_system = 0 + for i, msg in enumerate(messages): + if msg.role != "system": + first_non_system = i + break + + system_messages = messages[:first_non_system] + non_system_messages = messages[first_non_system:] + + if len(non_system_messages) <= keep_recent: + return system_messages, [], non_system_messages + + # Find the split point, ensuring recent_messages starts with a user message + # This maintains complete conversation turns + split_index = len(non_system_messages) - keep_recent + + # Search backward from split_index to find the first user message + # This ensures recent_messages starts with a user message (complete turn) + while split_index > 0 and non_system_messages[split_index].role != "user": + # TODO: +=1 or -=1 ? calculate by tokens + split_index -= 1 + + # If we couldn't find a user message, keep all messages as recent + if split_index == 0: + return system_messages, [], non_system_messages + + messages_to_summarize = non_system_messages[:split_index] + recent_messages = non_system_messages[split_index:] + + return system_messages, messages_to_summarize, recent_messages + + +class LLMSummaryCompressor: + """LLM-based summary compressor. + Uses LLM to summarize the old conversation history, keeping the latest messages. + """ + + def __init__( + self, + provider: "Provider", + keep_recent: int = 4, + instruction_text: str | None = None, + compression_threshold: float = 0.82, + ): + """Initialize the LLM summary compressor. + + Args: + provider: The LLM provider instance. + keep_recent: The number of latest messages to keep (default: 4). + instruction_text: Custom instruction for summary generation. + compression_threshold: The compression trigger threshold (default: 0.82). + """ + self.provider = provider + self.keep_recent = keep_recent + self.compression_threshold = compression_threshold + + self.instruction_text = instruction_text or ( + "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" + "1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n" + "2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n" + "3. If there was an initial user goal, state it first and describe the current progress/status.\n" + "4. Write the summary in the user's language.\n" + ) + + def should_compress( + self, messages: list[Message], current_tokens: int, max_tokens: int + ) -> bool: + """Check if compression is needed. + + Args: + messages: The message list to evaluate. + current_tokens: The current token count. + max_tokens: The maximum allowed tokens. + + Returns: + True if compression is needed, False otherwise. + """ + if max_tokens <= 0 or current_tokens <= 0: + return False + usage_rate = current_tokens / max_tokens + return usage_rate > self.compression_threshold + + async def __call__(self, messages: list[Message]) -> list[Message]: + """Use LLM to generate a summary of the conversation history. + + Process: + 1. Divide messages: keep the system message and the latest N messages. + 2. Send the old messages + the instruction message to the LLM. + 3. Reconstruct the message list: [system message, summary message, latest messages]. + """ + if len(messages) <= self.keep_recent + 1: + return messages + + system_messages, messages_to_summarize, recent_messages = split_history( + messages, self.keep_recent + ) + + if not messages_to_summarize: + return messages + + # build payload + instruction_message = Message(role="user", content=self.instruction_text) + llm_payload = messages_to_summarize + [instruction_message] + + # generate summary + try: + response = await self.provider.text_chat(contexts=llm_payload) + summary_content = response.completion_text + except Exception as e: + logger.error(f"Failed to generate summary: {e}") + return messages + + # build result + result = [] + result.extend(system_messages) + + result.append( + Message( + role="user", + content=f"Our previous history conversation summary: {summary_content}", + ) + ) + result.append( + Message( + role="assistant", + content="Acknowledged the summary of our previous conversation history.", + ) + ) + + result.extend(recent_messages) + + return result diff --git a/astrbot/core/agent/context/config.py b/astrbot/core/agent/context/config.py new file mode 100644 index 000000000..b8fd8eb96 --- /dev/null +++ b/astrbot/core/agent/context/config.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from .compressor import ContextCompressor +from .token_counter import TokenCounter + +if TYPE_CHECKING: + from astrbot.core.provider.provider import Provider + + +@dataclass +class ContextConfig: + """Context configuration class.""" + + max_context_tokens: int = 0 + """Maximum number of context tokens. <= 0 means no limit.""" + enforce_max_turns: int = -1 # -1 means no limit + """Maximum number of conversation turns to keep. -1 means no limit. Executed before compression.""" + truncate_turns: int = 1 + """Number of conversation turns to discard at once when truncation is triggered. + Two processes will use this value: + + 1. Enforce max turns truncation. + 2. Truncation by turns compression strategy. + """ + llm_compress_instruction: str | None = None + """Instruction prompt for LLM-based compression.""" + llm_compress_keep_recent: int = 0 + """Number of recent messages to keep during LLM-based compression.""" + llm_compress_provider: "Provider | None" = None + """LLM provider used for compression tasks. If None, truncation strategy is used.""" + custom_token_counter: TokenCounter | None = None + """Custom token counting method. If None, the default method is used.""" + custom_compressor: ContextCompressor | None = None + """Custom context compression method. If None, the default method is used.""" diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py new file mode 100644 index 000000000..b8e131d98 --- /dev/null +++ b/astrbot/core/agent/context/manager.py @@ -0,0 +1,120 @@ +from astrbot import logger + +from ..message import Message +from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor +from .config import ContextConfig +from .token_counter import EstimateTokenCounter +from .truncator import ContextTruncator + + +class ContextManager: + """Context compression manager.""" + + def __init__( + self, + config: ContextConfig, + ): + """Initialize the context manager. + + There are two strategies to handle context limit reached: + 1. Truncate by turns: remove older messages by turns. + 2. LLM-based compression: use LLM to summarize old messages. + + Args: + config: The context configuration. + """ + self.config = config + + self.token_counter = config.custom_token_counter or EstimateTokenCounter() + self.truncator = ContextTruncator() + + if config.custom_compressor: + self.compressor = config.custom_compressor + elif config.llm_compress_provider: + self.compressor = LLMSummaryCompressor( + provider=config.llm_compress_provider, + keep_recent=config.llm_compress_keep_recent, + instruction_text=config.llm_compress_instruction, + ) + else: + self.compressor = TruncateByTurnsCompressor( + truncate_turns=config.truncate_turns + ) + + async def process( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> list[Message]: + """Process the messages. + + Args: + messages: The original message list. + + Returns: + The processed message list. + """ + try: + result = messages + + # 1. 基于轮次的截断 (Enforce max turns) + if self.config.enforce_max_turns != -1: + result = self.truncator.truncate_by_turns( + result, + keep_most_recent_turns=self.config.enforce_max_turns, + drop_turns=self.config.truncate_turns, + ) + + # 2. 基于 token 的压缩 + if self.config.max_context_tokens > 0: + total_tokens = self.token_counter.count_tokens( + result, trusted_token_usage + ) + + if self.compressor.should_compress( + result, total_tokens, self.config.max_context_tokens + ): + result = await self._run_compression(result, total_tokens) + + return result + except Exception as e: + logger.error(f"Error during context processing: {e}", exc_info=True) + return messages + + async def _run_compression( + self, messages: list[Message], prev_tokens: int + ) -> list[Message]: + """ + Compress/truncate the messages. + + Args: + messages: The original message list. + prev_tokens: The token count before compression. + + Returns: + The compressed/truncated message list. + """ + logger.debug("Compress triggered, starting compression...") + + messages = await self.compressor(messages) + + # double check + tokens_after_summary = self.token_counter.count_tokens(messages) + + # calculate compress rate + compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100 + logger.info( + f"Compress completed." + f" {prev_tokens} -> {tokens_after_summary} tokens," + f" compression rate: {compress_rate:.2f}%.", + ) + + # last check + if self.compressor.should_compress( + messages, tokens_after_summary, self.config.max_context_tokens + ): + logger.info( + "Context still exceeds max tokens after compression, applying halving truncation..." + ) + # still need compress, truncate by half + messages = self.truncator.truncate_by_halving(messages) + + return messages diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py new file mode 100644 index 000000000..1d4efbe8d --- /dev/null +++ b/astrbot/core/agent/context/token_counter.py @@ -0,0 +1,64 @@ +import json +from typing import Protocol, runtime_checkable + +from ..message import Message, TextPart + + +@runtime_checkable +class TokenCounter(Protocol): + """ + Protocol for token counters. + Provides an interface for counting tokens in message lists. + """ + + def count_tokens( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> int: + """Count the total tokens in the message list. + + Args: + messages: The message list. + trusted_token_usage: The total token usage that LLM API returned. + For some cases, this value is more accurate. + But some API does not return it, so the value defaults to 0. + + Returns: + The total token count. + """ + ... + + +class EstimateTokenCounter: + """Estimate token counter implementation. + Provides a simple estimation of token count based on character types. + """ + + def count_tokens( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> int: + if trusted_token_usage > 0: + return trusted_token_usage + + total = 0 + for msg in messages: + content = msg.content + if isinstance(content, str): + total += self._estimate_tokens(content) + elif isinstance(content, list): + # 处理多模态内容 + for part in content: + if isinstance(part, TextPart): + total += self._estimate_tokens(part.text) + + # 处理 Tool Calls + if msg.tool_calls: + for tc in msg.tool_calls: + tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump()) + total += self._estimate_tokens(tc_str) + + return total + + def _estimate_tokens(self, text: str) -> int: + chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"]) + other_count = len(text) - chinese_count + return int(chinese_count * 0.6 + other_count * 0.3) diff --git a/astrbot/core/agent/context/truncator.py b/astrbot/core/agent/context/truncator.py new file mode 100644 index 000000000..8d1da6f56 --- /dev/null +++ b/astrbot/core/agent/context/truncator.py @@ -0,0 +1,141 @@ +from ..message import Message + + +class ContextTruncator: + """Context truncator.""" + + def fix_messages(self, messages: list[Message]) -> list[Message]: + fixed_messages = [] + for message in messages: + if message.role == "tool": + # tool block 前面必须要有 user 和 assistant block + if len(fixed_messages) < 2: + # 这种情况可能是上下文被截断导致的 + # 我们直接将之前的上下文都清空 + fixed_messages = [] + else: + fixed_messages.append(message) + else: + fixed_messages.append(message) + return fixed_messages + + def truncate_by_turns( + self, + messages: list[Message], + keep_most_recent_turns: int, + drop_turns: int = 1, + ) -> list[Message]: + """截断上下文列表,确保不超过最大长度。 + 一个 turn 包含一个 user 消息和一个 assistant 消息。 + 这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。 + + Args: + messages: 上下文列表 + keep_most_recent_turns: 保留最近的对话轮数 + drop_turns: 一次性丢弃的对话轮数 + + Returns: + 截断后的上下文列表 + """ + if keep_most_recent_turns == -1: + return messages + + first_non_system = 0 + for i, msg in enumerate(messages): + if msg.role != "system": + first_non_system = i + break + + system_messages = messages[:first_non_system] + non_system_messages = messages[first_non_system:] + + if len(non_system_messages) // 2 <= keep_most_recent_turns: + return messages + + num_to_keep = keep_most_recent_turns - drop_turns + 1 + if num_to_keep <= 0: + truncated_contexts = [] + else: + truncated_contexts = non_system_messages[-num_to_keep * 2 :] + + # 找到第一个 role 为 user 的索引,确保上下文格式正确 + index = next( + (i for i, item in enumerate(truncated_contexts) if item.role == "user"), + None, + ) + if index is not None and index > 0: + truncated_contexts = truncated_contexts[index:] + + result = system_messages + truncated_contexts + + return self.fix_messages(result) + + def truncate_by_dropping_oldest_turns( + self, + messages: list[Message], + drop_turns: int = 1, + ) -> list[Message]: + """丢弃最旧的 N 个对话轮次。""" + if drop_turns <= 0: + return messages + + first_non_system = 0 + for i, msg in enumerate(messages): + if msg.role != "system": + first_non_system = i + break + + system_messages = messages[:first_non_system] + non_system_messages = messages[first_non_system:] + + if len(non_system_messages) // 2 <= drop_turns: + truncated_non_system = [] + else: + truncated_non_system = non_system_messages[drop_turns * 2 :] + + index = next( + (i for i, item in enumerate(truncated_non_system) if item.role == "user"), + None, + ) + if index is not None: + truncated_non_system = truncated_non_system[index:] + elif truncated_non_system: + truncated_non_system = [] + + result = system_messages + truncated_non_system + + return self.fix_messages(result) + + def truncate_by_halving( + self, + messages: list[Message], + ) -> list[Message]: + """对半砍策略,删除 50% 的消息""" + if len(messages) <= 2: + return messages + + first_non_system = 0 + for i, msg in enumerate(messages): + if msg.role != "system": + first_non_system = i + break + + system_messages = messages[:first_non_system] + non_system_messages = messages[first_non_system:] + + messages_to_delete = len(non_system_messages) // 2 + if messages_to_delete == 0: + return messages + + truncated_non_system = non_system_messages[messages_to_delete:] + + index = next( + (i for i, item in enumerate(truncated_non_system) if item.role == "user"), + None, + ) + if index is not None: + truncated_non_system = truncated_non_system[index:] + + result = system_messages + truncated_non_system + + return self.fix_messages(result) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 4b0c601b4..606163685 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -25,6 +25,10 @@ ) from astrbot.core.provider.provider import Provider +from ..context.compressor import ContextCompressor +from ..context.config import ContextConfig +from ..context.manager import ContextManager +from ..context.token_counter import TokenCounter from ..hooks import BaseAgentRunHooks from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment from ..response import AgentResponseData, AgentStats @@ -47,10 +51,47 @@ async def reset( run_context: ContextWrapper[TContext], tool_executor: BaseFunctionToolExecutor[TContext], agent_hooks: BaseAgentRunHooks[TContext], + streaming: bool = False, + # enforce max turns, will discard older turns when exceeded BEFORE compression + # -1 means no limit + enforce_max_turns: int = -1, + # llm compressor + llm_compress_instruction: str | None = None, + llm_compress_keep_recent: int = 0, + llm_compress_provider: Provider | None = None, + # truncate by turns compressor + truncate_turns: int = 1, + # customize + custom_token_counter: TokenCounter | None = None, + custom_compressor: ContextCompressor | None = None, **kwargs: T.Any, ) -> None: self.req = request - self.streaming = kwargs.get("streaming", False) + self.streaming = streaming + self.enforce_max_turns = enforce_max_turns + self.llm_compress_instruction = llm_compress_instruction + self.llm_compress_keep_recent = llm_compress_keep_recent + self.llm_compress_provider = llm_compress_provider + self.truncate_turns = truncate_turns + self.custom_token_counter = custom_token_counter + self.custom_compressor = custom_compressor + # we will do compress when: + # 1. before requesting LLM + # TODO: 2. after LLM output a tool call + self.context_config = ContextConfig( + # <=0 will never do compress + max_context_tokens=provider.provider_config.get("max_context_tokens", 0), + # enforce max turns before compression + enforce_max_turns=self.enforce_max_turns, + truncate_turns=self.truncate_turns, + llm_compress_instruction=self.llm_compress_instruction, + llm_compress_keep_recent=self.llm_compress_keep_recent, + llm_compress_provider=self.llm_compress_provider, + custom_token_counter=self.custom_token_counter, + custom_compressor=self.custom_compressor, + ) + self.context_manager = ContextManager(self.context_config) + self.provider = provider self.final_llm_resp = None self._state = AgentState.IDLE @@ -110,6 +151,12 @@ async def step(self): self._transition_state(AgentState.RUNNING) llm_resp_result = None + # do truncate and compress + token_usage = self.req.conversation.token_usage if self.req.conversation else 0 + self.run_context.messages = await self.context_manager.process( + self.run_context.messages, trusted_token_usage=token_usage + ) + async for llm_response in self._iter_llm_responses(): if llm_response.is_chunk: # update ttft diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 0d7495b68..38d3eb0d0 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -83,6 +83,16 @@ "default_personality": "default", "persona_pool": ["*"], "prompt_prefix": "{{prompt}}", + "context_limit_reached_strategy": "truncate_by_turns", # or llm_compress + "llm_compress_instruction": ( + "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" + "1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n" + "2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n" + "3. If there was an initial user goal, state it first and describe the current progress/status.\n" + "4. Write the summary in the user's language.\n" + ), + "llm_compress_keep_recent": 4, + "llm_compress_provider_id": "", "max_context_length": -1, "dequeue_context_length": 1, "streaming_response": False, @@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict): model: str modalities: list custom_extra_body: dict[str, Any] + max_context_tokens: int CHAT_PROVIDER_TEMPLATE = { @@ -187,6 +198,7 @@ class ChatProviderTemplate(TypedDict): "model": "", "modalities": [], "custom_extra_body": {}, + "max_context_tokens": 0, } """ @@ -2033,6 +2045,11 @@ class ChatProviderTemplate(TypedDict): "type": "string", "hint": "模型名称,如 gpt-4o-mini, deepseek-chat。", }, + "max_context_tokens": { + "description": "模型上下文窗口大小", + "type": "int", + "hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。", + }, "dify_api_key": { "description": "API Key", "type": "string", @@ -2540,6 +2557,66 @@ class ChatProviderTemplate(TypedDict): # "provider_settings.enable": True, # }, # }, + "truncate_and_compress": { + "description": "上下文管理策略", + "type": "object", + "items": { + "provider_settings.max_context_length": { + "description": "最多携带对话轮数", + "type": "int", + "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.dequeue_context_length": { + "description": "丢弃对话轮数", + "type": "int", + "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数", + "condition": { + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.context_limit_reached_strategy": { + "description": "超出模型上下文窗口时的处理方式", + "type": "string", + "options": ["truncate_by_turns", "llm_compress"], + "labels": ["按对话轮数截断", "由 LLM 压缩上下文"], + "condition": { + "provider_settings.agent_runner_type": "local", + }, + "hint": "", + }, + "provider_settings.llm_compress_instruction": { + "description": "上下文压缩提示词", + "type": "text", + "hint": "如果为空则使用默认提示词。", + "condition": { + "provider_settings.context_limit_reached_strategy": "llm_compress", + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.llm_compress_keep_recent": { + "description": "压缩时保留最近对话轮数", + "type": "int", + "hint": "始终保留的最近 N 轮对话。", + "condition": { + "provider_settings.context_limit_reached_strategy": "llm_compress", + "provider_settings.agent_runner_type": "local", + }, + }, + "provider_settings.llm_compress_provider_id": { + "description": "用于上下文压缩的模型提供商 ID", + "type": "string", + "_special": "select_provider", + "hint": "留空时将降级为“按对话轮数截断”的策略。", + "condition": { + "provider_settings.context_limit_reached_strategy": "llm_compress", + "provider_settings.agent_runner_type": "local", + }, + }, + }, + }, "others": { "description": "其他配置", "type": "object", @@ -2604,22 +2681,6 @@ class ChatProviderTemplate(TypedDict): "provider_settings.streaming_response": True, }, }, - "provider_settings.max_context_length": { - "description": "最多携带对话轮数", - "type": "int", - "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制", - "condition": { - "provider_settings.agent_runner_type": "local", - }, - }, - "provider_settings.dequeue_context_length": { - "description": "丢弃对话轮数", - "type": "int", - "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数", - "condition": { - "provider_settings.agent_runner_type": "local", - }, - }, "provider_settings.wake_prefix": { "description": "LLM 聊天额外唤醒前缀 ", "type": "string", diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 287fe03c4..a0a0c0e2f 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -69,6 +69,7 @@ def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: persona_id=conv_v2.persona_id, created_at=created_at, updated_at=updated_at, + token_usage=conv_v2.token_usage, ) async def new_conversation( @@ -256,6 +257,7 @@ async def update_conversation( history: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, + token_usage: int | None = None, ) -> None: """更新会话的对话. @@ -263,6 +265,7 @@ async def update_conversation( unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 + token_usage (int | None): token 使用量。None 表示不更新 """ if not conversation_id: @@ -274,6 +277,7 @@ async def update_conversation( title=title, persona_id=persona_id, content=history, + token_usage=token_usage, ) async def update_conversation_title( diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 192c7b263..3a79e41c2 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -152,6 +152,7 @@ async def update_conversation( title: str | None = None, persona_id: str | None = None, content: list[dict] | None = None, + token_usage: int | None = None, ) -> None: """Update a conversation's history.""" ... diff --git a/astrbot/core/db/migration/migra_token_usage.py b/astrbot/core/db/migration/migra_token_usage.py new file mode 100644 index 000000000..07938301d --- /dev/null +++ b/astrbot/core/db/migration/migra_token_usage.py @@ -0,0 +1,61 @@ +"""Migration script to add token_usage column to conversations table. + +This migration adds the token_usage field to track token consumption for each conversation. + +Changes: +- Adds token_usage column to conversations table (default: 0) +""" + +from sqlalchemy import text + +from astrbot.api import logger, sp +from astrbot.core.db import BaseDatabase + + +async def migrate_token_usage(db_helper: BaseDatabase): + """Add token_usage column to conversations table. + + This migration adds a new column to track token consumption in conversations. + """ + # 检查是否已经完成迁移 + migration_done = await db_helper.get_preference( + "global", "global", "migration_done_token_usage_1" + ) + if migration_done: + return + + logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...") + + # 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。 + + try: + async with db_helper.get_db() as session: + # 检查列是否已存在 + result = await session.execute(text("PRAGMA table_info(conversations)")) + columns = result.fetchall() + column_names = [col[1] for col in columns] + + if "token_usage" in column_names: + logger.info("token_usage 列已存在,跳过迁移") + await sp.put_async( + "global", "global", "migration_done_token_usage_1", True + ) + return + + # 添加 token_usage 列 + await session.execute( + text( + "ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0" + ) + ) + await session.commit() + + logger.info("token_usage 列添加成功") + + # 标记迁移完成 + await sp.put_async("global", "global", "migration_done_token_usage_1", True) + logger.info("token_usage 迁移完成") + + except Exception as e: + logger.error(f"迁移过程中发生错误: {e}", exc_info=True) + raise diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 64bcf4ce3..fdbf4aff3 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True): ) title: str | None = Field(default=None, max_length=255) persona_id: str | None = Field(default=None) + token_usage: int = Field(default=0, nullable=False) + """content is a list of OpenAI-formated messages in list[dict] format. + token_usage is the total token value of the messages. + when 0, will use estimated token counter. + """ __table_args__ = ( UniqueConstraint( @@ -313,6 +318,8 @@ class Conversation: persona_id: str | None = "" created_at: int = 0 updated_at: int = 0 + token_usage: int = 0 + """对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。""" class Personality(TypedDict): diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index fa3ca9a76..7422a5cc2 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -241,7 +241,9 @@ async def create_conversation( session.add(new_conversation) return new_conversation - async def update_conversation(self, cid, title=None, persona_id=None, content=None): + async def update_conversation( + self, cid, title=None, persona_id=None, content=None, token_usage=None + ): async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -255,6 +257,8 @@ async def update_conversation(self, cid, title=None, persona_id=None, content=No values["persona_id"] = persona_id if content is not None: values["content"] = content + if token_usage is not None: + values["token_usage"] = token_usage if not values: return None query = query.values(**values) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index ed6dc32cf..69bd04314 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -1,12 +1,12 @@ """本地 Agent 模式的 LLM 调用 Stage""" import asyncio -import copy import json from collections.abc import AsyncGenerator from astrbot.core import logger from astrbot.core.agent.message import Message +from astrbot.core.agent.response import AgentStats from astrbot.core.agent.tool import ToolSet from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.conversation_mgr import Conversation @@ -24,6 +24,7 @@ ) from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.file_extract import extract_file_moonshotai +from astrbot.core.utils.llm_metadata import LLM_METADATAS from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager @@ -41,11 +42,6 @@ async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx conf = ctx.astrbot_config settings = conf["provider_settings"] - self.max_context_length = settings["max_context_length"] # int - self.dequeue_context_length: int = min( - max(1, settings["dequeue_context_length"]), - self.max_context_length - 1, - ) self.streaming_response: bool = settings["streaming_response"] self.unsupported_streaming_strategy: str = settings[ "unsupported_streaming_strategy" @@ -65,6 +61,25 @@ async def initialize(self, ctx: PipelineContext) -> None: "moonshotai_api_key", "" ) + # 上下文管理相关 + self.context_limit_reached_strategy: str = settings.get( + "context_limit_reached_strategy", "truncate_by_turns" + ) + self.llm_compress_instruction: str = settings.get( + "llm_compress_instruction", "" + ) + self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4) + self.llm_compress_provider_id: str = settings.get( + "llm_compress_provider_id", "" + ) + self.max_context_length = settings["max_context_length"] # int + self.dequeue_context_length: int = min( + max(1, settings["dequeue_context_length"]), + self.max_context_length - 1, + ) + if self.dequeue_context_length <= 0: + self.dequeue_context_length = 1 + self.conv_manager = ctx.plugin_manager.context.conversation_manager def _select_provider(self, event: AstrMessageEvent): @@ -167,34 +182,6 @@ async def _apply_file_extract( }, ) - def _truncate_contexts( - self, - contexts: list[dict], - ) -> list[dict]: - """截断上下文列表,确保不超过最大长度""" - if self.max_context_length == -1: - return contexts - - if len(contexts) // 2 <= self.max_context_length: - return contexts - - truncated_contexts = contexts[ - -(self.max_context_length - self.dequeue_context_length + 1) * 2 : - ] - # 找到第一个role 为 user 的索引,确保上下文格式正确 - index = next( - ( - i - for i, item in enumerate(truncated_contexts) - if item.get("role") == "user" - ), - None, - ) - if index is not None and index > 0: - truncated_contexts = truncated_contexts[index:] - - return truncated_contexts - def _modalities_fix( self, provider: Provider, @@ -296,6 +283,7 @@ async def _save_to_history( req: ProviderRequest, llm_response: LLMResponse | None, all_messages: list[Message], + runner_stats: AgentStats | None, ): if ( not req @@ -322,27 +310,37 @@ async def _save_to_history( continue message_to_save.append(message.model_dump()) + # get token usage from agent runner stats + token_usage = None + if runner_stats: + token_usage = runner_stats.token_usage.total + await self.conv_manager.update_conversation( event.unified_msg_origin, req.conversation.cid, history=message_to_save, + token_usage=token_usage, ) - def _fix_messages(self, messages: list[dict]) -> list[dict]: - """验证并且修复上下文""" - fixed_messages = [] - for message in messages: - if message.get("role") == "tool": - # tool block 前面必须要有 user 和 assistant block - if len(fixed_messages) < 2: - # 这种情况可能是上下文被截断导致的 - # 我们直接将之前的上下文都清空 - fixed_messages = [] - else: - fixed_messages.append(message) - else: - fixed_messages.append(message) - return fixed_messages + def _get_compress_provider(self) -> Provider | None: + if not self.llm_compress_provider_id: + return None + if self.context_limit_reached_strategy != "llm_compress": + return None + provider = self.ctx.plugin_manager.context.get_provider_by_id( + self.llm_compress_provider_id, + ) + if provider is None: + logger.warning( + f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。", + ) + return None + if not isinstance(provider, Provider): + logger.warning( + f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。" + ) + return None + return provider async def process( self, event: AstrMessageEvent, provider_wake_prefix: str @@ -426,9 +424,10 @@ async def process( await self._apply_kb(event, req) # truncate contexts to fit max length - if req.contexts: - req.contexts = self._truncate_contexts(req.contexts) - self._fix_messages(req.contexts) + # NOW moved to ContextManager inside ToolLoopAgentRunner + # if req.contexts: + # req.contexts = self._truncate_contexts(req.contexts) + # self._fix_messages(req.contexts) # session_id if not req.session_id: @@ -444,8 +443,6 @@ async def process( self.unsupported_streaming_strategy == "turn_off" and not event.platform_meta.support_streaming_message ) - # 备份 req.contexts - backup_contexts = copy.deepcopy(req.contexts) # run agent agent_runner = AgentRunner() @@ -456,6 +453,15 @@ async def process( context=self.ctx.plugin_manager.context, event=event, ) + + # inject model context length limit + if provider.provider_config.get("max_context_tokens", 0) <= 0: + model = provider.get_model() + if model_info := LLM_METADATAS.get(model): + provider.provider_config["max_context_tokens"] = model_info[ + "limit" + ]["context"] + await agent_runner.reset( provider=provider, request=req, @@ -466,6 +472,11 @@ async def process( tool_executor=FunctionToolExecutor(), agent_hooks=MAIN_AGENT_HOOKS, streaming=streaming_response, + llm_compress_instruction=self.llm_compress_instruction, + llm_compress_keep_recent=self.llm_compress_keep_recent, + llm_compress_provider=self._get_compress_provider(), + truncate_turns=self.dequeue_context_length, + enforce_max_turns=self.max_context_length, ) if streaming_response and not stream_to_general: @@ -511,14 +522,12 @@ async def process( ): yield - # 恢复备份的 contexts - req.contexts = backup_contexts - await self._save_to_history( event, req, agent_runner.get_final_llm_resp(), agent_runner.run_context.messages, + agent_runner.stats, ) # 异步处理 WebChat 特殊情况 diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index c8e470632..a64d2a9ee 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -149,9 +149,12 @@ async def tool_loop_agent( contexts: context messages for the LLM max_steps: Maximum number of tool calls before stopping the loop **kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include: + stream: bool - whether to stream the LLM response agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution agent_context: AstrAgentContext - context to use for the agent + other kwargs will be DIRECTLY passed to the runner.reset() method + Returns: The final LLMResponse after tool calls are completed. @@ -194,6 +197,15 @@ async def tool_loop_agent( ) agent_runner = ToolLoopAgentRunner() tool_executor = FunctionToolExecutor() + + streaming = kwargs.get("stream", False) + + other_kwargs = { + k: v + for k, v in kwargs.items() + if k not in ["stream", "agent_hooks", "agent_context"] + } + await agent_runner.reset( provider=prov, request=request, @@ -203,7 +215,8 @@ async def tool_loop_agent( ), tool_executor=tool_executor, agent_hooks=agent_hooks, - streaming=kwargs.get("stream", False), + streaming=streaming, + **other_kwargs, ) async for _ in agent_runner.step_until_done(max_steps): pass diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index b8ff677e1..6a300302d 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -3,6 +3,7 @@ from astrbot.core import astrbot_config, logger from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_token_usage import migrate_token_usage from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session @@ -139,6 +140,13 @@ async def migra( logger.error(f"Migration for webchat session failed: {e!s}") logger.error(traceback.format_exc()) + # migration for token_usage column + try: + await migrate_token_usage(db) + except Exception as e: + logger.error(f"Migration for token_usage column failed: {e!s}") + logger.error(traceback.format_exc()) + # migra third party agent runner configs _c = False providers = astrbot_config["provider"] diff --git a/dashboard/src/components/shared/ConfigItemRenderer.vue b/dashboard/src/components/shared/ConfigItemRenderer.vue index 88674eb0a..24ea8f9ce 100644 --- a/dashboard/src/components/shared/ConfigItemRenderer.vue +++ b/dashboard/src/components/shared/ConfigItemRenderer.vue @@ -144,7 +144,7 @@ color="primary" density="compact" hide-details - class="flex-grow-1" + style="flex: 1" > @@ -325,4 +325,8 @@ function getSpecialSubtype(value) { .gap-20 { gap: 20px; } + +:deep(.v-field__input) { + font-size: 14px; +} diff --git a/dashboard/src/composables/useProviderSources.ts b/dashboard/src/composables/useProviderSources.ts index e8bf58f45..dc0059b04 100644 --- a/dashboard/src/composables/useProviderSources.ts +++ b/dashboard/src/composables/useProviderSources.ts @@ -510,7 +510,7 @@ export function useProviderSources(options: UseProviderSourcesOptions) { const metadata = getModelMetadata(modelName) let modalities: string[] - + if (!metadata) { modalities = ['text', 'image', 'tool_use'] } else { @@ -523,13 +523,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) { } } + let max_context_tokens = 0 + if (metadata?.limit?.context && typeof metadata.limit.context === 'number') { + max_context_tokens = metadata.limit.context + } + const newProvider = { id: newId, enable: false, provider_source_id: sourceId, model: modelName, modalities, - custom_extra_body: {} + custom_extra_body: {}, + max_context_tokens: max_context_tokens } try { diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index ff9a68256..e0f694c33 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -11,7 +11,12 @@ }, "agent_runner_type": { "description": "Runner", - "labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"] + "labels": [ + "Built-in Agent", + "Dify", + "Coze", + "Alibaba Cloud Bailian Application" + ] }, "coze_agent_runner_provider_id": { "description": "Coze Agent Runner Provider ID" @@ -128,6 +133,39 @@ } } }, + "truncate_and_compress": { + "description": "Context Management Strategy", + "provider_settings": { + "max_context_length": { + "description": "Maximum Conversation Turns", + "hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited" + }, + "dequeue_context_length": { + "description": "Dequeue Conversation Turns", + "hint": "Number of conversation turns to discard at once when maximum context length is exceeded" + }, + "context_limit_reached_strategy": { + "description": "Handling When Model Context Window is Exceeded", + "labels": [ + "Truncate by Turns", + "Compress by LLM" + ], + "hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression." + }, + "llm_compress_instruction": { + "description": "Context Compression Instruction", + "hint": "If empty, the default prompt will be used." + }, + "llm_compress_keep_recent": { + "description": "Keep Recent Turns When Compressing", + "hint": "Always keep the most recent N turns of conversation when compressing context." + }, + "llm_compress_provider_id": { + "description": "Model Provider ID for Context Compression", + "hint": "When left empty, will fall back to the 'Truncate by Turns' strategy." + } + } + }, "others": { "description": "Other Settings", "provider_settings": { @@ -161,15 +199,10 @@ "unsupported_streaming_strategy": { "description": "Platforms Without Streaming Support", "hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception", - "labels": ["Real-time Segmented Reply", "Disable Streaming Response"] - }, - "max_context_length": { - "description": "Maximum Conversation Rounds", - "hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited" - }, - "dequeue_context_length": { - "description": "Dequeue Conversation Rounds", - "hint": "Number of conversation rounds to discard at once when maximum context length is exceeded" + "labels": [ + "Real-time Segmented Reply", + "Disable Streaming Response" + ] }, "wake_prefix": { "description": "Additional LLM Chat Wake Prefix", @@ -387,7 +420,10 @@ }, "split_mode": { "description": "Split Mode", - "labels": ["Regex", "Words List"] + "labels": [ + "Regex", + "Words List" + ] }, "regex": { "description": "Segmentation Regular Expression" @@ -488,4 +524,4 @@ } } } -} +} \ No newline at end of file diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index aba9cfd35..589aa54a0 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -133,6 +133,36 @@ } } }, + "truncate_and_compress": { + "description": "上下文管理策略", + "provider_settings": { + "max_context_length": { + "description": "最多携带对话轮数", + "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制" + }, + "dequeue_context_length": { + "description": "丢弃对话轮数", + "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数" + }, + "context_limit_reached_strategy": { + "description": "超出模型上下文窗口时的处理方式", + "labels": ["按对话轮数截断", "由 LLM 压缩上下文"], + "hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。" + }, + "llm_compress_instruction": { + "description": "上下文压缩提示词", + "hint": "如果为空则使用默认提示词。" + }, + "llm_compress_keep_recent": { + "description": "压缩时保留最近对话轮数", + "hint": "始终保留的最近 N 轮对话。" + }, + "llm_compress_provider_id": { + "description": "用于上下文压缩的模型提供商 ID", + "hint": "留空时将降级为\"按对话轮数截断\"的策略。" + } + } + }, "others": { "description": "其他配置", "provider_settings": { @@ -171,14 +201,7 @@ "关闭流式回复" ] }, - "max_context_length": { - "description": "最多携带对话轮数", - "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制" - }, - "dequeue_context_length": { - "description": "丢弃对话轮数", - "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数" - }, + "wake_prefix": { "description": "LLM 聊天额外唤醒前缀", "hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求" diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py new file mode 100644 index 000000000..0b955ff40 --- /dev/null +++ b/tests/agent/test_context_manager.py @@ -0,0 +1,774 @@ +"""Comprehensive tests for ContextManager.""" + +import sys +from pathlib import Path +from typing import Literal +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Add parent directory to path to avoid circular import issues +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from astrbot.core.agent.context.config import ContextConfig +from astrbot.core.agent.context.manager import ContextManager +from astrbot.core.agent.message import Message, TextPart +from astrbot.core.provider.entities import LLMResponse + + +class MockProvider: + """模拟 Provider""" + + def __init__(self): + self.provider_config = { + "id": "test_provider", + "model": "gpt-4", + "modalities": ["text", "image", "tool_use"], + } + + async def text_chat(self, **kwargs): + """模拟 LLM 调用,返回摘要""" + messages = kwargs.get("messages", []) + # 简单的摘要逻辑:返回消息数量统计 + return LLMResponse( + role="assistant", + completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。", + ) + + def get_model(self): + return "gpt-4" + + def meta(self): + return MagicMock(id="test_provider", type="openai") + + +class TestContextManager: + """Test suite for ContextManager.""" + + def create_message( + self, role: Literal["system", "user", "assistant", "tool"], content: str + ) -> Message: + """Helper to create a simple text message.""" + return Message(role=role, content=content) + + def create_messages(self, count: int) -> list[Message]: + """Helper to create alternating user/assistant messages.""" + messages = [] + for i in range(count): + role = "user" if i % 2 == 0 else "assistant" + messages.append(self.create_message(role, f"Message {i}")) + return messages + + # ==================== Basic Initialization Tests ==================== + + def test_init_with_minimal_config(self): + """Test initialization with minimal configuration.""" + config = ContextConfig() + manager = ContextManager(config) + + assert manager.config == config + assert manager.token_counter is not None + assert manager.truncator is not None + assert manager.compressor is not None + + def test_init_with_llm_compressor(self): + """Test initialization with LLM-based compression.""" + mock_provider = MockProvider() + config = ContextConfig( + llm_compress_provider=mock_provider, # type: ignore + llm_compress_keep_recent=5, + llm_compress_instruction="Summarize the conversation", + ) + manager = ContextManager(config) + + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + assert isinstance(manager.compressor, LLMSummaryCompressor) + + def test_init_with_truncate_compressor(self): + """Test initialization with truncate-based compression (default).""" + config = ContextConfig(truncate_turns=3) + manager = ContextManager(config) + + from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor + + assert isinstance(manager.compressor, TruncateByTurnsCompressor) + + # ==================== Empty and Edge Cases ==================== + + @pytest.mark.asyncio + async def test_process_empty_messages(self): + """Test processing an empty message list.""" + config = ContextConfig() + manager = ContextManager(config) + + result = await manager.process([]) + + assert result == [] + + @pytest.mark.asyncio + async def test_process_single_message(self): + """Test processing a single message.""" + config = ContextConfig() + manager = ContextManager(config) + + messages = [self.create_message("user", "Hello")] + result = await manager.process(messages) + + assert len(result) == 1 + assert result[0].content == "Hello" + + @pytest.mark.asyncio + async def test_process_with_no_limits(self): + """Test processing when no limits are set (no truncation or compression).""" + config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1) + manager = ContextManager(config) + + messages = self.create_messages(20) + result = await manager.process(messages) + + assert len(result) == 20 + assert result == messages + + # ==================== Enforce Max Turns Tests ==================== + + @pytest.mark.asyncio + async def test_enforce_max_turns_basic(self): + """Test basic enforce_max_turns functionality.""" + config = ContextConfig(enforce_max_turns=3, truncate_turns=1) + manager = ContextManager(config) + + # Create 10 turns (20 messages) + messages = self.create_messages(20) + result = await manager.process(messages) + + # Should keep only 3 most recent turns (6 messages) + assert len(result) <= 8 # May vary due to truncation logic + + @pytest.mark.asyncio + async def test_enforce_max_turns_zero(self): + """Test enforce_max_turns with value 0 (should keep nothing).""" + config = ContextConfig(enforce_max_turns=0, truncate_turns=1) + manager = ContextManager(config) + + messages = self.create_messages(10) + result = await manager.process(messages) + + # Should result in empty or minimal message list + assert len(result) <= 2 + + @pytest.mark.asyncio + async def test_enforce_max_turns_negative(self): + """Test enforce_max_turns with -1 (no limit).""" + config = ContextConfig(enforce_max_turns=-1) + manager = ContextManager(config) + + messages = self.create_messages(20) + result = await manager.process(messages) + + assert len(result) == 20 + + @pytest.mark.asyncio + async def test_enforce_max_turns_with_system_messages(self): + """Test enforce_max_turns preserves system messages.""" + config = ContextConfig(enforce_max_turns=2, truncate_turns=1) + manager = ContextManager(config) + + messages = [ + self.create_message("system", "System instruction"), + *self.create_messages(10), + ] + result = await manager.process(messages) + + # System message should be preserved + system_msgs = [m for m in result if m.role == "system"] + assert len(system_msgs) >= 1 + assert system_msgs[0].content == "System instruction" + + # ==================== Token-based Compression Tests ==================== + + @pytest.mark.asyncio + async def test_token_compression_not_triggered_below_threshold(self): + """Test that compression is not triggered below threshold.""" + config = ContextConfig(max_context_tokens=1000) + manager = ContextManager(config) + + # Create messages that total less than threshold + messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens + + with patch.object( + manager.compressor, "should_compress", return_value=False + ) as mock_should_compress: + with patch.object( + manager.compressor, "__call__", new_callable=AsyncMock + ) as mock_compress: + result = await manager.process(messages) + + # should_compress should be called + mock_should_compress.assert_called_once() + # Compressor should not be called + mock_compress.assert_not_called() + assert result == messages + + @pytest.mark.asyncio + async def test_token_compression_triggered_above_threshold(self): + """Test that compression is triggered above threshold.""" + config = ContextConfig(max_context_tokens=100, truncate_turns=1) + manager = ContextManager(config) + + # Create messages that exceed threshold (0.82 * 100 = 82 tokens) + # 300 chars * 0.3 = 90 tokens > 82 threshold + long_text = "x" * 300 # ~90 tokens, above threshold + messages = [self.create_message("user", long_text)] + + # Mock compressor to return smaller result + compressed = [self.create_message("user", "short")] + + # Create a mock compressor + mock_compressor = AsyncMock() + mock_compressor.compression_threshold = 0.82 + mock_compressor.return_value = compressed + + # Mock should_compress to return True first time, False after + call_count = 0 + + def mock_should_compress(*args, **kwargs): + nonlocal call_count + call_count += 1 + return call_count == 1 + + mock_compressor.should_compress = mock_should_compress + manager.compressor = mock_compressor + + result = await manager.process(messages) + + # Compressor should be called + mock_compressor.assert_called_once() + # Result should be the compressed version + assert len(result) <= len(messages) + + @pytest.mark.asyncio + async def test_token_compression_with_zero_max_tokens(self): + """Test that compression is skipped when max_context_tokens is 0.""" + config = ContextConfig(max_context_tokens=0) + manager = ContextManager(config) + + messages = [self.create_message("user", "x" * 10000)] + + with patch.object( + manager.compressor, "__call__", new_callable=AsyncMock + ) as mock_compress: + result = await manager.process(messages) + + # Compressor should not be called when max_context_tokens is 0 + mock_compress.assert_not_called() + assert result == messages + + @pytest.mark.asyncio + async def test_token_compression_with_negative_max_tokens(self): + """Test that compression is skipped when max_context_tokens is negative.""" + config = ContextConfig(max_context_tokens=-100) + manager = ContextManager(config) + + messages = [self.create_message("user", "x" * 10000)] + + with patch.object( + manager.compressor, "__call__", new_callable=AsyncMock + ) as mock_compress: + result = await manager.process(messages) + + # Compressor should not be called + mock_compress.assert_not_called() + assert result == messages + + @pytest.mark.asyncio + async def test_double_check_after_compression(self): + """Test that halving is applied if still over threshold after compression.""" + config = ContextConfig(max_context_tokens=100) + manager = ContextManager(config) + + # Create messages that would still be over threshold after compression + long_messages = [self.create_message("user", "x" * 200) for _ in range(10)] + + # Mock compressor to return messages still over threshold + async def mock_compress(msgs): + return msgs # Return same messages (still over limit) + + # Mock should_compress to return True twice (before and after compression) + with patch.object(manager.compressor, "should_compress", return_value=True): + with patch.object(manager.compressor, "__call__", new=mock_compress): + with patch.object( + manager.truncator, + "truncate_by_halving", + return_value=long_messages[:5], + ) as mock_halving: + _ = await manager.process(long_messages) + + # Halving should be called + mock_halving.assert_called_once() + + # ==================== Combined Truncation and Compression Tests ==================== + + @pytest.mark.asyncio + async def test_combined_enforce_turns_and_token_limit(self): + """Test combining enforce_max_turns and token limit.""" + config = ContextConfig( + enforce_max_turns=5, max_context_tokens=500, truncate_turns=1 + ) + manager = ContextManager(config) + + # Create many messages + messages = self.create_messages(30) + + result = await manager.process(messages) + + # Should be truncated by both mechanisms + assert len(result) < 30 + + @pytest.mark.asyncio + async def test_sequential_processing_order(self): + """Test that enforce_max_turns happens before token compression.""" + config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000) + manager = ContextManager(config) + + messages = self.create_messages(20) + + # Mock the truncator to track calls + with patch.object( + manager.truncator, + "truncate_by_turns", + wraps=manager.truncator.truncate_by_turns, + ) as mock_truncate: + await manager.process(messages) + + # Truncator should be called first + mock_truncate.assert_called_once() + + # ==================== Error Handling Tests ==================== + + @pytest.mark.asyncio + async def test_error_handling_returns_original_messages(self): + """Test that errors during processing return original messages.""" + config = ContextConfig(max_context_tokens=100) + manager = ContextManager(config) + + messages = self.create_messages(5) + + # Make compressor raise an exception + with patch.object( + manager.compressor, "__call__", side_effect=Exception("Test error") + ): + result = await manager.process(messages) + + # Should return original messages despite error + assert result == messages + + @pytest.mark.asyncio + async def test_error_handling_logs_exception(self): + """Test that errors are logged.""" + config = ContextConfig(max_context_tokens=100) + manager = ContextManager(config) + + # Create messages that will trigger compression (> 82 tokens) + messages = [self.create_message("user", "x" * 300)] # ~90 tokens + + # Replace compressor with one that raises an exception + mock_compressor = AsyncMock(side_effect=Exception("Test error")) + mock_compressor.compression_threshold = 0.82 + mock_compressor.should_compress = MagicMock(return_value=True) + manager.compressor = mock_compressor + + with patch("astrbot.core.agent.context.manager.logger") as mock_logger: + result = await manager.process(messages) + + # Logger error method should be called + assert mock_logger.error.called + # Should return original messages on error + assert result == messages + + # ==================== Multi-modal Content Tests ==================== + + @pytest.mark.asyncio + async def test_process_messages_with_textpart_content(self): + """Test processing messages with TextPart content.""" + config = ContextConfig() + manager = ContextManager(config) + + messages = [ + Message(role="user", content=[TextPart(text="Hello")]), + Message(role="assistant", content=[TextPart(text="Hi there")]), + ] + + result = await manager.process(messages) + + assert len(result) == 2 + assert result == messages + + @pytest.mark.asyncio + async def test_token_counting_with_multimodal_content(self): + """Test token counting works with multi-modal content.""" + config = ContextConfig(max_context_tokens=50) + manager = ContextManager(config) + + # Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens + # 150 chars * 0.3 = 45 tokens > 41 + messages = [ + Message(role="user", content=[TextPart(text="x" * 150)]), + ] + + # Should trigger compression due to token count + tokens = manager.token_counter.count_tokens(messages) + needs_compression = manager.compressor.should_compress(messages, tokens, 50) + + assert tokens > 0 # Tokens should be counted + assert needs_compression # Should trigger compression + + # ==================== Tool Calls Tests ==================== + + @pytest.mark.asyncio + async def test_process_messages_with_tool_calls(self): + """Test processing messages with tool calls.""" + config = ContextConfig() + manager = ContextManager(config) + + messages = [ + Message( + role="assistant", + content="Let me search for that", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + ), + Message(role="tool", content="Search result", tool_call_id="call_1"), + ] + + result = await manager.process(messages) + + assert len(result) == 2 + + # ==================== Compressor should_compress Tests ==================== + + @pytest.mark.asyncio + async def test_should_compress_empty_messages(self): + """Test should_compress with empty messages.""" + config = ContextConfig(max_context_tokens=100) + manager = ContextManager(config) + + # Compressor's should_compress should handle empty gracefully + needs_compression = manager.compressor.should_compress([], 0, 100) + assert not needs_compression + + @pytest.mark.asyncio + async def test_should_compress_below_threshold(self): + """Test should_compress when below compression threshold.""" + config = ContextConfig(max_context_tokens=1000) + manager = ContextManager(config) + + messages = [self.create_message("user", "Hello")] + tokens = manager.token_counter.count_tokens(messages) + + needs_compression = manager.compressor.should_compress(messages, tokens, 1000) + assert not needs_compression + + @pytest.mark.asyncio + async def test_should_compress_above_threshold(self): + """Test should_compress when above compression threshold.""" + config = ContextConfig(max_context_tokens=100) + manager = ContextManager(config) + + # Create message with many tokens + messages = [self.create_message("user", "这是测试" * 50)] + tokens = manager.token_counter.count_tokens(messages) + + needs_compression = manager.compressor.should_compress(messages, tokens, 100) + # Should need compression if tokens > 82 (0.82 * 100) + assert needs_compression == (tokens > 82) + + # ==================== Truncator Halving Tests ==================== + + def test_truncate_by_halving_basic(self): + """Test truncate_by_halving removes middle 50%.""" + config = ContextConfig() + manager = ContextManager(config) + + messages = self.create_messages(10) + result = manager.truncator.truncate_by_halving(messages) + + # Should keep roughly half + assert len(result) < len(messages) + + def test_truncate_by_halving_empty_list(self): + """Test truncate_by_halving with empty list.""" + config = ContextConfig() + manager = ContextManager(config) + + result = manager.truncator.truncate_by_halving([]) + + assert result == [] + + def test_truncate_by_halving_single_message(self): + """Test truncate_by_halving with single message.""" + config = ContextConfig() + manager = ContextManager(config) + + messages = [self.create_message("user", "Hello")] + result = manager.truncator.truncate_by_halving(messages) + + assert len(result) <= 1 + + # ==================== Complex Scenarios ==================== + + @pytest.mark.asyncio + async def test_multiple_compression_cycles(self): + """Test that compression can be triggered multiple times in sequence.""" + config = ContextConfig(max_context_tokens=50, truncate_turns=1) + manager = ContextManager(config) + + # Process messages multiple times + messages = self.create_messages(10) + + result1 = await manager.process(messages) + result2 = await manager.process(result1) + result3 = await manager.process(result2) + + # Each cycle should maintain or reduce message count + assert len(result3) <= len(result2) <= len(result1) + + @pytest.mark.asyncio + async def test_alternating_roles_preserved(self): + """Test that user/assistant alternation is preserved after processing.""" + config = ContextConfig(enforce_max_turns=3, truncate_turns=1) + manager = ContextManager(config) + + messages = self.create_messages(20) + result = await manager.process(messages) + + # Check that roles still alternate (excluding system messages) + non_system = [m for m in result if m.role != "system"] + if len(non_system) >= 2: + # Should start with user + assert non_system[0].role == "user" + + @pytest.mark.asyncio + async def test_compression_threshold_default(self): + """Test that compression threshold is used correctly.""" + config = ContextConfig(max_context_tokens=100) + manager = ContextManager(config) + + # Verify the default threshold is 0.82 + assert manager.compressor.compression_threshold == 0.82 + + # Test threshold logic + messages = [self.create_message("user", "x" * 81)] # ~24 tokens + tokens = manager.token_counter.count_tokens(messages) + + needs_compression = manager.compressor.should_compress(messages, tokens, 100) + # Should not compress if below threshold + assert needs_compression == (tokens > 82) + + @pytest.mark.asyncio + async def test_large_batch_processing(self): + """Test processing a large batch of messages.""" + config = ContextConfig( + enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2 + ) + manager = ContextManager(config) + + # Create 100 messages (50 turns) + messages = self.create_messages(100) + + result = await manager.process(messages) + + # Should be significantly reduced + assert len(result) < 100 + assert len(result) > 0 + + @pytest.mark.asyncio + async def test_config_persistence(self): + """Test that config settings are respected throughout processing.""" + config = ContextConfig( + max_context_tokens=500, + enforce_max_turns=5, + truncate_turns=2, + llm_compress_keep_recent=3, + ) + manager = ContextManager(config) + + # Verify config is stored + assert manager.config.max_context_tokens == 500 + assert manager.config.enforce_max_turns == 5 + assert manager.config.truncate_turns == 2 + assert manager.config.llm_compress_keep_recent == 3 + + # ==================== Run Compression Tests ==================== + + @pytest.mark.asyncio + async def test_run_compression_calls_compressor(self): + """Test _run_compression calls compressor.""" + config = ContextConfig(max_context_tokens=100) + manager = ContextManager(config) + + messages = self.create_messages(5) + compressed = self.create_messages(3) + + # Create a mock compressor + mock_compressor = AsyncMock() + mock_compressor.compression_threshold = 0.82 + mock_compressor.return_value = compressed + mock_compressor.should_compress = MagicMock(return_value=False) + manager.compressor = mock_compressor + + result = await manager._run_compression(messages, prev_tokens=100) + + # Compressor __call__ should be invoked + mock_compressor.assert_called_once_with(messages) + assert result == compressed + + @pytest.mark.asyncio + async def test_run_compression_applies_compressor_through_process(self): + """Test _run_compression calls compressor when needed through process().""" + config = ContextConfig(max_context_tokens=100, truncate_turns=1) + manager = ContextManager(config) + + # Create messages that will trigger compression + messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold + compressed = [self.create_message("user", "short")] # Much smaller + + # Create a mock compressor + mock_compressor = AsyncMock() + mock_compressor.compression_threshold = 0.82 + mock_compressor.return_value = compressed + + # Mock should_compress to return True first time, False after + call_count = 0 + + def mock_should_compress(*args, **kwargs): + nonlocal call_count + call_count += 1 + return call_count == 1 + + mock_compressor.should_compress = mock_should_compress + manager.compressor = mock_compressor + + result = await manager.process(messages) + + # Compressor should have been called + mock_compressor.assert_called_once() + assert len(result) <= len(messages) + + @pytest.mark.asyncio + async def test_llm_compression_with_mock_provider(self): + """Test LLM compression using MockProvider.""" + mock_provider = MockProvider() + config = ContextConfig( + llm_compress_provider=mock_provider, # type: ignore + llm_compress_keep_recent=3, + llm_compress_instruction="请总结对话内容", + max_context_tokens=100, + ) + manager = ContextManager(config) + + # Create messages that will trigger compression + messages = [ + self.create_message("user", "x" * 100), + self.create_message("assistant", "y" * 100), + self.create_message("user", "z" * 100), + ] + + result = await manager.process(messages) + + # Should have been compressed + assert len(result) <= len(messages) + + # ==================== split_history Tests ==================== + + def test_split_history_ensures_user_start(self): + """Test split_history ensures recent_messages starts with user message.""" + from astrbot.core.agent.context.compressor import split_history + + # Create alternating messages: user, assistant, user, assistant, user, assistant + messages = [ + self.create_message("system", "System prompt"), + self.create_message("user", "msg1"), + self.create_message("assistant", "msg2"), + self.create_message("user", "msg3"), + self.create_message("assistant", "msg4"), + self.create_message("user", "msg5"), + self.create_message("assistant", "msg6"), + ] + + # Keep recent 3 messages - should adjust to start with user + system, to_summarize, recent = split_history(messages, keep_recent=3) + + # recent_messages should start with user message + assert len(recent) > 0 + assert recent[0].role == "user" + + # messages_to_summarize should end with assistant (complete turn) + if len(to_summarize) > 0: + assert to_summarize[-1].role == "assistant" + + def test_split_history_handles_assistant_at_split_point(self): + """Test split_history when assistant message is at the intended split point.""" + from astrbot.core.agent.context.compressor import split_history + + messages = [ + self.create_message("user", "msg1"), + self.create_message("assistant", "msg2"), + self.create_message("user", "msg3"), + self.create_message("assistant", "msg4"), # <- intended split here + self.create_message("user", "msg5"), + self.create_message("assistant", "msg6"), + ] + + # keep_recent=2 would normally split at index 4 (assistant msg4) + # Should move back to include from msg5 (user) + system, to_summarize, recent = split_history(messages, keep_recent=2) + + # recent should start with user message + assert recent[0].role == "user" + assert recent[0].content == "msg5" + + def test_split_history_all_assistant_messages(self): + """Test split_history when there are consecutive assistant messages.""" + from astrbot.core.agent.context.compressor import split_history + + messages = [ + self.create_message("user", "msg1"), + self.create_message("assistant", "msg2"), + self.create_message("assistant", "msg3"), + self.create_message("assistant", "msg4"), + ] + + system, to_summarize, recent = split_history(messages, keep_recent=2) + + # Should find the user message and keep from there + if len(recent) > 0: + # Find first user message backwards + assert any(m.role == "user" for m in messages) + + def test_split_history_with_system_messages(self): + """Test split_history preserves system messages separately.""" + from astrbot.core.agent.context.compressor import split_history + + messages = [ + self.create_message("system", "System 1"), + self.create_message("system", "System 2"), + self.create_message("user", "msg1"), + self.create_message("assistant", "msg2"), + self.create_message("user", "msg3"), + ] + + system, to_summarize, recent = split_history(messages, keep_recent=2) + + # System messages should be separate + assert len(system) == 2 + assert all(m.role == "system" for m in system) + + # Recent should start with user + if len(recent) > 0: + assert recent[0].role == "user" diff --git a/tests/agent/test_truncator.py b/tests/agent/test_truncator.py new file mode 100644 index 000000000..1027643bb --- /dev/null +++ b/tests/agent/test_truncator.py @@ -0,0 +1,423 @@ +"""Tests for ContextTruncator.""" + +from astrbot.core.agent.context.truncator import ContextTruncator +from astrbot.core.agent.message import Message + + +class TestContextTruncator: + """Test suite for ContextTruncator.""" + + def create_message(self, role: str, content: str = "test content") -> Message: + """Helper to create a simple test message.""" + return Message(role=role, content=content) + + def create_messages( + self, count: int, include_system: bool = False + ) -> list[Message]: + """Helper to create alternating user/assistant messages. + + Args: + count: Number of messages to create + include_system: Whether to include a system message at the start + + Returns: + List of messages + """ + messages = [] + if include_system: + messages.append(self.create_message("system", "System prompt")) + + for i in range(count): + role = "user" if i % 2 == 0 else "assistant" + messages.append(self.create_message(role, f"Message {i}")) + return messages + + # ==================== fix_messages Tests ==================== + + def test_fix_messages_empty_list(self): + """Test fix_messages with an empty list.""" + truncator = ContextTruncator() + result = truncator.fix_messages([]) + assert result == [] + + def test_fix_messages_normal_messages(self): + """Test fix_messages with normal user/assistant messages.""" + truncator = ContextTruncator() + messages = [ + self.create_message("user", "Hello"), + self.create_message("assistant", "Hi"), + self.create_message("user", "How are you?"), + ] + result = truncator.fix_messages(messages) + assert len(result) == 3 + assert result == messages + + def test_fix_messages_tool_with_valid_context(self): + """Test fix_messages with tool message after user+assistant.""" + truncator = ContextTruncator() + messages = [ + self.create_message("user", "Run tool"), + self.create_message("assistant", "Running..."), + self.create_message("tool", "Tool result"), + ] + result = truncator.fix_messages(messages) + assert len(result) == 3 + assert result == messages + + def test_fix_messages_tool_without_context(self): + """Test fix_messages with tool message without enough context.""" + truncator = ContextTruncator() + messages = [ + self.create_message("tool", "Tool result"), + ] + result = truncator.fix_messages(messages) + # Tool message without context should be removed + assert len(result) == 0 + + def test_fix_messages_tool_with_only_one_message(self): + """Test fix_messages with tool message after only one message.""" + truncator = ContextTruncator() + messages = [ + self.create_message("user", "Hello"), + self.create_message("tool", "Tool result"), + ] + result = truncator.fix_messages(messages) + # Tool message without enough context should be removed + assert len(result) == 0 + + def test_fix_messages_multiple_tools(self): + """Test fix_messages with multiple tool messages.""" + truncator = ContextTruncator() + messages = [ + self.create_message("user", "Run tool"), + self.create_message("assistant", "Running..."), + self.create_message("tool", "Tool 1 result"), + self.create_message("tool", "Tool 2 result"), + ] + result = truncator.fix_messages(messages) + assert len(result) == 4 + assert result == messages + + def test_fix_messages_mixed_system_tool(self): + """Test fix_messages with system message and tool messages.""" + truncator = ContextTruncator() + messages = [ + self.create_message("system", "System prompt"), + self.create_message("user", "Run tool"), + self.create_message("assistant", "Running..."), + self.create_message("tool", "Tool result"), + ] + result = truncator.fix_messages(messages) + assert len(result) == 4 + assert result == messages + + # ==================== truncate_by_turns Tests ==================== + + def test_truncate_by_turns_no_limit(self): + """Test truncate_by_turns with -1 (no limit).""" + truncator = ContextTruncator() + messages = self.create_messages(20) + result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1) + assert len(result) == 20 + assert result == messages + + def test_truncate_by_turns_basic(self): + """Test basic truncate_by_turns functionality.""" + truncator = ContextTruncator() + # Create 10 messages = 5 turns (user/assistant pairs) + messages = self.create_messages(10) + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=3, drop_turns=1 + ) + + # Should keep 3 most recent turns (6 messages) + assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format + + def test_truncate_by_turns_with_system_message(self): + """Test truncate_by_turns preserves system messages.""" + truncator = ContextTruncator() + messages = self.create_messages(10, include_system=True) + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=2, drop_turns=1 + ) + + # System message should always be preserved + assert result[0].role == "system" + assert result[0].content == "System prompt" + + def test_truncate_by_turns_zero_keep(self): + """Test truncate_by_turns with keep_most_recent_turns=0.""" + truncator = ContextTruncator() + messages = self.create_messages(10) + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=0, drop_turns=1 + ) + + # Should result in empty or minimal list + assert len(result) == 0 + + def test_truncate_by_turns_below_threshold(self): + """Test truncate_by_turns when messages are below threshold.""" + truncator = ContextTruncator() + # Create 4 messages = 2 turns + messages = self.create_messages(4) + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=5, drop_turns=1 + ) + + # No truncation should happen + assert len(result) == 4 + assert result == messages + + def test_truncate_by_turns_exact_threshold(self): + """Test truncate_by_turns when messages exactly match threshold.""" + truncator = ContextTruncator() + # Create 6 messages = 3 turns + messages = self.create_messages(6) + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=3, drop_turns=1 + ) + + # No truncation should happen + assert len(result) == 6 + assert result == messages + + def test_truncate_by_turns_ensures_user_first(self): + """Test that truncate_by_turns ensures user message comes first.""" + truncator = ContextTruncator() + # Create scenario where truncation might start with assistant + messages = self.create_messages(20) + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=3, drop_turns=1 + ) + + # First non-system message should be user + assert result[0].role == "user" + + def test_truncate_by_turns_multiple_drop(self): + """Test truncate_by_turns with multiple turns dropped at once.""" + truncator = ContextTruncator() + messages = self.create_messages(20) + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=5, drop_turns=3 + ) + + # Should drop 3 turns when limit exceeded + assert len(result) < len(messages) + + # ==================== truncate_by_dropping_oldest_turns Tests ==================== + + def test_truncate_by_dropping_oldest_turns_zero(self): + """Test truncate_by_dropping_oldest_turns with drop_turns=0.""" + truncator = ContextTruncator() + messages = self.create_messages(10) + result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0) + assert result == messages + + def test_truncate_by_dropping_oldest_turns_negative(self): + """Test truncate_by_dropping_oldest_turns with negative drop_turns.""" + truncator = ContextTruncator() + messages = self.create_messages(10) + result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1) + assert result == messages + + def test_truncate_by_dropping_oldest_turns_basic(self): + """Test basic truncate_by_dropping_oldest_turns functionality.""" + truncator = ContextTruncator() + # Create 10 messages = 5 turns + messages = self.create_messages(10) + result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2) + + # Should drop 2 oldest turns (4 messages) + assert len(result) == 6 + # Should start with user message + assert result[0].role == "user" + + def test_truncate_by_dropping_oldest_turns_with_system(self): + """Test truncate_by_dropping_oldest_turns preserves system messages.""" + truncator = ContextTruncator() + messages = self.create_messages(10, include_system=True) + result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2) + + # System message should be preserved + assert result[0].role == "system" + assert result[0].content == "System prompt" + + def test_truncate_by_dropping_oldest_turns_drop_all(self): + """Test truncate_by_dropping_oldest_turns dropping all turns.""" + truncator = ContextTruncator() + # Create 4 messages = 2 turns + messages = self.create_messages(4) + result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2) + + # Should drop all turns + assert len(result) == 0 + + def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self): + """Test truncate_by_dropping_oldest_turns with drop_turns > available turns.""" + truncator = ContextTruncator() + # Create 4 messages = 2 turns + messages = self.create_messages(4) + result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5) + + # Should result in empty list + assert len(result) == 0 + + def test_truncate_by_dropping_oldest_turns_ensures_user_first(self): + """Test that result starts with user message after dropping.""" + truncator = ContextTruncator() + messages = self.create_messages(20) + result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3) + + # First message should be user + if len(result) > 0: + assert result[0].role == "user" + + # ==================== truncate_by_halving Tests ==================== + + def test_truncate_by_halving_empty(self): + """Test truncate_by_halving with empty list.""" + truncator = ContextTruncator() + result = truncator.truncate_by_halving([]) + assert result == [] + + def test_truncate_by_halving_single_message(self): + """Test truncate_by_halving with single message.""" + truncator = ContextTruncator() + messages = [self.create_message("user", "Hello")] + result = truncator.truncate_by_halving(messages) + # Should not truncate if <= 2 messages + assert result == messages + + def test_truncate_by_halving_two_messages(self): + """Test truncate_by_halving with two messages.""" + truncator = ContextTruncator() + messages = self.create_messages(2) + result = truncator.truncate_by_halving(messages) + # Should not truncate if <= 2 messages + assert result == messages + + def test_truncate_by_halving_basic(self): + """Test basic truncate_by_halving functionality.""" + truncator = ContextTruncator() + # Create 20 messages + messages = self.create_messages(20) + result = truncator.truncate_by_halving(messages) + + # Should delete 50% = 10 messages, keep 10 + assert len(result) == 10 + # First message should be user + assert result[0].role == "user" + + def test_truncate_by_halving_with_system_message(self): + """Test truncate_by_halving preserves system messages.""" + truncator = ContextTruncator() + messages = self.create_messages(20, include_system=True) + result = truncator.truncate_by_halving(messages) + + # System message should be preserved + assert result[0].role == "system" + assert result[0].content == "System prompt" + + def test_truncate_by_halving_odd_count(self): + """Test truncate_by_halving with odd number of messages.""" + truncator = ContextTruncator() + messages = self.create_messages(11) + result = truncator.truncate_by_halving(messages) + + # Should delete floor(11/2) = 5 messages, keep 6 + # But after ensuring user first, may be 5 + assert len(result) >= 5 + assert result[0].role == "user" + + def test_truncate_by_halving_ensures_user_first(self): + """Test that result starts with user message.""" + truncator = ContextTruncator() + # Create messages starting with user + messages = self.create_messages(30) + result = truncator.truncate_by_halving(messages) + + # First message should be user + assert result[0].role == "user" + + def test_truncate_by_halving_preserves_recent_messages(self): + """Test that truncate_by_halving keeps the most recent 50%.""" + truncator = ContextTruncator() + messages = [ + self.create_message("user", "Message 0"), + self.create_message("assistant", "Message 1"), + self.create_message("user", "Message 2"), + self.create_message("assistant", "Message 3"), + ] + result = truncator.truncate_by_halving(messages) + + # Should keep last 2 messages + assert len(result) == 2 + assert result[0].content == "Message 2" + assert result[1].content == "Message 3" + + # ==================== Integration Tests ==================== + + def test_truncate_with_tool_messages(self): + """Test truncation with tool messages.""" + truncator = ContextTruncator() + messages = [ + self.create_message("user", "Run tool"), + self.create_message("assistant", "Running..."), + self.create_message("tool", "Tool result"), + self.create_message("user", "Thanks"), + self.create_message("assistant", "Welcome"), + ] + + result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1) + + # First turn (user+assistant+tool) should be dropped + # Tool message should be cleaned up by fix_messages + assert len(result) <= 2 + + def test_chain_multiple_truncations(self): + """Test chaining multiple truncation methods.""" + truncator = ContextTruncator() + messages = self.create_messages(40, include_system=True) + + # First: truncate by turns + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=10, drop_turns=2 + ) + # Then: halve + result = truncator.truncate_by_halving(result) + + # Should have system message + truncated content + assert result[0].role == "system" + assert len(result) < len(messages) + + def test_empty_after_system_message(self): + """Test truncation when only system message exists.""" + truncator = ContextTruncator() + messages = [self.create_message("system", "System prompt")] + + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=5, drop_turns=1 + ) + + # Should keep system message + assert len(result) == 1 + assert result[0].role == "system" + + def test_all_system_messages(self): + """Test truncation with only system messages.""" + truncator = ContextTruncator() + messages = [ + self.create_message("system", "System 1"), + self.create_message("system", "System 2"), + ] + + result = truncator.truncate_by_turns( + messages, keep_most_recent_turns=0, drop_turns=1 + ) + + # System messages should be preserved, but since there are no non-system + # messages and keep_most_recent_turns=0, result should be system messages only + assert len(result) >= 0 # May keep system messages or clear all + if len(result) > 0: + assert all(msg.role == "system" for msg in result)