152 lines
4.7 KiB
Python
152 lines
4.7 KiB
Python
"""Beaver session 子系统的检索能力。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
import sqlite3
|
|
from typing import Any
|
|
|
|
from .store import SessionStore
|
|
|
|
|
|
class SessionSearchService:
|
|
"""围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。"""
|
|
|
|
def __init__(self, store: SessionStore) -> None:
|
|
self.store = store
|
|
|
|
@staticmethod
|
|
def _sanitize_fts5_query(query: str) -> str:
|
|
quoted_parts: list[str] = []
|
|
|
|
def preserve(match: re.Match[str]) -> str:
|
|
quoted_parts.append(match.group(0))
|
|
return f"\x00Q{len(quoted_parts) - 1}\x00"
|
|
|
|
sanitized = re.sub(r'"[^"]*"', preserve, query)
|
|
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
|
|
sanitized = re.sub(r"\*+", "*", sanitized)
|
|
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
|
|
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
|
|
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
|
|
sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized)
|
|
|
|
for index, quoted in enumerate(quoted_parts):
|
|
sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted)
|
|
return sanitized.strip()
|
|
|
|
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
|
|
"""用完整 ID 或唯一前缀解析出目标 session_id。"""
|
|
|
|
exact = self.store.get_session_record(session_id_or_prefix)
|
|
if exact is not None:
|
|
return exact.session_id
|
|
|
|
escaped = (
|
|
session_id_or_prefix
|
|
.replace("\\", "\\\\")
|
|
.replace("%", "\\%")
|
|
.replace("_", "\\_")
|
|
)
|
|
rows = self.store._fetchall(
|
|
"""
|
|
SELECT id
|
|
FROM sessions
|
|
WHERE id LIKE ? ESCAPE '\\'
|
|
ORDER BY started_at DESC
|
|
LIMIT 2
|
|
""",
|
|
(f"{escaped}%",),
|
|
)
|
|
if len(rows) == 1:
|
|
return rows[0]["id"]
|
|
return None
|
|
|
|
def list_sessions_rich(
|
|
self,
|
|
*,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
include_children: bool = False,
|
|
source: str | None = None,
|
|
exclude_sources: list[str] | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""列出最近活跃的 session 及其摘要元数据。"""
|
|
|
|
clauses: list[str] = []
|
|
params: list[Any] = []
|
|
|
|
if not include_children:
|
|
clauses.append("parent_session_id IS NULL")
|
|
if source:
|
|
clauses.append("source = ?")
|
|
params.append(source)
|
|
if exclude_sources:
|
|
placeholders = ",".join("?" for _ in exclude_sources)
|
|
clauses.append(f"source NOT IN ({placeholders})")
|
|
params.extend(exclude_sources)
|
|
|
|
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
|
|
params.extend([limit, offset])
|
|
rows = self.store._fetchall(
|
|
f"""
|
|
SELECT *
|
|
FROM sessions
|
|
{where}
|
|
ORDER BY last_active DESC
|
|
LIMIT ? OFFSET ?
|
|
""",
|
|
tuple(params),
|
|
)
|
|
return rows
|
|
|
|
def search_messages(
|
|
self,
|
|
*,
|
|
query: str,
|
|
role_filter: list[str] | None = None,
|
|
exclude_sources: list[str] | None = None,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
) -> list[dict[str, Any]]:
|
|
"""使用 FTS5 搜索 session transcript。"""
|
|
|
|
query = self._sanitize_fts5_query(query)
|
|
if not query:
|
|
return []
|
|
|
|
clauses = ["messages_fts MATCH ?", "m.context_visible = 1"]
|
|
params: list[Any] = [query]
|
|
|
|
if exclude_sources:
|
|
placeholders = ",".join("?" for _ in exclude_sources)
|
|
clauses.append(f"s.source NOT IN ({placeholders})")
|
|
params.extend(exclude_sources)
|
|
if role_filter:
|
|
placeholders = ",".join("?" for _ in role_filter)
|
|
clauses.append(f"m.role IN ({placeholders})")
|
|
params.extend(role_filter)
|
|
|
|
params.extend([limit, offset])
|
|
sql = f"""
|
|
SELECT
|
|
m.id,
|
|
m.session_id,
|
|
m.role,
|
|
s.source,
|
|
s.model,
|
|
s.started_at AS session_started,
|
|
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet
|
|
FROM messages_fts
|
|
JOIN messages m ON m.id = messages_fts.rowid
|
|
JOIN sessions s ON s.id = m.session_id
|
|
WHERE {' AND '.join(clauses)}
|
|
ORDER BY rank
|
|
LIMIT ? OFFSET ?
|
|
"""
|
|
|
|
try:
|
|
return self.store._fetchall(sql, tuple(params))
|
|
except sqlite3.Error as exc:
|
|
raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc
|