Files
beaver_project/app-instance/backend/beaver/engine/session/search.py

152 lines
4.7 KiB
Python

"""Beaver session 子系统的检索能力。"""
from __future__ import annotations
import re
import sqlite3
from typing import Any
from .store import SessionStore
class SessionSearchService:
"""围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。"""
def __init__(self, store: SessionStore) -> None:
self.store = store
@staticmethod
def _sanitize_fts5_query(query: str) -> str:
quoted_parts: list[str] = []
def preserve(match: re.Match[str]) -> str:
quoted_parts.append(match.group(0))
return f"\x00Q{len(quoted_parts) - 1}\x00"
sanitized = re.sub(r'"[^"]*"', preserve, query)
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
sanitized = re.sub(r"\*+", "*", sanitized)
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized)
for index, quoted in enumerate(quoted_parts):
sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted)
return sanitized.strip()
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
"""用完整 ID 或唯一前缀解析出目标 session_id。"""
exact = self.store.get_session_record(session_id_or_prefix)
if exact is not None:
return exact.session_id
escaped = (
session_id_or_prefix
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
rows = self.store._fetchall(
"""
SELECT id
FROM sessions
WHERE id LIKE ? ESCAPE '\\'
ORDER BY started_at DESC
LIMIT 2
""",
(f"{escaped}%",),
)
if len(rows) == 1:
return rows[0]["id"]
return None
def list_sessions_rich(
self,
*,
limit: int = 20,
offset: int = 0,
include_children: bool = False,
source: str | None = None,
exclude_sources: list[str] | None = None,
) -> list[dict[str, Any]]:
"""列出最近活跃的 session 及其摘要元数据。"""
clauses: list[str] = []
params: list[Any] = []
if not include_children:
clauses.append("parent_session_id IS NULL")
if source:
clauses.append("source = ?")
params.append(source)
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
clauses.append(f"source NOT IN ({placeholders})")
params.extend(exclude_sources)
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
params.extend([limit, offset])
rows = self.store._fetchall(
f"""
SELECT *
FROM sessions
{where}
ORDER BY last_active DESC
LIMIT ? OFFSET ?
""",
tuple(params),
)
return rows
def search_messages(
self,
*,
query: str,
role_filter: list[str] | None = None,
exclude_sources: list[str] | None = None,
limit: int = 20,
offset: int = 0,
) -> list[dict[str, Any]]:
"""使用 FTS5 搜索 session transcript。"""
query = self._sanitize_fts5_query(query)
if not query:
return []
clauses = ["messages_fts MATCH ?", "m.context_visible = 1"]
params: list[Any] = [query]
if exclude_sources:
placeholders = ",".join("?" for _ in exclude_sources)
clauses.append(f"s.source NOT IN ({placeholders})")
params.extend(exclude_sources)
if role_filter:
placeholders = ",".join("?" for _ in role_filter)
clauses.append(f"m.role IN ({placeholders})")
params.extend(role_filter)
params.extend([limit, offset])
sql = f"""
SELECT
m.id,
m.session_id,
m.role,
s.source,
s.model,
s.started_at AS session_started,
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet
FROM messages_fts
JOIN messages m ON m.id = messages_fts.rowid
JOIN sessions s ON s.id = m.session_id
WHERE {' AND '.join(clauses)}
ORDER BY rank
LIMIT ? OFFSET ?
"""
try:
return self.store._fetchall(sql, tuple(params))
except sqlite3.Error as exc:
raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc