Files
beaver_project/app-instance/backend/beaver/tools/builtins/session_search.py

419 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Beaver 内置 session_search tool。
这个工具对应 Hermes-agent 的跨会话检索能力,目标不是把所有历史内容塞回主上下文,
而是按需从过去的 session 中找回“之前发生过什么”。
当前实现保留了几个关键行为:
1. query 为空时进入 recent/browse 模式,只列最近会话,不走 LLM总成本很低
2. query 不为空时走 transcript DB 的搜索接口,预期底层是 FTS 风格检索
3. 自动排除当前 session lineage避免把当前上下文又搜出来一遍
4. 对长会话做 match-centered truncation而不是无脑截前 N 字符
5. summarizer 是可选依赖;没有时降级返回 raw preview而不是整条工具失败
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Awaitable, Callable, Protocol
MAX_SESSION_CHARS = 100_000
class SessionSearchDB(Protocol):
"""session_search 依赖的最小数据库契约。
这里没有直接绑定某个具体 SQLite 实现,而是先定义行为接口。
这样后面无论你接的是 Hermes 风格 state DB、还是 Beaver 自己的 transcript store
只要满足这些方法就能工作。
"""
def list_sessions_rich(
self,
*,
limit: int,
exclude_sources: list[str] | None = None,
) -> list[dict[str, Any]]: ...
def get_session(self, session_id: str) -> dict[str, Any] | None: ...
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]: ...
def search_messages(
self,
*,
query: str,
role_filter: list[str] | None = None,
exclude_sources: list[str] | None = None,
limit: int,
offset: int = 0,
) -> list[dict[str, Any]]: ...
SessionSummarizer = Callable[[str, str, dict[str, Any]], Awaitable[str | None]]
_HIDDEN_SESSION_SOURCES = ("tool",)
SESSION_SEARCH_TOOL_DESCRIPTION = (
"Search prior sessions for historical context, or browse recent sessions when "
"query is omitted. Use this when the user references past work, prior fixes, "
"or earlier decisions instead of asking them to repeat themselves."
)
SESSION_SEARCH_TOOL_PARAMETERS: dict[str, Any] = {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Keyword, phrase, or boolean FTS query. Omit to browse recent sessions.",
},
"role_filter": {
"type": "string",
"description": "Optional comma-separated roles to search, for example 'user,assistant'.",
},
"limit": {
"type": "integer",
"default": 3,
"minimum": 1,
"maximum": 5,
"description": "Maximum number of sessions to return.",
},
},
"required": [],
}
def _format_timestamp(value: int | float | str | None) -> str:
"""把时间戳或字符串格式化成更可读的展示文本。"""
if value is None:
return "unknown"
try:
if isinstance(value, (int, float)):
return datetime.fromtimestamp(value).strftime("%B %d, %Y at %I:%M %p")
if isinstance(value, str):
if value.replace(".", "").replace("-", "").isdigit():
return datetime.fromtimestamp(float(value)).strftime("%B %d, %Y at %I:%M %p")
return value
except (OSError, OverflowError, ValueError):
pass
return str(value)
def _format_conversation(messages: list[dict[str, Any]]) -> str:
"""把消息列表整理成适合摘要模型消费的 transcript 文本。
这里会保留:
- role
- assistant 的 tool calls 名称
- tool 输出的简短内容
但不会原样塞入超长工具输出,否则摘要成本会被单个工具结果拉爆。
"""
parts: list[str] = []
for message in messages:
role = str(message.get("role", "unknown")).upper()
content = message.get("content") or ""
tool_name = message.get("tool_name")
if role == "TOOL" and tool_name:
if len(content) > 500:
content = content[:250] + "\n...[truncated]...\n" + content[-250:]
parts.append(f"[TOOL:{tool_name}]: {content}")
continue
if role == "ASSISTANT":
tool_calls = message.get("tool_calls")
if isinstance(tool_calls, list) and tool_calls:
names: list[str] = []
for tool_call in tool_calls:
if isinstance(tool_call, dict):
names.append(
tool_call.get("name")
or tool_call.get("function", {}).get("name", "?")
)
if names:
parts.append(f"[ASSISTANT]: [Called: {', '.join(names)}]")
parts.append(f"[ASSISTANT]: {content}")
continue
parts.append(f"[{role}]: {content}")
return "\n\n".join(parts)
def _truncate_around_matches(full_text: str, query: str, *, max_chars: int = MAX_SESSION_CHARS) -> str:
"""围绕匹配位置截取上下文,而不是固定截头。
优先级:
1. 先找整句 query
2. 找不到再找多词近邻共现
3. 再退化到逐词匹配
这样做的目的,是尽量把与 query 最相关的对话片段保留下来,提高 summarizer 的命中率。
"""
if len(full_text) <= max_chars:
return full_text
text_lower = full_text.lower()
query_lower = query.lower().strip()
match_positions = [match.start() for match in re.finditer(re.escape(query_lower), text_lower)]
if not match_positions:
terms = query_lower.split()
if len(terms) > 1:
positions: dict[str, list[int]] = {
term: [match.start() for match in re.finditer(re.escape(term), text_lower)]
for term in terms
}
rarest = min(terms, key=lambda term: len(positions.get(term, [])))
for position in positions.get(rarest, []):
if all(
any(abs(candidate - position) < 200 for candidate in positions.get(term, []))
for term in terms
if term != rarest
):
match_positions.append(position)
if not match_positions:
for term in query_lower.split():
match_positions.extend(match.start() for match in re.finditer(re.escape(term), text_lower))
if not match_positions:
head = full_text[:max_chars]
suffix = "\n\n...[later conversation truncated]..." if max_chars < len(full_text) else ""
return head + suffix
best_start = 0
best_count = 0
for candidate in sorted(match_positions):
window_start = max(0, candidate - max_chars // 4)
window_end = window_start + max_chars
if window_end > len(full_text):
window_start = max(0, len(full_text) - max_chars)
window_end = len(full_text)
count = sum(1 for position in match_positions if window_start <= position < window_end)
if count > best_count:
best_count = count
best_start = window_start
start = best_start
end = min(len(full_text), start + max_chars)
prefix = "...[earlier conversation truncated]...\n\n" if start > 0 else ""
suffix = "\n\n...[later conversation truncated]..." if end < len(full_text) else ""
return prefix + full_text[start:end] + suffix
def _resolve_to_parent(db: SessionSearchDB, session_id: str | None) -> str | None:
"""沿 parent_session_id 向上追溯到 lineage root。
这样可以把 delegation/compression 形成的子 session 归并回同一条主会话链,
避免检索结果里出现多个其实属于同一轮上下文的碎片 session。
"""
visited: set[str] = set()
current = session_id
while current and current not in visited:
visited.add(current)
session = db.get_session(current)
if not session:
break
parent = session.get("parent_session_id")
if not parent:
break
current = parent
return current
def _list_recent_sessions(
db: SessionSearchDB,
*,
limit: int,
current_session_id: str | None = None,
) -> str:
"""recent mode仅列出最近 session 的元数据,不做摘要调用。"""
sessions = db.list_sessions_rich(
limit=limit + 5,
exclude_sources=list(_HIDDEN_SESSION_SOURCES),
)
current_root = _resolve_to_parent(db, current_session_id) if current_session_id else None
results: list[dict[str, Any]] = []
for session in sessions:
session_id = session.get("id", "")
if current_root and session_id == current_root:
continue
if current_session_id and session_id == current_session_id:
continue
if session.get("parent_session_id"):
continue
results.append(
{
"session_id": session_id,
"title": session.get("title") or None,
"source": session.get("source", ""),
"started_at": session.get("started_at", ""),
"last_active": session.get("last_active", ""),
"message_count": session.get("message_count", 0),
"preview": session.get("preview", ""),
}
)
if len(results) >= limit:
break
return json.dumps(
{
"success": True,
"mode": "recent",
"results": results,
"count": len(results),
"message": f"Showing {len(results)} most recent sessions.",
},
ensure_ascii=False,
)
async def session_search(
*,
query: str = "",
role_filter: str | None = None,
limit: int = 3,
db: SessionSearchDB | None = None,
current_session_id: str | None = None,
summarizer: SessionSummarizer | None = None,
) -> str:
"""搜索过去的会话并返回结构化 JSON 结果。
运行流程:
1. 空 query -> recent mode
2. 有 query -> 调 transcript DB 搜索
3. 去掉当前会话链
4. 拉取命中的 session transcript
5. 对 transcript 做 match-centered truncation
6. 如果提供 summarizer就并发摘要否则回退 raw preview
"""
if db is None:
return json.dumps({"success": False, "error": "Session database is not available."}, ensure_ascii=False)
limit = max(1, min(limit, 5))
if not query or not query.strip():
return _list_recent_sessions(db, limit=limit, current_session_id=current_session_id)
role_list = [item.strip() for item in (role_filter or "").split(",") if item.strip()] or None
try:
raw_results = db.search_messages(
query=query.strip(),
role_filter=role_list,
exclude_sources=list(_HIDDEN_SESSION_SOURCES),
limit=50,
offset=0,
)
except Exception as exc:
logging.error("Session search failed during FTS lookup: %s", exc, exc_info=True)
return json.dumps({"success": False, "error": f"Search failed: {exc}"}, ensure_ascii=False)
if not raw_results:
return json.dumps(
{
"success": True,
"query": query.strip(),
"results": [],
"count": 0,
"message": "No matching sessions found.",
},
ensure_ascii=False,
)
current_root = _resolve_to_parent(db, current_session_id) if current_session_id else None
seen_sessions: dict[str, dict[str, Any]] = {}
for result in raw_results:
raw_session_id = result["session_id"]
resolved_session_id = _resolve_to_parent(db, raw_session_id) or raw_session_id
if current_root and resolved_session_id == current_root:
continue
if current_session_id and raw_session_id == current_session_id:
continue
if resolved_session_id not in seen_sessions:
entry = dict(result)
entry["session_id"] = resolved_session_id
seen_sessions[resolved_session_id] = entry
if len(seen_sessions) >= limit:
break
prepared: list[tuple[str, dict[str, Any], str, dict[str, Any]]] = []
for session_id, match_info in seen_sessions.items():
try:
messages = db.get_messages_as_conversation(session_id)
if not messages:
continue
session_meta = db.get_session(session_id) or {}
transcript = _truncate_around_matches(_format_conversation(messages), query.strip())
prepared.append((session_id, match_info, transcript, session_meta))
except Exception as exc:
logging.warning("Failed to prepare session %s: %s", session_id, exc, exc_info=True)
if summarizer is not None:
summaries = await asyncio.gather(
*(summarizer(transcript, query.strip(), session_meta) for _, _, transcript, session_meta in prepared),
return_exceptions=True,
)
else:
summaries = [None] * len(prepared)
results: list[dict[str, Any]] = []
for (session_id, match_info, transcript, _), summary in zip(prepared, summaries):
resolved_summary: str | None
if isinstance(summary, Exception):
logging.warning("Failed to summarize session %s: %s", session_id, summary, exc_info=True)
resolved_summary = None
else:
resolved_summary = summary
if not resolved_summary:
preview = transcript[:500] + ("\n…[truncated]" if len(transcript) > 500 else "")
resolved_summary = f"[Raw preview — summarization unavailable]\n{preview}"
results.append(
{
"session_id": session_id,
"when": _format_timestamp(match_info.get("session_started")),
"source": match_info.get("source", "unknown"),
"model": match_info.get("model"),
"summary": resolved_summary,
}
)
return json.dumps(
{
"success": True,
"query": query.strip(),
"results": results,
"count": len(results),
"sessions_searched": len(seen_sessions),
},
ensure_ascii=False,
)
@dataclass(slots=True)
class SessionSearchTool:
"""面向 runtime 的轻量 session_search 工具封装。"""
db: SessionSearchDB
current_session_id: str | None = None
summarizer: SessionSummarizer | None = None
name: str = "session_search"
description: str = SESSION_SEARCH_TOOL_DESCRIPTION
parameters: dict[str, Any] = field(default_factory=lambda: dict(SESSION_SEARCH_TOOL_PARAMETERS))
async def execute(self, **kwargs: Any) -> str:
current_session_id = kwargs.pop("current_session_id", None)
return await session_search(
db=self.db,
current_session_id=current_session_id if current_session_id is not None else self.current_session_id,
summarizer=self.summarizer,
**kwargs,
)