419 lines
15 KiB
Python
419 lines
15 KiB
Python
"""Beaver 内置 session_search tool。
|
||
|
||
这个工具对应 Hermes-agent 的跨会话检索能力,目标不是把所有历史内容塞回主上下文,
|
||
而是按需从过去的 session 中找回“之前发生过什么”。
|
||
|
||
当前实现保留了几个关键行为:
|
||
1. query 为空时进入 recent/browse 模式,只列最近会话,不走 LLM,总成本很低
|
||
2. query 不为空时走 transcript DB 的搜索接口,预期底层是 FTS 风格检索
|
||
3. 自动排除当前 session lineage,避免把当前上下文又搜出来一遍
|
||
4. 对长会话做 match-centered truncation,而不是无脑截前 N 字符
|
||
5. summarizer 是可选依赖;没有时降级返回 raw preview,而不是整条工具失败
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import re
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from typing import Any, Awaitable, Callable, Protocol
|
||
|
||
MAX_SESSION_CHARS = 100_000
|
||
|
||
|
||
class SessionSearchDB(Protocol):
|
||
"""session_search 依赖的最小数据库契约。
|
||
|
||
这里没有直接绑定某个具体 SQLite 实现,而是先定义行为接口。
|
||
这样后面无论你接的是 Hermes 风格 state DB、还是 Beaver 自己的 transcript store,
|
||
只要满足这些方法就能工作。
|
||
"""
|
||
|
||
def list_sessions_rich(
|
||
self,
|
||
*,
|
||
limit: int,
|
||
exclude_sources: list[str] | None = None,
|
||
) -> list[dict[str, Any]]: ...
|
||
|
||
def get_session(self, session_id: str) -> dict[str, Any] | None: ...
|
||
|
||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]: ...
|
||
|
||
def search_messages(
|
||
self,
|
||
*,
|
||
query: str,
|
||
role_filter: list[str] | None = None,
|
||
exclude_sources: list[str] | None = None,
|
||
limit: int,
|
||
offset: int = 0,
|
||
) -> list[dict[str, Any]]: ...
|
||
|
||
|
||
SessionSummarizer = Callable[[str, str, dict[str, Any]], Awaitable[str | None]]
|
||
|
||
_HIDDEN_SESSION_SOURCES = ("tool",)
|
||
|
||
SESSION_SEARCH_TOOL_DESCRIPTION = (
|
||
"Search prior sessions for historical context, or browse recent sessions when "
|
||
"query is omitted. Use this when the user references past work, prior fixes, "
|
||
"or earlier decisions instead of asking them to repeat themselves."
|
||
)
|
||
|
||
SESSION_SEARCH_TOOL_PARAMETERS: dict[str, Any] = {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "Keyword, phrase, or boolean FTS query. Omit to browse recent sessions.",
|
||
},
|
||
"role_filter": {
|
||
"type": "string",
|
||
"description": "Optional comma-separated roles to search, for example 'user,assistant'.",
|
||
},
|
||
"limit": {
|
||
"type": "integer",
|
||
"default": 3,
|
||
"minimum": 1,
|
||
"maximum": 5,
|
||
"description": "Maximum number of sessions to return.",
|
||
},
|
||
},
|
||
"required": [],
|
||
}
|
||
|
||
|
||
def _format_timestamp(value: int | float | str | None) -> str:
|
||
"""把时间戳或字符串格式化成更可读的展示文本。"""
|
||
if value is None:
|
||
return "unknown"
|
||
try:
|
||
if isinstance(value, (int, float)):
|
||
return datetime.fromtimestamp(value).strftime("%B %d, %Y at %I:%M %p")
|
||
if isinstance(value, str):
|
||
if value.replace(".", "").replace("-", "").isdigit():
|
||
return datetime.fromtimestamp(float(value)).strftime("%B %d, %Y at %I:%M %p")
|
||
return value
|
||
except (OSError, OverflowError, ValueError):
|
||
pass
|
||
return str(value)
|
||
|
||
|
||
def _format_conversation(messages: list[dict[str, Any]]) -> str:
|
||
"""把消息列表整理成适合摘要模型消费的 transcript 文本。
|
||
|
||
这里会保留:
|
||
- role
|
||
- assistant 的 tool calls 名称
|
||
- tool 输出的简短内容
|
||
|
||
但不会原样塞入超长工具输出,否则摘要成本会被单个工具结果拉爆。
|
||
"""
|
||
parts: list[str] = []
|
||
for message in messages:
|
||
role = str(message.get("role", "unknown")).upper()
|
||
content = message.get("content") or ""
|
||
tool_name = message.get("tool_name")
|
||
|
||
if role == "TOOL" and tool_name:
|
||
if len(content) > 500:
|
||
content = content[:250] + "\n...[truncated]...\n" + content[-250:]
|
||
parts.append(f"[TOOL:{tool_name}]: {content}")
|
||
continue
|
||
|
||
if role == "ASSISTANT":
|
||
tool_calls = message.get("tool_calls")
|
||
if isinstance(tool_calls, list) and tool_calls:
|
||
names: list[str] = []
|
||
for tool_call in tool_calls:
|
||
if isinstance(tool_call, dict):
|
||
names.append(
|
||
tool_call.get("name")
|
||
or tool_call.get("function", {}).get("name", "?")
|
||
)
|
||
if names:
|
||
parts.append(f"[ASSISTANT]: [Called: {', '.join(names)}]")
|
||
parts.append(f"[ASSISTANT]: {content}")
|
||
continue
|
||
|
||
parts.append(f"[{role}]: {content}")
|
||
|
||
return "\n\n".join(parts)
|
||
|
||
|
||
def _truncate_around_matches(full_text: str, query: str, *, max_chars: int = MAX_SESSION_CHARS) -> str:
|
||
"""围绕匹配位置截取上下文,而不是固定截头。
|
||
|
||
优先级:
|
||
1. 先找整句 query
|
||
2. 找不到再找多词近邻共现
|
||
3. 再退化到逐词匹配
|
||
|
||
这样做的目的,是尽量把与 query 最相关的对话片段保留下来,提高 summarizer 的命中率。
|
||
"""
|
||
if len(full_text) <= max_chars:
|
||
return full_text
|
||
|
||
text_lower = full_text.lower()
|
||
query_lower = query.lower().strip()
|
||
match_positions = [match.start() for match in re.finditer(re.escape(query_lower), text_lower)]
|
||
|
||
if not match_positions:
|
||
terms = query_lower.split()
|
||
if len(terms) > 1:
|
||
positions: dict[str, list[int]] = {
|
||
term: [match.start() for match in re.finditer(re.escape(term), text_lower)]
|
||
for term in terms
|
||
}
|
||
rarest = min(terms, key=lambda term: len(positions.get(term, [])))
|
||
for position in positions.get(rarest, []):
|
||
if all(
|
||
any(abs(candidate - position) < 200 for candidate in positions.get(term, []))
|
||
for term in terms
|
||
if term != rarest
|
||
):
|
||
match_positions.append(position)
|
||
|
||
if not match_positions:
|
||
for term in query_lower.split():
|
||
match_positions.extend(match.start() for match in re.finditer(re.escape(term), text_lower))
|
||
|
||
if not match_positions:
|
||
head = full_text[:max_chars]
|
||
suffix = "\n\n...[later conversation truncated]..." if max_chars < len(full_text) else ""
|
||
return head + suffix
|
||
|
||
best_start = 0
|
||
best_count = 0
|
||
for candidate in sorted(match_positions):
|
||
window_start = max(0, candidate - max_chars // 4)
|
||
window_end = window_start + max_chars
|
||
if window_end > len(full_text):
|
||
window_start = max(0, len(full_text) - max_chars)
|
||
window_end = len(full_text)
|
||
count = sum(1 for position in match_positions if window_start <= position < window_end)
|
||
if count > best_count:
|
||
best_count = count
|
||
best_start = window_start
|
||
|
||
start = best_start
|
||
end = min(len(full_text), start + max_chars)
|
||
prefix = "...[earlier conversation truncated]...\n\n" if start > 0 else ""
|
||
suffix = "\n\n...[later conversation truncated]..." if end < len(full_text) else ""
|
||
return prefix + full_text[start:end] + suffix
|
||
|
||
|
||
def _resolve_to_parent(db: SessionSearchDB, session_id: str | None) -> str | None:
|
||
"""沿 parent_session_id 向上追溯到 lineage root。
|
||
|
||
这样可以把 delegation/compression 形成的子 session 归并回同一条主会话链,
|
||
避免检索结果里出现多个其实属于同一轮上下文的碎片 session。
|
||
"""
|
||
visited: set[str] = set()
|
||
current = session_id
|
||
while current and current not in visited:
|
||
visited.add(current)
|
||
session = db.get_session(current)
|
||
if not session:
|
||
break
|
||
parent = session.get("parent_session_id")
|
||
if not parent:
|
||
break
|
||
current = parent
|
||
return current
|
||
|
||
|
||
def _list_recent_sessions(
|
||
db: SessionSearchDB,
|
||
*,
|
||
limit: int,
|
||
current_session_id: str | None = None,
|
||
) -> str:
|
||
"""recent mode:仅列出最近 session 的元数据,不做摘要调用。"""
|
||
sessions = db.list_sessions_rich(
|
||
limit=limit + 5,
|
||
exclude_sources=list(_HIDDEN_SESSION_SOURCES),
|
||
)
|
||
current_root = _resolve_to_parent(db, current_session_id) if current_session_id else None
|
||
results: list[dict[str, Any]] = []
|
||
for session in sessions:
|
||
session_id = session.get("id", "")
|
||
if current_root and session_id == current_root:
|
||
continue
|
||
if current_session_id and session_id == current_session_id:
|
||
continue
|
||
if session.get("parent_session_id"):
|
||
continue
|
||
results.append(
|
||
{
|
||
"session_id": session_id,
|
||
"title": session.get("title") or None,
|
||
"source": session.get("source", ""),
|
||
"started_at": session.get("started_at", ""),
|
||
"last_active": session.get("last_active", ""),
|
||
"message_count": session.get("message_count", 0),
|
||
"preview": session.get("preview", ""),
|
||
}
|
||
)
|
||
if len(results) >= limit:
|
||
break
|
||
|
||
return json.dumps(
|
||
{
|
||
"success": True,
|
||
"mode": "recent",
|
||
"results": results,
|
||
"count": len(results),
|
||
"message": f"Showing {len(results)} most recent sessions.",
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
|
||
|
||
async def session_search(
|
||
*,
|
||
query: str = "",
|
||
role_filter: str | None = None,
|
||
limit: int = 3,
|
||
db: SessionSearchDB | None = None,
|
||
current_session_id: str | None = None,
|
||
summarizer: SessionSummarizer | None = None,
|
||
) -> str:
|
||
"""搜索过去的会话并返回结构化 JSON 结果。
|
||
|
||
运行流程:
|
||
1. 空 query -> recent mode
|
||
2. 有 query -> 调 transcript DB 搜索
|
||
3. 去掉当前会话链
|
||
4. 拉取命中的 session transcript
|
||
5. 对 transcript 做 match-centered truncation
|
||
6. 如果提供 summarizer,就并发摘要;否则回退 raw preview
|
||
"""
|
||
|
||
if db is None:
|
||
return json.dumps({"success": False, "error": "Session database is not available."}, ensure_ascii=False)
|
||
|
||
limit = max(1, min(limit, 5))
|
||
if not query or not query.strip():
|
||
return _list_recent_sessions(db, limit=limit, current_session_id=current_session_id)
|
||
|
||
role_list = [item.strip() for item in (role_filter or "").split(",") if item.strip()] or None
|
||
try:
|
||
raw_results = db.search_messages(
|
||
query=query.strip(),
|
||
role_filter=role_list,
|
||
exclude_sources=list(_HIDDEN_SESSION_SOURCES),
|
||
limit=50,
|
||
offset=0,
|
||
)
|
||
except Exception as exc:
|
||
logging.error("Session search failed during FTS lookup: %s", exc, exc_info=True)
|
||
return json.dumps({"success": False, "error": f"Search failed: {exc}"}, ensure_ascii=False)
|
||
|
||
if not raw_results:
|
||
return json.dumps(
|
||
{
|
||
"success": True,
|
||
"query": query.strip(),
|
||
"results": [],
|
||
"count": 0,
|
||
"message": "No matching sessions found.",
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
|
||
current_root = _resolve_to_parent(db, current_session_id) if current_session_id else None
|
||
seen_sessions: dict[str, dict[str, Any]] = {}
|
||
for result in raw_results:
|
||
raw_session_id = result["session_id"]
|
||
resolved_session_id = _resolve_to_parent(db, raw_session_id) or raw_session_id
|
||
if current_root and resolved_session_id == current_root:
|
||
continue
|
||
if current_session_id and raw_session_id == current_session_id:
|
||
continue
|
||
if resolved_session_id not in seen_sessions:
|
||
entry = dict(result)
|
||
entry["session_id"] = resolved_session_id
|
||
seen_sessions[resolved_session_id] = entry
|
||
if len(seen_sessions) >= limit:
|
||
break
|
||
|
||
prepared: list[tuple[str, dict[str, Any], str, dict[str, Any]]] = []
|
||
for session_id, match_info in seen_sessions.items():
|
||
try:
|
||
messages = db.get_messages_as_conversation(session_id)
|
||
if not messages:
|
||
continue
|
||
session_meta = db.get_session(session_id) or {}
|
||
transcript = _truncate_around_matches(_format_conversation(messages), query.strip())
|
||
prepared.append((session_id, match_info, transcript, session_meta))
|
||
except Exception as exc:
|
||
logging.warning("Failed to prepare session %s: %s", session_id, exc, exc_info=True)
|
||
|
||
if summarizer is not None:
|
||
summaries = await asyncio.gather(
|
||
*(summarizer(transcript, query.strip(), session_meta) for _, _, transcript, session_meta in prepared),
|
||
return_exceptions=True,
|
||
)
|
||
else:
|
||
summaries = [None] * len(prepared)
|
||
|
||
results: list[dict[str, Any]] = []
|
||
for (session_id, match_info, transcript, _), summary in zip(prepared, summaries):
|
||
resolved_summary: str | None
|
||
if isinstance(summary, Exception):
|
||
logging.warning("Failed to summarize session %s: %s", session_id, summary, exc_info=True)
|
||
resolved_summary = None
|
||
else:
|
||
resolved_summary = summary
|
||
|
||
if not resolved_summary:
|
||
preview = transcript[:500] + ("\n…[truncated]" if len(transcript) > 500 else "")
|
||
resolved_summary = f"[Raw preview — summarization unavailable]\n{preview}"
|
||
|
||
results.append(
|
||
{
|
||
"session_id": session_id,
|
||
"when": _format_timestamp(match_info.get("session_started")),
|
||
"source": match_info.get("source", "unknown"),
|
||
"model": match_info.get("model"),
|
||
"summary": resolved_summary,
|
||
}
|
||
)
|
||
|
||
return json.dumps(
|
||
{
|
||
"success": True,
|
||
"query": query.strip(),
|
||
"results": results,
|
||
"count": len(results),
|
||
"sessions_searched": len(seen_sessions),
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
|
||
|
||
@dataclass(slots=True)
|
||
class SessionSearchTool:
|
||
"""面向 runtime 的轻量 session_search 工具封装。"""
|
||
|
||
db: SessionSearchDB
|
||
current_session_id: str | None = None
|
||
summarizer: SessionSummarizer | None = None
|
||
name: str = "session_search"
|
||
description: str = SESSION_SEARCH_TOOL_DESCRIPTION
|
||
parameters: dict[str, Any] = field(default_factory=lambda: dict(SESSION_SEARCH_TOOL_PARAMETERS))
|
||
|
||
async def execute(self, **kwargs: Any) -> str:
|
||
current_session_id = kwargs.pop("current_session_id", None)
|
||
return await session_search(
|
||
db=self.db,
|
||
current_session_id=current_session_id if current_session_id is not None else self.current_session_id,
|
||
summarizer=self.summarizer,
|
||
**kwargs,
|
||
)
|