"""Beaver session 子系统的检索能力。""" from __future__ import annotations import re import sqlite3 from typing import Any from .store import SessionStore class SessionSearchService: """围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。""" def __init__(self, store: SessionStore) -> None: self.store = store @staticmethod def _sanitize_fts5_query(query: str) -> str: quoted_parts: list[str] = [] def preserve(match: re.Match[str]) -> str: quoted_parts.append(match.group(0)) return f"\x00Q{len(quoted_parts) - 1}\x00" sanitized = re.sub(r'"[^"]*"', preserve, query) sanitized = re.sub(r'[+{}()\"^]', " ", sanitized) sanitized = re.sub(r"\*+", "*", sanitized) sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized) sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip()) sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip()) sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized) for index, quoted in enumerate(quoted_parts): sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted) return sanitized.strip() def resolve_session_id(self, session_id_or_prefix: str) -> str | None: """用完整 ID 或唯一前缀解析出目标 session_id。""" exact = self.store.get_session_record(session_id_or_prefix) if exact is not None: return exact.session_id escaped = ( session_id_or_prefix .replace("\\", "\\\\") .replace("%", "\\%") .replace("_", "\\_") ) rows = self.store._fetchall( """ SELECT id FROM sessions WHERE id LIKE ? ESCAPE '\\' ORDER BY started_at DESC LIMIT 2 """, (f"{escaped}%",), ) if len(rows) == 1: return rows[0]["id"] return None def list_sessions_rich( self, *, limit: int = 20, offset: int = 0, include_children: bool = False, source: str | None = None, exclude_sources: list[str] | None = None, exclude_end_reasons: list[str] | None = None, ) -> list[dict[str, Any]]: """列出最近活跃的 session 及其摘要元数据。""" clauses: list[str] = [] params: list[Any] = [] if not include_children: clauses.append("parent_session_id IS NULL") if source: clauses.append("source = ?") params.append(source) if exclude_sources: placeholders = ",".join("?" for _ in exclude_sources) clauses.append(f"source NOT IN ({placeholders})") params.extend(exclude_sources) if exclude_end_reasons: placeholders = ",".join("?" for _ in exclude_end_reasons) clauses.append(f"(end_reason IS NULL OR end_reason NOT IN ({placeholders}))") params.extend(exclude_end_reasons) where = f"WHERE {' AND '.join(clauses)}" if clauses else "" params.extend([limit, offset]) rows = self.store._fetchall( f""" SELECT * FROM sessions {where} ORDER BY last_active DESC LIMIT ? OFFSET ? """, tuple(params), ) return rows def search_messages( self, *, query: str, role_filter: list[str] | None = None, exclude_sources: list[str] | None = None, limit: int = 20, offset: int = 0, ) -> list[dict[str, Any]]: """使用 FTS5 搜索 session transcript。""" query = self._sanitize_fts5_query(query) if not query: return [] clauses = ["messages_fts MATCH ?", "m.context_visible = 1"] params: list[Any] = [query] if exclude_sources: placeholders = ",".join("?" for _ in exclude_sources) clauses.append(f"s.source NOT IN ({placeholders})") params.extend(exclude_sources) if role_filter: placeholders = ",".join("?" for _ in role_filter) clauses.append(f"m.role IN ({placeholders})") params.extend(role_filter) params.extend([limit, offset]) sql = f""" SELECT m.id, m.session_id, m.role, s.source, s.model, s.started_at AS session_started, snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet FROM messages_fts JOIN messages m ON m.id = messages_fts.rowid JOIN sessions s ON s.id = m.session_id WHERE {' AND '.join(clauses)} ORDER BY rank LIMIT ? OFFSET ? """ try: return self.store._fetchall(sql, tuple(params)) except sqlite3.Error as exc: raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc