"""Beaver 内置 session_search tool。 这个工具对应 Hermes-agent 的跨会话检索能力,目标不是把所有历史内容塞回主上下文, 而是按需从过去的 session 中找回“之前发生过什么”。 当前实现保留了几个关键行为: 1. query 为空时进入 recent/browse 模式,只列最近会话,不走 LLM,总成本很低 2. query 不为空时走 transcript DB 的搜索接口,预期底层是 FTS 风格检索 3. 自动排除当前 session lineage,避免把当前上下文又搜出来一遍 4. 对长会话做 match-centered truncation,而不是无脑截前 N 字符 5. summarizer 是可选依赖;没有时降级返回 raw preview,而不是整条工具失败 """ from __future__ import annotations import asyncio import json import logging import re from dataclasses import dataclass, field from datetime import datetime from typing import Any, Awaitable, Callable, Protocol MAX_SESSION_CHARS = 100_000 class SessionSearchDB(Protocol): """session_search 依赖的最小数据库契约。 这里没有直接绑定某个具体 SQLite 实现,而是先定义行为接口。 这样后面无论你接的是 Hermes 风格 state DB、还是 Beaver 自己的 transcript store, 只要满足这些方法就能工作。 """ def list_sessions_rich( self, *, limit: int, exclude_sources: list[str] | None = None, ) -> list[dict[str, Any]]: ... def get_session(self, session_id: str) -> dict[str, Any] | None: ... def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]: ... def search_messages( self, *, query: str, role_filter: list[str] | None = None, exclude_sources: list[str] | None = None, limit: int, offset: int = 0, ) -> list[dict[str, Any]]: ... SessionSummarizer = Callable[[str, str, dict[str, Any]], Awaitable[str | None]] _HIDDEN_SESSION_SOURCES = ("tool",) SESSION_SEARCH_TOOL_DESCRIPTION = ( "Search prior sessions for historical context, or browse recent sessions when " "query is omitted. Use this when the user references past work, prior fixes, " "or earlier decisions instead of asking them to repeat themselves." ) SESSION_SEARCH_TOOL_PARAMETERS: dict[str, Any] = { "type": "object", "properties": { "query": { "type": "string", "description": "Keyword, phrase, or boolean FTS query. Omit to browse recent sessions.", }, "role_filter": { "type": "string", "description": "Optional comma-separated roles to search, for example 'user,assistant'.", }, "limit": { "type": "integer", "default": 3, "minimum": 1, "maximum": 5, "description": "Maximum number of sessions to return.", }, }, "required": [], } def _format_timestamp(value: int | float | str | None) -> str: """把时间戳或字符串格式化成更可读的展示文本。""" if value is None: return "unknown" try: if isinstance(value, (int, float)): return datetime.fromtimestamp(value).strftime("%B %d, %Y at %I:%M %p") if isinstance(value, str): if value.replace(".", "").replace("-", "").isdigit(): return datetime.fromtimestamp(float(value)).strftime("%B %d, %Y at %I:%M %p") return value except (OSError, OverflowError, ValueError): pass return str(value) def _format_conversation(messages: list[dict[str, Any]]) -> str: """把消息列表整理成适合摘要模型消费的 transcript 文本。 这里会保留: - role - assistant 的 tool calls 名称 - tool 输出的简短内容 但不会原样塞入超长工具输出,否则摘要成本会被单个工具结果拉爆。 """ parts: list[str] = [] for message in messages: role = str(message.get("role", "unknown")).upper() content = message.get("content") or "" tool_name = message.get("tool_name") if role == "TOOL" and tool_name: if len(content) > 500: content = content[:250] + "\n...[truncated]...\n" + content[-250:] parts.append(f"[TOOL:{tool_name}]: {content}") continue if role == "ASSISTANT": tool_calls = message.get("tool_calls") if isinstance(tool_calls, list) and tool_calls: names: list[str] = [] for tool_call in tool_calls: if isinstance(tool_call, dict): names.append( tool_call.get("name") or tool_call.get("function", {}).get("name", "?") ) if names: parts.append(f"[ASSISTANT]: [Called: {', '.join(names)}]") parts.append(f"[ASSISTANT]: {content}") continue parts.append(f"[{role}]: {content}") return "\n\n".join(parts) def _truncate_around_matches(full_text: str, query: str, *, max_chars: int = MAX_SESSION_CHARS) -> str: """围绕匹配位置截取上下文,而不是固定截头。 优先级: 1. 先找整句 query 2. 找不到再找多词近邻共现 3. 再退化到逐词匹配 这样做的目的,是尽量把与 query 最相关的对话片段保留下来,提高 summarizer 的命中率。 """ if len(full_text) <= max_chars: return full_text text_lower = full_text.lower() query_lower = query.lower().strip() match_positions = [match.start() for match in re.finditer(re.escape(query_lower), text_lower)] if not match_positions: terms = query_lower.split() if len(terms) > 1: positions: dict[str, list[int]] = { term: [match.start() for match in re.finditer(re.escape(term), text_lower)] for term in terms } rarest = min(terms, key=lambda term: len(positions.get(term, []))) for position in positions.get(rarest, []): if all( any(abs(candidate - position) < 200 for candidate in positions.get(term, [])) for term in terms if term != rarest ): match_positions.append(position) if not match_positions: for term in query_lower.split(): match_positions.extend(match.start() for match in re.finditer(re.escape(term), text_lower)) if not match_positions: head = full_text[:max_chars] suffix = "\n\n...[later conversation truncated]..." if max_chars < len(full_text) else "" return head + suffix best_start = 0 best_count = 0 for candidate in sorted(match_positions): window_start = max(0, candidate - max_chars // 4) window_end = window_start + max_chars if window_end > len(full_text): window_start = max(0, len(full_text) - max_chars) window_end = len(full_text) count = sum(1 for position in match_positions if window_start <= position < window_end) if count > best_count: best_count = count best_start = window_start start = best_start end = min(len(full_text), start + max_chars) prefix = "...[earlier conversation truncated]...\n\n" if start > 0 else "" suffix = "\n\n...[later conversation truncated]..." if end < len(full_text) else "" return prefix + full_text[start:end] + suffix def _resolve_to_parent(db: SessionSearchDB, session_id: str | None) -> str | None: """沿 parent_session_id 向上追溯到 lineage root。 这样可以把 delegation/compression 形成的子 session 归并回同一条主会话链, 避免检索结果里出现多个其实属于同一轮上下文的碎片 session。 """ visited: set[str] = set() current = session_id while current and current not in visited: visited.add(current) session = db.get_session(current) if not session: break parent = session.get("parent_session_id") if not parent: break current = parent return current def _list_recent_sessions( db: SessionSearchDB, *, limit: int, current_session_id: str | None = None, ) -> str: """recent mode:仅列出最近 session 的元数据,不做摘要调用。""" sessions = db.list_sessions_rich( limit=limit + 5, exclude_sources=list(_HIDDEN_SESSION_SOURCES), ) current_root = _resolve_to_parent(db, current_session_id) if current_session_id else None results: list[dict[str, Any]] = [] for session in sessions: session_id = session.get("id", "") if current_root and session_id == current_root: continue if current_session_id and session_id == current_session_id: continue if session.get("parent_session_id"): continue results.append( { "session_id": session_id, "title": session.get("title") or None, "source": session.get("source", ""), "started_at": session.get("started_at", ""), "last_active": session.get("last_active", ""), "message_count": session.get("message_count", 0), "preview": session.get("preview", ""), } ) if len(results) >= limit: break return json.dumps( { "success": True, "mode": "recent", "results": results, "count": len(results), "message": f"Showing {len(results)} most recent sessions.", }, ensure_ascii=False, ) async def session_search( *, query: str = "", role_filter: str | None = None, limit: int = 3, db: SessionSearchDB | None = None, current_session_id: str | None = None, summarizer: SessionSummarizer | None = None, ) -> str: """搜索过去的会话并返回结构化 JSON 结果。 运行流程: 1. 空 query -> recent mode 2. 有 query -> 调 transcript DB 搜索 3. 去掉当前会话链 4. 拉取命中的 session transcript 5. 对 transcript 做 match-centered truncation 6. 如果提供 summarizer,就并发摘要;否则回退 raw preview """ if db is None: return json.dumps({"success": False, "error": "Session database is not available."}, ensure_ascii=False) limit = max(1, min(limit, 5)) if not query or not query.strip(): return _list_recent_sessions(db, limit=limit, current_session_id=current_session_id) role_list = [item.strip() for item in (role_filter or "").split(",") if item.strip()] or None try: raw_results = db.search_messages( query=query.strip(), role_filter=role_list, exclude_sources=list(_HIDDEN_SESSION_SOURCES), limit=50, offset=0, ) except Exception as exc: logging.error("Session search failed during FTS lookup: %s", exc, exc_info=True) return json.dumps({"success": False, "error": f"Search failed: {exc}"}, ensure_ascii=False) if not raw_results: return json.dumps( { "success": True, "query": query.strip(), "results": [], "count": 0, "message": "No matching sessions found.", }, ensure_ascii=False, ) current_root = _resolve_to_parent(db, current_session_id) if current_session_id else None seen_sessions: dict[str, dict[str, Any]] = {} for result in raw_results: raw_session_id = result["session_id"] resolved_session_id = _resolve_to_parent(db, raw_session_id) or raw_session_id if current_root and resolved_session_id == current_root: continue if current_session_id and raw_session_id == current_session_id: continue if resolved_session_id not in seen_sessions: entry = dict(result) entry["session_id"] = resolved_session_id seen_sessions[resolved_session_id] = entry if len(seen_sessions) >= limit: break prepared: list[tuple[str, dict[str, Any], str, dict[str, Any]]] = [] for session_id, match_info in seen_sessions.items(): try: messages = db.get_messages_as_conversation(session_id) if not messages: continue session_meta = db.get_session(session_id) or {} transcript = _truncate_around_matches(_format_conversation(messages), query.strip()) prepared.append((session_id, match_info, transcript, session_meta)) except Exception as exc: logging.warning("Failed to prepare session %s: %s", session_id, exc, exc_info=True) if summarizer is not None: summaries = await asyncio.gather( *(summarizer(transcript, query.strip(), session_meta) for _, _, transcript, session_meta in prepared), return_exceptions=True, ) else: summaries = [None] * len(prepared) results: list[dict[str, Any]] = [] for (session_id, match_info, transcript, _), summary in zip(prepared, summaries): resolved_summary: str | None if isinstance(summary, Exception): logging.warning("Failed to summarize session %s: %s", session_id, summary, exc_info=True) resolved_summary = None else: resolved_summary = summary if not resolved_summary: preview = transcript[:500] + ("\n…[truncated]" if len(transcript) > 500 else "") resolved_summary = f"[Raw preview — summarization unavailable]\n{preview}" results.append( { "session_id": session_id, "when": _format_timestamp(match_info.get("session_started")), "source": match_info.get("source", "unknown"), "model": match_info.get("model"), "summary": resolved_summary, } ) return json.dumps( { "success": True, "query": query.strip(), "results": results, "count": len(results), "sessions_searched": len(seen_sessions), }, ensure_ascii=False, ) @dataclass(slots=True) class SessionSearchTool: """面向 runtime 的轻量 session_search 工具封装。""" db: SessionSearchDB current_session_id: str | None = None summarizer: SessionSummarizer | None = None name: str = "session_search" description: str = SESSION_SEARCH_TOOL_DESCRIPTION parameters: dict[str, Any] = field(default_factory=lambda: dict(SESSION_SEARCH_TOOL_PARAMETERS)) async def execute(self, **kwargs: Any) -> str: current_session_id = kwargs.pop("current_session_id", None) return await session_search( db=self.db, current_session_id=current_session_id if current_session_id is not None else self.current_session_id, summarizer=self.summarizer, **kwargs, )