257 lines
9.9 KiB
Python
257 lines
9.9 KiB
Python
"""会话管理模块:负责对话历史的内存缓存与磁盘持久化。
|
||
|
||
存储设计:
|
||
1. 每个会话对应一个 `.jsonl` 文件(按行 JSON);
|
||
2. 第 1 行固定为 `_type=metadata` 的会话元数据;
|
||
3. 后续每行是按时间追加的消息对象(append-only)。
|
||
|
||
这样做的目的:
|
||
- 读写简单,便于排查;
|
||
- 追加友好,降低频繁改写历史内容的复杂度;
|
||
- 与记忆归档机制配合:归档只追加到 MEMORY/HISTORY,不回写旧消息。
|
||
"""
|
||
|
||
import json
|
||
import shutil
|
||
from pathlib import Path
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
from loguru import logger
|
||
|
||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
||
|
||
|
||
@dataclass
|
||
class Session:
|
||
"""
|
||
单个会话对象(通常对应一个 channel:chat_id)。
|
||
|
||
消息以 append-only 方式增长:
|
||
- `messages` 保存完整消息序列;
|
||
- `last_consolidated` 记录已归档到记忆文件的消息数量;
|
||
- `get_history()` 只返回未归档区间,供模型构建上下文。
|
||
|
||
注意:
|
||
记忆归档不会修改历史消息内容,只会推进 `last_consolidated` 游标。
|
||
"""
|
||
|
||
key: str # 会话唯一键,格式通常为 `channel:chat_id`
|
||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||
created_at: datetime = field(default_factory=datetime.now)
|
||
updated_at: datetime = field(default_factory=datetime.now)
|
||
metadata: dict[str, Any] = field(default_factory=dict)
|
||
# 已经归档到 MEMORY/HISTORY 的消息条数(不是时间戳)。
|
||
last_consolidated: int = 0
|
||
|
||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||
"""向会话追加一条消息并更新更新时间。"""
|
||
# kwargs 允许携带 tool_calls / tool_call_id / name 等扩展字段。
|
||
msg = {
|
||
"role": role,
|
||
"content": content,
|
||
"timestamp": datetime.now().isoformat(),
|
||
**kwargs
|
||
}
|
||
self.messages.append(msg)
|
||
self.updated_at = datetime.now()
|
||
|
||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||
"""返回用于 LLM 输入的历史消息(仅未归档部分)。"""
|
||
# 只取未归档区间,避免模型反复看到已经沉淀到记忆文件的旧信息。
|
||
unconsolidated = self.messages[self.last_consolidated:]
|
||
# 再截取最近窗口,控制上下文规模。
|
||
sliced = unconsolidated[-max_messages:]
|
||
|
||
# 丢弃开头连续的非 user 消息,避免出现“孤立 tool 结果”。
|
||
# 这样可以提高模型对当前轮上下文的可解释性。
|
||
for i, m in enumerate(sliced):
|
||
if m.get("role") == "user":
|
||
sliced = sliced[i:]
|
||
break
|
||
|
||
# 仅输出 LLM 需要的字段,避免把本地存储专用字段带入模型上下文。
|
||
out: list[dict[str, Any]] = []
|
||
for m in sliced:
|
||
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
|
||
for k in ("tool_calls", "tool_call_id", "name"):
|
||
if k in m:
|
||
entry[k] = m[k]
|
||
out.append(entry)
|
||
return out
|
||
|
||
def clear(self) -> None:
|
||
"""清空会话消息并重置归档游标。"""
|
||
self.messages = []
|
||
self.last_consolidated = 0
|
||
self.updated_at = datetime.now()
|
||
|
||
|
||
class SessionManager:
|
||
"""
|
||
会话管理器:负责会话对象的加载、保存、缓存与枚举。
|
||
|
||
目录约定:
|
||
- 新路径:`<workspace>/sessions/*.jsonl`
|
||
- 兼容旧路径:`~/.nanobot/sessions/*.jsonl`(按需迁移)
|
||
"""
|
||
|
||
def __init__(self, workspace: Path):
|
||
self.workspace = workspace
|
||
# 会话文件目录:不存在则自动创建。
|
||
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
||
# 旧版本全局目录(用于懒迁移)。
|
||
self.legacy_sessions_dir = Path.home() / ".nanobot" / "sessions"
|
||
# 进程内缓存:key -> Session,减少重复读盘。
|
||
self._cache: dict[str, Session] = {}
|
||
|
||
def _get_session_path(self, key: str) -> Path:
|
||
"""计算会话文件路径(workspace 下)。"""
|
||
# 统一把 `:` 映射为 `_`,并做文件名安全化。
|
||
safe_key = safe_filename(key.replace(":", "_"))
|
||
return self.sessions_dir / f"{safe_key}.jsonl"
|
||
|
||
def _get_legacy_session_path(self, key: str) -> Path:
|
||
"""计算旧版全局会话文件路径(~/.nanobot/sessions/)。"""
|
||
safe_key = safe_filename(key.replace(":", "_"))
|
||
return self.legacy_sessions_dir / f"{safe_key}.jsonl"
|
||
|
||
def get_or_create(self, key: str) -> Session:
|
||
"""获取会话;若不存在则从磁盘加载或新建。"""
|
||
# 先走内存缓存,避免同一轮多次访问重复读文件。
|
||
if key in self._cache:
|
||
return self._cache[key]
|
||
|
||
# 缓存未命中时尝试读盘,读不到就创建新会话。
|
||
session = self._load(key)
|
||
if session is None:
|
||
session = Session(key=key)
|
||
|
||
self._cache[key] = session
|
||
return session
|
||
|
||
def _load(self, key: str) -> Session | None:
|
||
"""从磁盘加载会话;失败返回 None。"""
|
||
path = self._get_session_path(key)
|
||
if not path.exists():
|
||
# 新路径不存在时,尝试从旧路径迁移。
|
||
legacy_path = self._get_legacy_session_path(key)
|
||
if legacy_path.exists():
|
||
try:
|
||
shutil.move(str(legacy_path), str(path))
|
||
logger.info("Migrated session {} from legacy path", key)
|
||
except Exception:
|
||
logger.exception("Failed to migrate session {}", key)
|
||
|
||
if not path.exists():
|
||
return None
|
||
|
||
try:
|
||
# JSONL 解析状态。
|
||
messages = []
|
||
metadata = {}
|
||
created_at = None
|
||
last_consolidated = 0
|
||
|
||
with open(path, encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
data = json.loads(line)
|
||
|
||
# 约定:metadata 行只有一条(通常是第一行)。
|
||
if data.get("_type") == "metadata":
|
||
metadata = data.get("metadata", {})
|
||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||
last_consolidated = data.get("last_consolidated", 0)
|
||
else:
|
||
# 其余行均视为消息体,按文件顺序恢复。
|
||
messages.append(data)
|
||
|
||
return Session(
|
||
key=key,
|
||
messages=messages,
|
||
created_at=created_at or datetime.now(),
|
||
metadata=metadata,
|
||
last_consolidated=last_consolidated
|
||
)
|
||
except Exception as e:
|
||
# 文件损坏/格式异常时容错,不阻塞上层流程。
|
||
logger.warning("Failed to load session {}: {}", key, e)
|
||
return None
|
||
|
||
def save(self, session: Session) -> None:
|
||
"""将会话完整写回磁盘,并刷新缓存。"""
|
||
path = self._get_session_path(session.key)
|
||
|
||
# 当前实现采用“整文件重写”策略:先 metadata,再逐条消息。
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
metadata_line = {
|
||
"_type": "metadata",
|
||
"key": session.key,
|
||
"created_at": session.created_at.isoformat(),
|
||
"updated_at": session.updated_at.isoformat(),
|
||
"metadata": session.metadata,
|
||
"last_consolidated": session.last_consolidated
|
||
}
|
||
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
||
for msg in session.messages:
|
||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||
|
||
self._cache[session.key] = session
|
||
|
||
def invalidate(self, key: str) -> None:
|
||
"""使指定会话缓存失效(不删除磁盘文件)。"""
|
||
self._cache.pop(key, None)
|
||
|
||
def delete(self, key: str) -> bool:
|
||
"""删除会话缓存与会话文件(新旧路径都尝试)。"""
|
||
self.invalidate(key)
|
||
|
||
path = self._get_session_path(key)
|
||
legacy_path = self._get_legacy_session_path(key)
|
||
|
||
deleted = False
|
||
try:
|
||
if path.exists():
|
||
path.unlink()
|
||
deleted = True
|
||
|
||
if legacy_path.exists():
|
||
legacy_path.unlink()
|
||
deleted = True
|
||
except Exception as e:
|
||
logger.error("Error deleting session {}: {}", key, e)
|
||
|
||
return deleted
|
||
|
||
def list_sessions(self) -> list[dict[str, Any]]:
|
||
"""列出会话目录下所有会话元信息。"""
|
||
sessions = []
|
||
|
||
for path in self.sessions_dir.glob("*.jsonl"):
|
||
try:
|
||
# 只读取首行 metadata,避免大文件全量扫描。
|
||
with open(path, encoding="utf-8") as f:
|
||
first_line = f.readline().strip()
|
||
if first_line:
|
||
data = json.loads(first_line)
|
||
if data.get("_type") == "metadata":
|
||
# 兼容旧数据:metadata 缺 key 时由文件名回推。
|
||
key = data.get("key") or path.stem.replace("_", ":", 1)
|
||
sessions.append({
|
||
"key": key,
|
||
"created_at": data.get("created_at"),
|
||
"updated_at": data.get("updated_at"),
|
||
"path": str(path)
|
||
})
|
||
except Exception:
|
||
# 单个文件损坏不影响整体列表结果。
|
||
continue
|
||
|
||
# 最新活跃会话优先。
|
||
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
|