修改了nanobot,往Hermes agent的风格走,进度1/3
This commit is contained in:
31
app-instance/backend/beaver/engine/__init__.py
Normal file
31
app-instance/backend/beaver/engine/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""Unified Beaver agent engine.
|
||||
|
||||
这里不做顶层 eager import,避免子模块导入时触发循环依赖。
|
||||
对外仍然保留同样的导出名称,但改成按需加载。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["AgentLoop", "AgentProfile", "AgentRunResult", "EngineLoader", "EngineLoadResult"]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name == "EngineLoader":
|
||||
from .loader import EngineLoader
|
||||
|
||||
return EngineLoader
|
||||
if name == "EngineLoadResult":
|
||||
from .loader import EngineLoadResult
|
||||
|
||||
return EngineLoadResult
|
||||
if name in {"AgentLoop", "AgentProfile", "AgentRunResult"}:
|
||||
from .loop import AgentLoop, AgentProfile, AgentRunResult
|
||||
|
||||
return {
|
||||
"AgentLoop": AgentLoop,
|
||||
"AgentProfile": AgentProfile,
|
||||
"AgentRunResult": AgentRunResult,
|
||||
}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
17
app-instance/backend/beaver/engine/context/__init__.py
Normal file
17
app-instance/backend/beaver/engine/context/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Context assembly for agent runs."""
|
||||
|
||||
from .builder import (
|
||||
ContextBuildInput,
|
||||
ContextBuildResult,
|
||||
ContextBuilder,
|
||||
SessionContext,
|
||||
SkillContext,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContextBuildInput",
|
||||
"ContextBuildResult",
|
||||
"ContextBuilder",
|
||||
"SessionContext",
|
||||
"SkillContext",
|
||||
]
|
||||
331
app-instance/backend/beaver/engine/context/builder.py
Normal file
331
app-instance/backend/beaver/engine/context/builder.py
Normal file
@ -0,0 +1,331 @@
|
||||
"""Beaver 运行时上下文装配器。
|
||||
|
||||
这个模块是 `session` 和 `provider` 之间的中间层,职责非常明确:
|
||||
|
||||
1. 把运行前已经准备好的静态/半静态上下文拼成一份稳定的 system prompt
|
||||
2. 把从 session 事件流里裁剪出的“可见历史”和当前用户输入整理成 provider 可直接消费的 messages
|
||||
3. 在 tool loop 中,持续把 assistant/tool 消息按统一格式追加回消息数组
|
||||
|
||||
为什么这层必须单独存在:
|
||||
|
||||
1. `AgentLoop` 不应该自己拼 prompt,否则很快又会长成一个大文件
|
||||
2. `memory`、`skills`、`session` 的注入顺序需要固定,否则模型行为会漂移
|
||||
3. tool loop 前后追加消息的格式必须统一,否则不同 provider 很容易出兼容问题
|
||||
|
||||
这一版 builder 的设计目标是“最小但稳定”:
|
||||
|
||||
1. 先服务单 agent 主链
|
||||
2. 先支持 frozen curated memory,而不是 live memory
|
||||
3. skills 按 Hermes 风格支持“显式激活消息注入”,不在这里做磁盘扫描
|
||||
4. 为后续 channel / gateway / team metadata 预留注入位,但不提前做复杂逻辑
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SkillContext:
|
||||
"""单个已激活 skill 的最小表示。
|
||||
|
||||
这里故意不把 skill 设计成复杂对象,只保留 builder 真正关心的两部分:
|
||||
|
||||
- `name`:用于生成激活提示
|
||||
- `content`:skill 的完整正文
|
||||
|
||||
注意:按当前 Hermes 风格实现,skill 正文不再塞进 system prompt,而是转成显式消息注入。
|
||||
"""
|
||||
|
||||
name: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionContext:
|
||||
"""当前运行轮次的会话元数据。
|
||||
|
||||
这不是 session store 里的完整 record,而是 prompt builder 关心的那一小部分:
|
||||
- 哪个 session
|
||||
- 来源是什么
|
||||
- 当前使用什么 model
|
||||
- 是否有 channel/chat/user 这类运行路由信息
|
||||
|
||||
把它单独抽出来的原因是:
|
||||
1. builder 不应该知道 SQLite row 长什么样
|
||||
2. 不同入口(CLI/Web/Gateway)都可以把自己的 metadata 收敛成同一种结构
|
||||
"""
|
||||
|
||||
session_id: str | None = None
|
||||
source: str | None = None
|
||||
model: str | None = None
|
||||
user_id: str | None = None
|
||||
channel: str | None = None
|
||||
chat_id: str | None = None
|
||||
parent_session_id: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ContextBuildInput:
|
||||
"""一次上下文构建所需的全部输入。
|
||||
|
||||
这个对象的作用不是“炫技式封装”,而是把主链里零散的数据显式收口。
|
||||
这样一来,后面 `AgentLoop.process_direct()` 在组装参数时会更清晰,也更容易测试。
|
||||
|
||||
字段分组:
|
||||
- 身份/基础段:`base_system_prompt`
|
||||
- 会话可见历史:`history`
|
||||
- 当前输入:`current_user_input`
|
||||
- 冻结记忆:`memory_snapshot`
|
||||
- 技能:`activated_skills`
|
||||
- 运行元数据:`session_context` / `execution_context`
|
||||
- 额外扩展:`extra_sections`
|
||||
"""
|
||||
|
||||
base_system_prompt: str = ""
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
current_user_input: str | list[dict[str, Any]] | None = None
|
||||
memory_snapshot: MemorySnapshot | None = None
|
||||
activated_skills: list[SkillContext] = field(default_factory=list)
|
||||
session_context: SessionContext | None = None
|
||||
execution_context: str | None = None
|
||||
extra_sections: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ContextBuildResult:
|
||||
"""一次上下文构建后的结果。
|
||||
|
||||
保留 `system_prompt` 的原因:
|
||||
1. `SessionManager.update_system_prompt()` 需要把最终注入的 prompt snapshot 落盘
|
||||
2. 调试时经常需要区分“system prompt 长什么样”和“messages 长什么样”
|
||||
3. 后面如果做 prompt audit / replay,也会直接复用这个结果
|
||||
"""
|
||||
|
||||
system_prompt: str
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""负责把运行时输入装配成稳定上下文。
|
||||
|
||||
这一层故意保持“无 IO、无数据库、无网络”:
|
||||
- 不直接读 session store
|
||||
- 不直接读 memory store
|
||||
- 不直接扫描 skills 目录
|
||||
|
||||
这样 builder 的行为只由输入决定,便于单测,也便于后面并到真正的 AgentLoop 主链里。
|
||||
"""
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
build_input: ContextBuildInput,
|
||||
) -> str:
|
||||
"""构建 system prompt。
|
||||
|
||||
顺序固定非常重要,当前约定是:
|
||||
|
||||
1. base system prompt
|
||||
2. session metadata
|
||||
3. execution context
|
||||
4. frozen memory snapshot
|
||||
5. extra sections
|
||||
|
||||
这样设计的原因:
|
||||
- 身份与总规则要最靠前
|
||||
- session/execution 是本轮运行语境,优先级高于长期记忆
|
||||
- memory 必须是 frozen snapshot,避免中途写 memory 后 prompt 失真
|
||||
- activated skill 正文按 Hermes 风格放到显式消息里,避免 system prompt 持续膨胀
|
||||
"""
|
||||
|
||||
sections: list[str] = []
|
||||
|
||||
base_system_prompt = (build_input.base_system_prompt or "").strip()
|
||||
if base_system_prompt:
|
||||
sections.append(base_system_prompt)
|
||||
|
||||
session_section = self._render_session_section(build_input.session_context)
|
||||
if session_section:
|
||||
sections.append(session_section)
|
||||
|
||||
execution_context = (build_input.execution_context or "").strip()
|
||||
if execution_context:
|
||||
sections.append(f"# Execution Context\n\n{execution_context}")
|
||||
|
||||
if build_input.memory_snapshot is not None:
|
||||
# 这里明确只读 frozen snapshot,而不是去读 live memory store。
|
||||
# 否则一旦当前会话中途写 memory,system prompt 语义就会和会话开头不一致。
|
||||
snapshot_sections = build_input.memory_snapshot.as_prompt_sections()
|
||||
if snapshot_sections:
|
||||
sections.extend(snapshot_sections)
|
||||
|
||||
for extra in build_input.extra_sections:
|
||||
cleaned = (extra or "").strip()
|
||||
if cleaned:
|
||||
sections.append(cleaned)
|
||||
|
||||
return "\n\n---\n\n".join(sections)
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
build_input: ContextBuildInput,
|
||||
) -> ContextBuildResult:
|
||||
"""构建一次模型调用的完整 messages。
|
||||
|
||||
这里做三件事:
|
||||
1. 先生成最终 system prompt
|
||||
2. 按 Hermes 风格,把已激活 skill 的完整正文作为显式消息注入
|
||||
3. 把历史消息按原顺序接到后面
|
||||
4. 如果存在当前用户输入,则把本轮输入追加为最后一条 user message
|
||||
|
||||
注意:
|
||||
- `history` 默认被视为“已经由 session/context 上游从完整事件流中裁剪好的可见结构”
|
||||
- builder 不负责裁剪历史窗口,这件事应由 session/loop 上层决定
|
||||
- builder 只做最小格式统一
|
||||
"""
|
||||
|
||||
system_prompt = self.build_system_prompt(build_input)
|
||||
messages: list[dict[str, Any]] = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
messages.extend(self.build_skill_activation_messages(build_input.activated_skills))
|
||||
|
||||
for message in build_input.history:
|
||||
# 当前 builder 自己负责生成唯一的 system prompt。
|
||||
# 如果上游 history 已经混入 system 消息,这里要主动跳过,避免双 system。
|
||||
if message.get("role") == "system":
|
||||
continue
|
||||
messages.append(dict(message))
|
||||
|
||||
if build_input.current_user_input is not None:
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": build_input.current_user_input,
|
||||
}
|
||||
)
|
||||
|
||||
return ContextBuildResult(
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
def add_tool_result(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""向消息数组追加一条 tool result。
|
||||
|
||||
为什么这个函数放在 builder,而不是塞回 `AgentLoop`:
|
||||
- tool message 的结构必须和 provider 兼容
|
||||
- 统一在这里追加,可以避免不同执行路径拼出不同字段名
|
||||
- 后面如果要兼容更多 provider 差异,也只改这一层
|
||||
|
||||
这里返回原 list 本身,保持旧项目的“可链式追加”习惯。
|
||||
"""
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": tool_name,
|
||||
"content": result,
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
def add_assistant_message(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""向消息数组追加 assistant 消息。
|
||||
|
||||
这里有两个实现细节非常重要:
|
||||
|
||||
1. 无论 `content` 是否为空,都显式写入 `content` 键
|
||||
原因是部分 provider 在 assistant 带 `tool_calls` 时仍要求消息里存在 `content`
|
||||
|
||||
2. `reasoning_content` 只有在非空时才附带
|
||||
因为这属于思考模型扩展字段,不应污染普通 provider 路径
|
||||
"""
|
||||
|
||||
message: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
if tool_calls:
|
||||
message["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None:
|
||||
message["reasoning_content"] = reasoning_content
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
def _render_session_section(self, session_context: SessionContext | None) -> str | None:
|
||||
"""把运行时 session metadata 渲染成一个可读 section。
|
||||
|
||||
这一段的目标不是让模型“记住所有数据库字段”,而是给它足够的当前运行语境。
|
||||
常见用途包括:
|
||||
- 知道当前来自 CLI 还是 Web/Gateway
|
||||
- 知道当前使用什么 model
|
||||
- 知道当前 channel/chat_id,便于后续多渠道行为约束
|
||||
"""
|
||||
|
||||
if session_context is None:
|
||||
return None
|
||||
|
||||
rows: list[str] = []
|
||||
if session_context.session_id:
|
||||
rows.append(f"Session ID: {session_context.session_id}")
|
||||
if session_context.source:
|
||||
rows.append(f"Source: {session_context.source}")
|
||||
if session_context.model:
|
||||
rows.append(f"Model: {session_context.model}")
|
||||
if session_context.user_id:
|
||||
rows.append(f"User ID: {session_context.user_id}")
|
||||
if session_context.channel:
|
||||
rows.append(f"Channel: {session_context.channel}")
|
||||
if session_context.chat_id:
|
||||
rows.append(f"Chat ID: {session_context.chat_id}")
|
||||
if session_context.parent_session_id:
|
||||
rows.append(f"Parent Session ID: {session_context.parent_session_id}")
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
return "# Current Session\n\n" + "\n".join(rows)
|
||||
|
||||
def build_skill_activation_messages(self, activated_skills: list[SkillContext]) -> list[dict[str, str]]:
|
||||
"""按 Hermes 风格把已激活 skill 转成显式消息。
|
||||
|
||||
关键区别:
|
||||
- system prompt 只保留轻量 skills index
|
||||
- 真正生效的 skill 正文通过额外消息块显式加载
|
||||
|
||||
这样模型不需要“从摘要里猜怎么读到正文”,而是直接拿到完整指导内容。
|
||||
"""
|
||||
|
||||
messages: list[dict[str, str]] = []
|
||||
for skill in activated_skills:
|
||||
content = (skill.content or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f'[SYSTEM: The "{skill.name}" skill is active for this run. '
|
||||
"Follow its instructions as active guidance unless the user overrides them.]\n\n"
|
||||
f"{content}"
|
||||
),
|
||||
}
|
||||
)
|
||||
return messages
|
||||
154
app-instance/backend/beaver/engine/loader.py
Normal file
154
app-instance/backend/beaver/engine/loader.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""Centralized runtime loading for Beaver agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from beaver.engine.context import ContextBuilder
|
||||
from beaver.engine.session import SessionManager
|
||||
from beaver.memory.curated.store import MemoryStore
|
||||
from beaver.services.memory_service import MemoryService
|
||||
from beaver.skills import SkillAssembler, SkillsLoader
|
||||
from beaver.tools import ObjectBackedTool, ToolExecutor, ToolRegistry
|
||||
from beaver.tools.builtins import EchoTool, MemoryTool, SessionSearchTool, SkillViewTool
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class EngineLoadResult:
|
||||
"""描述当前 agent runtime 已经装好的依赖。
|
||||
|
||||
这里同时保留两类字段:
|
||||
1. `tools/skills/memory_stores/permissions`
|
||||
- 便于做状态展示、调试、轻量测试
|
||||
2. `session_manager/tool_registry/...`
|
||||
- 供真正的运行时主链直接使用
|
||||
"""
|
||||
|
||||
workspace: Path
|
||||
tools: list[str] = field(default_factory=list)
|
||||
skills: list[str] = field(default_factory=list)
|
||||
memory_stores: list[str] = field(default_factory=list)
|
||||
permissions: list[str] = field(default_factory=list)
|
||||
session_manager: SessionManager | None = None
|
||||
curated_memory_store: MemoryStore | None = None
|
||||
memory_service: MemoryService | None = None
|
||||
tool_registry: ToolRegistry | None = None
|
||||
tool_executor: ToolExecutor | None = None
|
||||
context_builder: ContextBuilder | None = None
|
||||
skills_loader: SkillsLoader | None = None
|
||||
skill_assembler: SkillAssembler | None = None
|
||||
closeables: list[tuple[str, Callable[[], None]]] = field(default_factory=list, repr=False)
|
||||
closed: bool = False
|
||||
|
||||
def register_closeable(self, name: str, close_fn: Callable[[], None]) -> None:
|
||||
"""登记一个由 runtime 统一关闭的资源。"""
|
||||
|
||||
self.closeables.append((name, close_fn))
|
||||
|
||||
def close(self) -> None:
|
||||
"""按后进先出顺序关闭 runtime 资源。
|
||||
|
||||
这一步先保持同步、最小、可组合:
|
||||
1. 只管理已经明确需要关闭的资源
|
||||
2. 暂不引入 async shutdown 协议
|
||||
3. 为后续 Web/Gateway lifespan 留统一入口
|
||||
"""
|
||||
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
errors: list[tuple[str, BaseException]] = []
|
||||
for name, close_fn in reversed(self.closeables):
|
||||
try:
|
||||
close_fn()
|
||||
except BaseException as exc: # pragma: no cover - defensive cleanup path
|
||||
errors.append((name, exc))
|
||||
self.closed = True
|
||||
|
||||
if errors:
|
||||
parts = ", ".join(f"{name}: {exc}" for name, exc in errors)
|
||||
raise RuntimeError(f"Runtime shutdown failed for {parts}")
|
||||
|
||||
|
||||
class EngineLoader:
|
||||
"""为任意 Beaver agent 装载共享 runtime 能力。
|
||||
|
||||
当前先做“最小可运行主链”需要的装配:
|
||||
- session manager
|
||||
- curated memory store
|
||||
- context builder
|
||||
- built-in tools
|
||||
- tool executor
|
||||
|
||||
等主链跑稳后,再把 skills、权限、MCP、delegation 逐步加进来。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
session_manager: SessionManager | None = None,
|
||||
curated_memory_store: MemoryStore | None = None,
|
||||
memory_service: MemoryService | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
context_builder: ContextBuilder | None = None,
|
||||
skills_loader: SkillsLoader | None = None,
|
||||
skill_assembler: SkillAssembler | None = None,
|
||||
) -> None:
|
||||
self.workspace = Path(workspace or Path.cwd())
|
||||
self._session_manager = session_manager
|
||||
self._curated_memory_store = curated_memory_store
|
||||
self._memory_service = memory_service
|
||||
self._tool_registry = tool_registry
|
||||
self._context_builder = context_builder
|
||||
self._skills_loader = skills_loader
|
||||
self._skill_assembler = skill_assembler
|
||||
|
||||
def load(self) -> EngineLoadResult:
|
||||
"""装配当前主链需要的最小 runtime 对象。"""
|
||||
|
||||
workspace = self.workspace
|
||||
session_manager = self._session_manager or SessionManager(workspace)
|
||||
|
||||
curated_root = workspace / "memory" / "curated"
|
||||
curated_memory_store = self._curated_memory_store or MemoryStore(curated_root)
|
||||
memory_service = self._memory_service or MemoryService(curated_root, store=curated_memory_store)
|
||||
memory_service.initialize()
|
||||
|
||||
tool_registry = self._tool_registry or ToolRegistry()
|
||||
skills_loader = self._skills_loader or SkillsLoader(workspace)
|
||||
if self._tool_registry is None:
|
||||
# 这里先注册最小工具集,满足主链的 tool loop。
|
||||
tool_registry.register_many(
|
||||
[
|
||||
ObjectBackedTool(EchoTool()),
|
||||
ObjectBackedTool(MemoryTool(store=memory_service.get_store())),
|
||||
ObjectBackedTool(SkillViewTool(loader=skills_loader)),
|
||||
ObjectBackedTool(SessionSearchTool(db=session_manager)),
|
||||
]
|
||||
)
|
||||
|
||||
context_builder = self._context_builder or ContextBuilder()
|
||||
tool_executor = ToolExecutor(tool_registry)
|
||||
skill_assembler = self._skill_assembler or SkillAssembler(skills_loader)
|
||||
|
||||
result = EngineLoadResult(
|
||||
workspace=workspace,
|
||||
tools=[spec.name for spec in tool_registry.list_specs()],
|
||||
skills=[record.name for record in skills_loader.list_skills(filter_unavailable=False)],
|
||||
memory_stores=["curated"],
|
||||
permissions=[],
|
||||
session_manager=session_manager,
|
||||
curated_memory_store=memory_service.get_store(),
|
||||
memory_service=memory_service,
|
||||
tool_registry=tool_registry,
|
||||
tool_executor=tool_executor,
|
||||
context_builder=context_builder,
|
||||
skills_loader=skills_loader,
|
||||
skill_assembler=skill_assembler,
|
||||
)
|
||||
if self._session_manager is None:
|
||||
result.register_closeable("session_manager", session_manager.close)
|
||||
return result
|
||||
689
app-instance/backend/beaver/engine/loop.py
Normal file
689
app-instance/backend/beaver/engine/loop.py
Normal file
@ -0,0 +1,689 @@
|
||||
"""Unified agent loop used by all Beaver agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from beaver.engine.context import ContextBuildInput, SessionContext
|
||||
from beaver.engine.providers import ProviderBundle, make_provider_bundle
|
||||
from beaver.tools import ToolContext
|
||||
|
||||
from .loader import EngineLoader, EngineLoadResult
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentProfile:
|
||||
"""Runtime profile for a Beaver agent instance."""
|
||||
|
||||
name: str = "default"
|
||||
system_prompt: str = ""
|
||||
default_model: str = "gpt-4.1-mini"
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.2
|
||||
max_tool_iterations: int = 8
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunResult:
|
||||
"""一次 direct run 的最小结果结构。"""
|
||||
|
||||
session_id: str
|
||||
run_id: str
|
||||
output_text: str
|
||||
finish_reason: str
|
||||
tool_iterations: int
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
usage: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _DirectRunRequest:
|
||||
"""运行循环中的单个 direct task。"""
|
||||
|
||||
task: str
|
||||
kwargs: dict[str, Any]
|
||||
future: asyncio.Future[AgentRunResult]
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""Single execution kernel shared by root agents and delegated agents."""
|
||||
|
||||
def __init__(self, *, profile: AgentProfile | None = None, loader: EngineLoader | None = None) -> None:
|
||||
self.profile = profile or AgentProfile()
|
||||
self.loader = loader or EngineLoader()
|
||||
self.loaded: EngineLoadResult | None = None
|
||||
self._run_queue: asyncio.Queue[_DirectRunRequest | None] | None = None
|
||||
self._running = False
|
||||
self._stop_requested = False
|
||||
|
||||
def boot(self) -> EngineLoadResult:
|
||||
"""Load shared runtime capabilities once for this agent instance."""
|
||||
if self.loaded is None:
|
||||
self.loaded = self.loader.load()
|
||||
return self.loaded
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running
|
||||
|
||||
async def run(self) -> None:
|
||||
"""启动最小运行循环,顺序消费提交进来的 direct tasks。
|
||||
|
||||
第一版故意保持克制:
|
||||
1. 只做单消费者串行消费
|
||||
2. 真正执行仍复用 `process_direct()`
|
||||
3. 不引入 bus / worker / priority / retry
|
||||
"""
|
||||
|
||||
if self._running:
|
||||
raise RuntimeError("AgentLoop.run() is already active")
|
||||
|
||||
self.boot()
|
||||
self._run_queue = asyncio.Queue()
|
||||
self._running = True
|
||||
self._stop_requested = False
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await self._run_queue.get()
|
||||
if item is None:
|
||||
if self._stop_requested:
|
||||
break
|
||||
continue
|
||||
|
||||
if item.future.cancelled():
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await self._process_direct_impl(item.task, **item.kwargs)
|
||||
except asyncio.CancelledError:
|
||||
if not item.future.done():
|
||||
item.future.cancel()
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive queue path
|
||||
if not item.future.done():
|
||||
item.future.set_exception(exc)
|
||||
else:
|
||||
if not item.future.done():
|
||||
item.future.set_result(result)
|
||||
finally:
|
||||
if self._run_queue is not None:
|
||||
while True:
|
||||
try:
|
||||
pending = self._run_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
if isinstance(pending, _DirectRunRequest) and not pending.future.done():
|
||||
pending.future.set_exception(
|
||||
RuntimeError("AgentLoop.run() stopped before processing the queued task")
|
||||
)
|
||||
self._running = False
|
||||
self._stop_requested = False
|
||||
self._run_queue = None
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止运行循环。
|
||||
|
||||
第一版语义:
|
||||
- 不再接收新任务
|
||||
- 当前已经取出的任务允许收尾
|
||||
- 不自动 close runtime
|
||||
"""
|
||||
|
||||
if not self._running or self._run_queue is None:
|
||||
return
|
||||
self._stop_requested = True
|
||||
await self._run_queue.put(None)
|
||||
|
||||
async def submit_direct(
|
||||
self,
|
||||
task: str,
|
||||
**kwargs: Any,
|
||||
) -> AgentRunResult:
|
||||
"""向运行中的 loop 提交一个 direct task,并等待结果。"""
|
||||
|
||||
if not self._running or self._run_queue is None:
|
||||
raise RuntimeError("AgentLoop.submit_direct() requires an active run() loop")
|
||||
if self._stop_requested:
|
||||
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
|
||||
|
||||
future: asyncio.Future[AgentRunResult] = asyncio.get_running_loop().create_future()
|
||||
await self._run_queue.put(_DirectRunRequest(task=task, kwargs=dict(kwargs), future=future))
|
||||
return await future
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭当前 loop 持有的 runtime。
|
||||
|
||||
第 6 阶段先把生命周期最小骨架立住:
|
||||
- `boot()` 负责建立 runtime
|
||||
- `close()` 负责释放由 runtime 持有的资源
|
||||
- 之后再在此基础上扩 `run()/stop()/shutdown hooks`
|
||||
"""
|
||||
|
||||
if self._running:
|
||||
raise RuntimeError("AgentLoop.close() requires the run loop to be stopped first")
|
||||
if self.loaded is None:
|
||||
return
|
||||
try:
|
||||
self.loaded.close()
|
||||
finally:
|
||||
self.loaded = None
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
source: str = "direct",
|
||||
user_id: str | None = None,
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
routing: Any = None,
|
||||
fallback_target: dict[str, Any] | None = None,
|
||||
auxiliary_target: dict[str, Any] | None = None,
|
||||
embedding_target: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tool_iterations: int | None = None,
|
||||
provider_bundle: ProviderBundle | None = None,
|
||||
) -> AgentRunResult:
|
||||
"""跑通最小 direct run 主链。
|
||||
|
||||
当前主链刻意保持克制,只解决这些事情:
|
||||
1. 确保 session 存在
|
||||
2. 用 frozen memory + history 组 prompt
|
||||
3. 调 provider
|
||||
4. 若有 tool calls,则进入最小 tool loop
|
||||
5. 把 user/assistant/tool 消息和 usage 写回 session
|
||||
"""
|
||||
|
||||
if self._running:
|
||||
raise RuntimeError(
|
||||
"AgentLoop.process_direct() is disabled while run() is active; "
|
||||
"submit tasks via submit_direct() instead."
|
||||
)
|
||||
return await self._process_direct_impl(
|
||||
task,
|
||||
session_id=session_id,
|
||||
source=source,
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
execution_context=execution_context,
|
||||
model=model,
|
||||
provider_name=provider_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
routing=routing,
|
||||
fallback_target=fallback_target,
|
||||
auxiliary_target=auxiliary_target,
|
||||
embedding_target=embedding_target,
|
||||
embedding_model=embedding_model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
max_tool_iterations=max_tool_iterations,
|
||||
provider_bundle=provider_bundle,
|
||||
)
|
||||
|
||||
async def _process_direct_impl(
|
||||
self,
|
||||
task: str,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
source: str = "direct",
|
||||
user_id: str | None = None,
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
routing: Any = None,
|
||||
fallback_target: dict[str, Any] | None = None,
|
||||
auxiliary_target: dict[str, Any] | None = None,
|
||||
embedding_target: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tool_iterations: int | None = None,
|
||||
provider_bundle: ProviderBundle | None = None,
|
||||
) -> AgentRunResult:
|
||||
"""真正执行一轮 direct run 的内部实现。
|
||||
|
||||
规则:
|
||||
- 外部直接调用时走 `process_direct()`
|
||||
- 运行循环内部消费时走 `_process_direct_impl()`
|
||||
- 这样才能保证 run 模式下外部不能绕过队列直接执行
|
||||
"""
|
||||
|
||||
loaded = self.boot()
|
||||
session_manager = self._require_loaded("session_manager")
|
||||
memory_service = self._require_loaded("memory_service")
|
||||
context_builder = self._require_loaded("context_builder")
|
||||
tool_registry = self._require_loaded("tool_registry")
|
||||
tool_executor = self._require_loaded("tool_executor")
|
||||
skill_assembler = self._require_loaded("skill_assembler")
|
||||
|
||||
resolved_session_id = session_id or uuid4().hex
|
||||
resolved_run_id = uuid4().hex
|
||||
resolved_model = model or self.profile.default_model
|
||||
resolved_max_tokens = max_tokens or self.profile.max_tokens
|
||||
resolved_temperature = self.profile.temperature if temperature is None else temperature
|
||||
resolved_max_tool_iterations = (
|
||||
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
|
||||
)
|
||||
|
||||
# 每次新运行开始前都通过 MemoryService 刷新 live state。
|
||||
# 这样 memory policy 会收口在 service,而不是散在 loop 里。
|
||||
memory_service.reload_for_new_run()
|
||||
|
||||
session_manager.ensure_session(
|
||||
resolved_session_id,
|
||||
source=source,
|
||||
model=resolved_model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="run_started",
|
||||
event_payload={
|
||||
"source": source,
|
||||
"model": resolved_model,
|
||||
"agent_name": self.profile.name,
|
||||
},
|
||||
content=task,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
user_message_recorded = False
|
||||
iterations = 0
|
||||
final_usage: dict[str, Any] = {}
|
||||
final_provider_name: str | None = provider_name
|
||||
final_model: str | None = resolved_model
|
||||
try:
|
||||
bundle = provider_bundle or make_provider_bundle(
|
||||
model=resolved_model,
|
||||
provider_name=provider_name,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
routing=routing,
|
||||
fallback_target=fallback_target,
|
||||
auxiliary_target=auxiliary_target,
|
||||
embedding_target=embedding_target,
|
||||
embedding_model=embedding_model or "text-embedding-v4",
|
||||
)
|
||||
skill_selector_provider = bundle.auxiliary_provider or bundle.main_provider
|
||||
skill_selector_model = (
|
||||
bundle.auxiliary_runtime.model
|
||||
if bundle.auxiliary_runtime is not None
|
||||
else bundle.main_runtime.model
|
||||
)
|
||||
assembled_skills = await skill_assembler.assemble(
|
||||
task_description=task,
|
||||
provider=skill_selector_provider,
|
||||
model=skill_selector_model,
|
||||
embedding_runtime=bundle.embedding_runtime,
|
||||
)
|
||||
skill_activation_messages = context_builder.build_skill_activation_messages(
|
||||
assembled_skills.activated_skills
|
||||
)
|
||||
|
||||
if skill_activation_messages:
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="skill_activation_snapshotted",
|
||||
event_payload={
|
||||
"activation_messages": skill_activation_messages,
|
||||
},
|
||||
content="\n\n".join(message["content"] for message in skill_activation_messages) or None,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
build_input = ContextBuildInput(
|
||||
base_system_prompt=self.profile.system_prompt,
|
||||
history=session_manager.get_history(resolved_session_id),
|
||||
current_user_input=task,
|
||||
memory_snapshot=memory_service.get_snapshot(),
|
||||
activated_skills=assembled_skills.activated_skills,
|
||||
session_context=SessionContext(
|
||||
session_id=resolved_session_id,
|
||||
source=source,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
),
|
||||
execution_context=execution_context,
|
||||
)
|
||||
context_result = context_builder.build_messages(build_input)
|
||||
session_manager.update_system_prompt(resolved_session_id, context_result.system_prompt)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="system_prompt_snapshotted",
|
||||
event_payload={
|
||||
"source": source,
|
||||
"model": resolved_model,
|
||||
"system_prompt_length": len(context_result.system_prompt),
|
||||
},
|
||||
content=context_result.system_prompt,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="user",
|
||||
event_type="user_message_added",
|
||||
content=task,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
user_message_recorded = True
|
||||
|
||||
provider = bundle.main_provider
|
||||
messages = list(context_result.messages)
|
||||
tool_schemas = tool_registry.export_provider_schemas()
|
||||
tool_context = ToolContext(
|
||||
workspace=str(loaded.workspace),
|
||||
session_id=resolved_session_id,
|
||||
user_id=user_id,
|
||||
services={
|
||||
"session_manager": session_manager,
|
||||
"memory_service": memory_service,
|
||||
"memory_store": memory_service.get_store(),
|
||||
"tool_registry": tool_registry,
|
||||
},
|
||||
metadata={
|
||||
"source": source,
|
||||
"agent_name": self.profile.name,
|
||||
},
|
||||
)
|
||||
|
||||
final_text = ""
|
||||
final_finish_reason = "stop"
|
||||
final_provider_name = bundle.main_runtime.provider_name
|
||||
final_model = bundle.main_runtime.model
|
||||
|
||||
while True:
|
||||
response = await provider.chat(
|
||||
messages=messages,
|
||||
tools=tool_schemas,
|
||||
model=final_model,
|
||||
max_tokens=resolved_max_tokens,
|
||||
temperature=resolved_temperature,
|
||||
)
|
||||
final_provider_name = response.provider_name or final_provider_name
|
||||
final_model = response.model or final_model
|
||||
final_usage = self._merge_usage(final_usage, response.usage or {})
|
||||
self._record_usage(session_manager, resolved_session_id, response.usage or {})
|
||||
|
||||
assistant_tool_calls = self._serialize_tool_calls(response.tool_calls)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="assistant",
|
||||
event_type="assistant_message_added",
|
||||
content=response.content,
|
||||
tool_calls=assistant_tool_calls or None,
|
||||
finish_reason=response.finish_reason,
|
||||
reasoning=response.reasoning_content,
|
||||
source=source,
|
||||
title=title,
|
||||
model=final_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
context_builder.add_assistant_message(
|
||||
messages,
|
||||
content=response.content,
|
||||
tool_calls=assistant_tool_calls or None,
|
||||
reasoning_content=response.reasoning_content,
|
||||
)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
final_text = response.content or ""
|
||||
final_finish_reason = response.finish_reason or "stop"
|
||||
break
|
||||
|
||||
if iterations >= resolved_max_tool_iterations:
|
||||
final_text = response.content or "Tool loop stopped after reaching the configured iteration limit."
|
||||
final_finish_reason = "max_tool_iterations"
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="assistant",
|
||||
event_type="assistant_message_added",
|
||||
content=final_text,
|
||||
finish_reason=final_finish_reason,
|
||||
source=source,
|
||||
title=title,
|
||||
model=final_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
context_builder.add_assistant_message(
|
||||
messages,
|
||||
content=final_text,
|
||||
)
|
||||
break
|
||||
|
||||
iterations += 1
|
||||
for tool_call in response.tool_calls:
|
||||
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="tool",
|
||||
event_type="tool_result_recorded",
|
||||
event_payload={
|
||||
"success": result.success,
|
||||
"error": result.error,
|
||||
},
|
||||
content=result.content,
|
||||
tool_name=result.tool_name,
|
||||
tool_call_id=tool_call.id,
|
||||
source=source,
|
||||
title=title,
|
||||
model=final_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
context_builder.add_tool_result(
|
||||
messages,
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=result.tool_name,
|
||||
result=result.content,
|
||||
)
|
||||
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="system",
|
||||
event_type="run_completed",
|
||||
event_payload={
|
||||
"finish_reason": final_finish_reason,
|
||||
"tool_iterations": iterations,
|
||||
},
|
||||
content=final_text,
|
||||
finish_reason=final_finish_reason,
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=final_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
return AgentRunResult(
|
||||
session_id=resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
output_text=final_text,
|
||||
finish_reason=final_finish_reason,
|
||||
tool_iterations=iterations,
|
||||
provider_name=final_provider_name,
|
||||
model=final_model,
|
||||
usage=final_usage,
|
||||
)
|
||||
except Exception as exc:
|
||||
if not user_message_recorded:
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
role="user",
|
||||
event_type="user_message_added",
|
||||
content=task,
|
||||
source=source,
|
||||
title=title,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
)
|
||||
return self._build_error_result(
|
||||
session_manager=session_manager,
|
||||
session_id=resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
source=source,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
model=final_model or resolved_model,
|
||||
message=f"Run failed before completion: {exc}",
|
||||
tool_iterations=iterations,
|
||||
provider_name=final_provider_name,
|
||||
usage=final_usage,
|
||||
)
|
||||
|
||||
def _require_loaded(self, field_name: str) -> Any:
|
||||
loaded = self.boot()
|
||||
value = getattr(loaded, field_name)
|
||||
if value is None:
|
||||
raise RuntimeError(f"Engine loader did not provide required dependency {field_name!r}")
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _serialize_tool_calls(tool_calls: list[Any]) -> list[dict[str, Any]]:
|
||||
payload: list[dict[str, Any]] = []
|
||||
for tool_call in tool_calls:
|
||||
payload.append(
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call.name,
|
||||
"arguments": tool_call.arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _record_usage(session_manager: Any, session_id: str, usage: dict[str, Any]) -> None:
|
||||
"""把 provider usage 映射到 session usage 字段。
|
||||
|
||||
这里先做最常见字段的最小映射:
|
||||
- prompt_tokens -> input_tokens
|
||||
- completion_tokens -> output_tokens
|
||||
|
||||
后面如果 provider 层补了更细的 cache/reasoning/cost,再往这里扩。
|
||||
"""
|
||||
|
||||
if not usage:
|
||||
return
|
||||
session_manager.update_usage(
|
||||
session_id,
|
||||
input_tokens=int(usage.get("input_tokens", usage.get("prompt_tokens", 0)) or 0),
|
||||
output_tokens=int(usage.get("output_tokens", usage.get("completion_tokens", 0)) or 0),
|
||||
reasoning_tokens=int(usage.get("reasoning_tokens", 0) or 0),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_usage(total: dict[str, Any], delta: dict[str, Any]) -> dict[str, Any]:
|
||||
"""把多轮 provider usage 合并成一次 run 的累计 usage。"""
|
||||
|
||||
merged = dict(total)
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, (int, float)) and isinstance(merged.get(key, 0), (int, float)):
|
||||
merged[key] = merged.get(key, 0) + value
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _build_error_result(
|
||||
*,
|
||||
session_manager: Any,
|
||||
session_id: str,
|
||||
run_id: str,
|
||||
source: str,
|
||||
title: str | None,
|
||||
user_id: str | None,
|
||||
model: str | None,
|
||||
message: str,
|
||||
tool_iterations: int,
|
||||
provider_name: str | None,
|
||||
usage: dict[str, Any],
|
||||
) -> AgentRunResult:
|
||||
"""把主链中的未处理异常收口成可追踪的 assistant error turn。"""
|
||||
|
||||
session_manager.append_message(
|
||||
session_id,
|
||||
run_id=run_id,
|
||||
role="assistant",
|
||||
event_type="assistant_message_added",
|
||||
content=message,
|
||||
finish_reason="error",
|
||||
source=source,
|
||||
title=title,
|
||||
model=model,
|
||||
user_id=user_id,
|
||||
)
|
||||
session_manager.append_message(
|
||||
session_id,
|
||||
run_id=run_id,
|
||||
role="system",
|
||||
event_type="run_failed",
|
||||
event_payload={
|
||||
"tool_iterations": tool_iterations,
|
||||
"provider_name": provider_name,
|
||||
},
|
||||
content=message,
|
||||
finish_reason="error",
|
||||
context_visible=False,
|
||||
source=source,
|
||||
title=title,
|
||||
model=model,
|
||||
user_id=user_id,
|
||||
)
|
||||
return AgentRunResult(
|
||||
session_id=session_id,
|
||||
run_id=run_id,
|
||||
output_text=message,
|
||||
finish_reason="error",
|
||||
tool_iterations=tool_iterations,
|
||||
provider_name=provider_name,
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
33
app-instance/backend/beaver/engine/providers/__init__.py
Normal file
33
app-instance/backend/beaver/engine/providers/__init__.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""LLM provider adapters."""
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .chain import FallbackProviderChain
|
||||
from .factory import (
|
||||
ProviderBundle,
|
||||
ProviderRoutingConfig,
|
||||
ProviderRuntime,
|
||||
ProviderTarget,
|
||||
build_provider_runtime,
|
||||
make_aux_provider,
|
||||
make_fallback_provider,
|
||||
make_main_provider,
|
||||
make_provider_bundle,
|
||||
make_provider_from_runtime,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FallbackProviderChain",
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"ProviderBundle",
|
||||
"ProviderRoutingConfig",
|
||||
"ProviderRuntime",
|
||||
"ProviderTarget",
|
||||
"ToolCallRequest",
|
||||
"build_provider_runtime",
|
||||
"make_aux_provider",
|
||||
"make_fallback_provider",
|
||||
"make_main_provider",
|
||||
"make_provider_bundle",
|
||||
"make_provider_from_runtime",
|
||||
]
|
||||
173
app-instance/backend/beaver/engine/providers/anthropic.py
Normal file
173
app-instance/backend/beaver/engine/providers/anthropic.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Native Anthropic Messages API provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import anthropic
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
anthropic = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""使用 Anthropic 原生 Messages API,而不是强行走 OpenAI-compatible path。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
default_model: str = "claude-sonnet-4-5",
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self._client = None
|
||||
|
||||
def _client_or_raise(self):
|
||||
if anthropic is None:
|
||||
raise RuntimeError("anthropic package is not installed")
|
||||
if self._client is None:
|
||||
self._client = anthropic.AsyncAnthropic(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
timeout=self.request_timeout_seconds,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
try:
|
||||
client = self._client_or_raise()
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
|
||||
|
||||
system_prompt, anthropic_messages = _convert_messages(messages)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"system": system_prompt or "",
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = _convert_tools(tools)
|
||||
|
||||
try:
|
||||
response = await client.messages.create(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="anthropic")
|
||||
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content_parts.append(block.text)
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
arguments=block.input,
|
||||
)
|
||||
)
|
||||
usage_payload = {}
|
||||
if getattr(response, "usage", None):
|
||||
usage_payload = {
|
||||
"input_tokens": getattr(response.usage, "input_tokens", 0),
|
||||
"output_tokens": getattr(response.usage, "output_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=getattr(response, "stop_reason", "stop") or "stop",
|
||||
usage=usage_payload,
|
||||
provider_name="anthropic",
|
||||
model=model or self.default_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
if role == "system":
|
||||
content = message.get("content")
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
if role == "tool":
|
||||
converted.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.get("tool_call_id"),
|
||||
"content": message.get("content") or "",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
continue
|
||||
if role == "assistant" and message.get("tool_calls"):
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if message.get("content"):
|
||||
content_blocks.append({"type": "text", "text": message["content"]})
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
function = tool_call.get("function", tool_call)
|
||||
arguments = function.get("arguments")
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_call.get("id"),
|
||||
"name": function.get("name"),
|
||||
"input": arguments or {},
|
||||
}
|
||||
)
|
||||
converted.append({"role": "assistant", "content": content_blocks})
|
||||
continue
|
||||
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
blocks = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
blocks.append({"type": "text", "text": item.get("text", "")})
|
||||
converted.append({"role": role, "content": blocks or [{"type": "text", "text": ""}]})
|
||||
else:
|
||||
converted.append({"role": role, "content": content or ""})
|
||||
return system_prompt, converted
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
if not fn.get("name"):
|
||||
continue
|
||||
converted.append(
|
||||
{
|
||||
"name": fn["name"],
|
||||
"description": fn.get("description") or "",
|
||||
"input_schema": fn.get("parameters") or {"type": "object", "properties": {}},
|
||||
}
|
||||
)
|
||||
return converted
|
||||
98
app-instance/backend/beaver/engine/providers/base.py
Normal file
98
app-instance/backend/beaver/engine/providers/base.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Beaver provider 子系统的统一契约。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCallRequest:
|
||||
"""模型返回的一次工具调用请求。"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class LLMResponse:
|
||||
"""统一的模型响应结构。"""
|
||||
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, Any] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return bool(self.tool_calls)
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""所有 provider 实现必须遵守的统一接口。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.request_timeout_seconds = (
|
||||
max(1.0, float(request_timeout_seconds))
|
||||
if request_timeout_seconds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""清理 provider 普遍不接受的空 content。"""
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str) and content == "":
|
||||
clean = dict(message)
|
||||
clean["content"] = None if (message.get("role") == "assistant" and message.get("tool_calls")) else "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = [
|
||||
item
|
||||
for item in content
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") in ("text", "input_text", "output_text")
|
||||
and not item.get("text")
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(content):
|
||||
clean = dict(message)
|
||||
clean["content"] = filtered or "(empty)"
|
||||
if message.get("role") == "assistant" and message.get("tool_calls") and not filtered:
|
||||
clean["content"] = None
|
||||
result.append(clean)
|
||||
continue
|
||||
result.append(message)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
"""统一聊天接口。"""
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""返回 provider 的默认模型名。"""
|
||||
145
app-instance/backend/beaver/engine/providers/chain.py
Normal file
145
app-instance/backend/beaver/engine/providers/chain.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Provider chain helpers.
|
||||
|
||||
这里先实现最小可用的 fallback chain:
|
||||
- 每次调用都先尝试主 provider
|
||||
- 本次调用主 provider 返回 `finish_reason=error` 时,再切到 fallback
|
||||
- fallback 只影响当前这一次调用,不会污染下一次 run 的首选链路
|
||||
|
||||
这样后面 `AgentLoop` 不需要自己处理“主模型挂了再换一个 provider”。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .base import LLMProvider, LLMResponse
|
||||
from .runtime import ProviderRuntime
|
||||
|
||||
|
||||
class FallbackProviderChain(LLMProvider):
|
||||
"""把 primary/fallback provider 封装成一个统一的 LLMProvider。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
primary_runtime: ProviderRuntime,
|
||||
primary_provider: LLMProvider,
|
||||
fallback_runtime: ProviderRuntime | None = None,
|
||||
fallback_provider: LLMProvider | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
api_key=primary_runtime.api_key,
|
||||
api_base=primary_runtime.api_base,
|
||||
request_timeout_seconds=primary_runtime.request_timeout_seconds,
|
||||
)
|
||||
self.primary_runtime = primary_runtime
|
||||
self.primary_provider = primary_provider
|
||||
self.fallback_runtime = fallback_runtime
|
||||
self.fallback_provider = fallback_provider
|
||||
# 这里只记录“最近一次 chat 实际用了哪条链”,用于调试和测试。
|
||||
# 真正的选路决策必须按调用粒度重新从 primary 开始,不能跨调用粘住 fallback。
|
||||
self._last_runtime = primary_runtime
|
||||
self._last_provider = primary_provider
|
||||
self._last_call_used_fallback = False
|
||||
|
||||
@property
|
||||
def fallback_activated(self) -> bool:
|
||||
"""最近一次 chat 是否实际用到了 fallback。"""
|
||||
|
||||
return self._last_call_used_fallback
|
||||
|
||||
@property
|
||||
def active_runtime(self) -> ProviderRuntime:
|
||||
"""最近一次 chat 实际使用的 runtime。"""
|
||||
|
||||
return self._last_runtime
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
self._last_provider = self.primary_provider
|
||||
self._last_runtime = self.primary_runtime
|
||||
self._last_call_used_fallback = False
|
||||
|
||||
response = await self._safe_chat(
|
||||
self.primary_provider,
|
||||
self.primary_runtime,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model or self.primary_runtime.model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
response = self._decorate_response(response, self.primary_runtime)
|
||||
if not self._should_activate_fallback(response):
|
||||
return response
|
||||
|
||||
assert self.fallback_provider is not None
|
||||
assert self.fallback_runtime is not None
|
||||
|
||||
self._last_provider = self.fallback_provider
|
||||
self._last_runtime = self.fallback_runtime
|
||||
self._last_call_used_fallback = True
|
||||
|
||||
response = await self._safe_chat(
|
||||
self.fallback_provider,
|
||||
self.fallback_runtime,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=self.fallback_runtime.model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
return self._decorate_response(response, self.fallback_runtime)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.primary_runtime.model
|
||||
|
||||
def _should_activate_fallback(self, response: LLMResponse) -> bool:
|
||||
return (
|
||||
self.fallback_provider is not None
|
||||
and self.fallback_runtime is not None
|
||||
and response.finish_reason == "error"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _safe_chat(
|
||||
provider: LLMProvider,
|
||||
runtime: ProviderRuntime,
|
||||
*,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> LLMResponse:
|
||||
"""把 provider 抛出的异常也收敛成统一 error response。
|
||||
|
||||
这样 fallback 的触发条件就不依赖“每个 provider 都记得自己 catch 异常”。
|
||||
"""
|
||||
|
||||
try:
|
||||
return await provider.chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
)
|
||||
except Exception as exc:
|
||||
return LLMResponse(
|
||||
content=f"Error: {exc}",
|
||||
finish_reason="error",
|
||||
provider_name=runtime.provider_name,
|
||||
model=runtime.model,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _decorate_response(response: LLMResponse, runtime: ProviderRuntime) -> LLMResponse:
|
||||
if response.provider_name is None:
|
||||
response.provider_name = runtime.provider_name
|
||||
if response.model is None:
|
||||
response.model = runtime.model
|
||||
return response
|
||||
274
app-instance/backend/beaver/engine/providers/codex.py
Normal file
274
app-instance/backend/beaver/engine/providers/codex.py
Normal file
@ -0,0 +1,274 @@
|
||||
"""OpenAI Codex Responses provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import httpx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
httpx = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
get_codex_token = None # type: ignore[assignment]
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
DEFAULT_ORIGINATOR = "beaver"
|
||||
|
||||
|
||||
class OpenAICodexProvider(LLMProvider):
|
||||
"""使用 Codex OAuth 调用 Responses API。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_model: str = "openai-codex/gpt-5.1-codex",
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key=None, api_base=None, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
if httpx is None or get_codex_token is None:
|
||||
return LLMResponse(content="Error: codex dependencies are not installed", finish_reason="error", provider_name="openai_codex")
|
||||
|
||||
resolved_model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
headers = _build_headers(token.account_id, token.access)
|
||||
body: dict[str, Any] = {
|
||||
"model": _strip_model_prefix(resolved_model),
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"instructions": system_prompt,
|
||||
"input": input_items,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"prompt_cache_key": _prompt_cache_key(messages),
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
|
||||
try:
|
||||
content, tool_calls, finish_reason = await _request_codex(
|
||||
DEFAULT_CODEX_URL,
|
||||
headers,
|
||||
body,
|
||||
verify=True,
|
||||
timeout_seconds=self.request_timeout_seconds or 600.0,
|
||||
)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error calling Codex: {exc}", finish_reason="error", provider_name="openai_codex")
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
provider_name="openai_codex",
|
||||
model=resolved_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
def _strip_model_prefix(model: str) -> str:
|
||||
if model.startswith("openai-codex/") or model.startswith("openai_codex/"):
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
def _build_headers(account_id: str, token: str) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": DEFAULT_ORIGINATOR,
|
||||
"User-Agent": "beaver (python)",
|
||||
"accept": "text/event-stream",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
async def _request_codex(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
body: dict[str, Any],
|
||||
verify: bool,
|
||||
timeout_seconds: float,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
async with httpx.AsyncClient(timeout=timeout_seconds, verify=verify) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=body) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||
return await _consume_sse(response)
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
}
|
||||
)
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
for index, message in enumerate(messages):
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
if role == "user":
|
||||
input_items.append(_convert_user_message(content))
|
||||
continue
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed",
|
||||
"id": f"msg_{index}",
|
||||
}
|
||||
)
|
||||
for tool_call in message.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{index}",
|
||||
"call_id": call_id or f"call_{index}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(message.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": output_text,
|
||||
}
|
||||
)
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
|
||||
|
||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _iter_sse(response: Any) -> AsyncGenerator[dict[str, Any], None]:
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [item[5:].strip() for item in buffer if item.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(response: Any) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
finish_reason = "stop"
|
||||
async for event in _iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_text.delta":
|
||||
delta = event.get("delta") or ""
|
||||
content_parts.append(delta)
|
||||
elif event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
raw_arguments = item.get("arguments") or "{}"
|
||||
try:
|
||||
arguments = json.loads(raw_arguments) if isinstance(raw_arguments, str) else raw_arguments
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{item.get('call_id', 'call')}|{item.get('id', '')}",
|
||||
name=item.get("name", ""),
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
finish_reason = event.get("response", {}).get("status", "completed")
|
||||
return "".join(content_parts) or None, tool_calls, finish_reason
|
||||
|
||||
|
||||
def _friendly_error(status_code: int, body: str) -> str:
|
||||
return f"Codex API error ({status_code}): {body[:400]}"
|
||||
106
app-instance/backend/beaver/engine/providers/custom.py
Normal file
106
app-instance/backend/beaver/engine/providers/custom.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import json_repair
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
json_repair = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
from openai import AsyncOpenAI
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
AsyncOpenAI = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class CustomProvider(LLMProvider):
|
||||
"""直接连接任意 OpenAI-compatible endpoint。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "no-key",
|
||||
api_base: str = "http://localhost:8000/v1",
|
||||
default_model: str = "default",
|
||||
request_timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self._client = None
|
||||
|
||||
def _client_or_raise(self):
|
||||
if AsyncOpenAI is None:
|
||||
raise RuntimeError("openai package is not installed")
|
||||
if self._client is None:
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
timeout=self.request_timeout_seconds,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
client = self._client_or_raise()
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"messages": self.sanitize_empty_content(messages),
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if tools:
|
||||
kwargs.update(tools=tools, tool_choice="auto")
|
||||
try:
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name="custom")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
parsed_tool_calls: list[ToolCallRequest] = []
|
||||
for tool_call in message.tool_calls or []:
|
||||
raw_arguments = tool_call.function.arguments
|
||||
if isinstance(raw_arguments, str):
|
||||
if json_repair is not None:
|
||||
arguments = json_repair.loads(raw_arguments)
|
||||
else:
|
||||
import json
|
||||
arguments = json.loads(raw_arguments)
|
||||
else:
|
||||
arguments = raw_arguments
|
||||
parsed_tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_payload = {}
|
||||
if usage is not None:
|
||||
usage_payload = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content=message.content,
|
||||
tool_calls=parsed_tool_calls,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
usage=usage_payload,
|
||||
reasoning_content=getattr(message, "reasoning_content", None),
|
||||
provider_name="custom",
|
||||
model=model or self.default_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
235
app-instance/backend/beaver/engine/providers/factory.py
Normal file
235
app-instance/backend/beaver/engine/providers/factory.py
Normal file
@ -0,0 +1,235 @@
|
||||
"""Provider runtime 的统一工厂入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .anthropic import AnthropicProvider
|
||||
from .base import LLMProvider
|
||||
from .chain import FallbackProviderChain
|
||||
from .codex import OpenAICodexProvider
|
||||
from .custom import CustomProvider
|
||||
from .litellm import LiteLLMProvider
|
||||
from .runtime import (
|
||||
ProviderRoutingConfig,
|
||||
ProviderRuntime,
|
||||
ProviderTarget,
|
||||
normalize_provider_target,
|
||||
resolve_auxiliary_runtime,
|
||||
resolve_embedding_runtime,
|
||||
resolve_fallback_runtime,
|
||||
resolve_provider_runtime,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderBundle:
|
||||
"""一次运行所需的 provider 组合。
|
||||
|
||||
这里把三条常见链路收口到一起:
|
||||
- `main`:主对话
|
||||
- `fallback`:主链失败后的备用 provider
|
||||
- `auxiliary`:搜索摘要、压缩、memory flush 等辅助任务
|
||||
"""
|
||||
|
||||
main_runtime: ProviderRuntime
|
||||
main_provider: LLMProvider
|
||||
fallback_runtime: ProviderRuntime | None = None
|
||||
fallback_provider: LLMProvider | None = None
|
||||
auxiliary_runtime: ProviderRuntime | None = None
|
||||
auxiliary_provider: LLMProvider | None = None
|
||||
embedding_runtime: ProviderRuntime | None = None
|
||||
|
||||
|
||||
def build_provider_runtime(**kwargs: Any) -> ProviderRuntime:
|
||||
"""构建统一 provider runtime。"""
|
||||
|
||||
return resolve_provider_runtime(**kwargs)
|
||||
|
||||
|
||||
def make_provider_from_runtime(runtime: ProviderRuntime) -> LLMProvider:
|
||||
"""根据 runtime 创建具体 provider 实例。"""
|
||||
|
||||
if runtime.spec.provider_impl == "custom":
|
||||
return CustomProvider(
|
||||
api_key=runtime.api_key or "no-key",
|
||||
api_base=runtime.api_base or "http://localhost:8000/v1",
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if runtime.spec.provider_impl == "codex":
|
||||
return OpenAICodexProvider(
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if runtime.spec.provider_impl == "anthropic":
|
||||
return AnthropicProvider(
|
||||
api_key=runtime.api_key,
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
api_base=runtime.api_base,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
return LiteLLMProvider(
|
||||
api_key=runtime.api_key,
|
||||
api_base=runtime.api_base,
|
||||
default_model=runtime.default_model or runtime.model,
|
||||
provider_name=runtime.provider_name,
|
||||
extra_headers=runtime.extra_headers,
|
||||
request_timeout_seconds=runtime.request_timeout_seconds,
|
||||
routing=runtime.routing,
|
||||
)
|
||||
|
||||
|
||||
def make_main_provider(**kwargs: Any) -> tuple[ProviderRuntime, LLMProvider]:
|
||||
"""构建主对话 provider。"""
|
||||
|
||||
fallback_target = kwargs.pop("fallback_target", None)
|
||||
if fallback_target is None and "fallback_model" in kwargs:
|
||||
fallback_target = kwargs.pop("fallback_model")
|
||||
|
||||
runtime = build_provider_runtime(
|
||||
auxiliary=False,
|
||||
fallback_target=fallback_target,
|
||||
role="main",
|
||||
source="main_config",
|
||||
**kwargs,
|
||||
)
|
||||
provider = make_provider_from_runtime(runtime)
|
||||
fallback_pair = make_fallback_provider(runtime, fallback_target)
|
||||
if fallback_pair is None:
|
||||
return runtime, provider
|
||||
fallback_runtime, fallback_provider = fallback_pair
|
||||
return runtime, FallbackProviderChain(runtime, provider, fallback_runtime, fallback_provider)
|
||||
|
||||
|
||||
def make_fallback_provider(
|
||||
primary_runtime: ProviderRuntime,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
) -> tuple[ProviderRuntime, LLMProvider] | None:
|
||||
"""构建 fallback provider。"""
|
||||
|
||||
runtime = resolve_fallback_runtime(primary_runtime, fallback_target or primary_runtime.fallback_target)
|
||||
if runtime is None:
|
||||
return None
|
||||
return runtime, make_provider_from_runtime(runtime)
|
||||
|
||||
|
||||
def make_aux_provider(
|
||||
main_runtime: ProviderRuntime | None = None,
|
||||
*,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
task_name: str = "auxiliary",
|
||||
**kwargs: Any,
|
||||
) -> tuple[ProviderRuntime, LLMProvider]:
|
||||
"""构建辅助任务 provider。"""
|
||||
|
||||
if target is None and kwargs:
|
||||
target = kwargs
|
||||
|
||||
if main_runtime is not None:
|
||||
runtime = resolve_auxiliary_runtime(main_runtime, target, task_name=task_name)
|
||||
else:
|
||||
normalized = normalize_provider_target(target)
|
||||
if normalized is None or not normalized.model:
|
||||
raise ValueError("Auxiliary provider without main_runtime requires at least a model")
|
||||
runtime = build_provider_runtime(
|
||||
model=normalized.model,
|
||||
provider_name=normalized.provider_name,
|
||||
api_key=normalized.api_key,
|
||||
api_base=normalized.api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds,
|
||||
extra_headers=normalized.extra_headers,
|
||||
routing=normalized.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auxiliary_config",
|
||||
)
|
||||
return runtime, make_provider_from_runtime(runtime)
|
||||
|
||||
|
||||
def make_embedding_runtime(
|
||||
main_runtime: ProviderRuntime,
|
||||
*,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
default_model: str = "text-embedding-v4",
|
||||
) -> ProviderRuntime | None:
|
||||
"""构建 embedding 专用 runtime。"""
|
||||
|
||||
return resolve_embedding_runtime(main_runtime, target=target, default_model=default_model)
|
||||
|
||||
|
||||
def make_provider_bundle(
|
||||
*,
|
||||
auxiliary_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
auxiliary_task_name: str = "auxiliary",
|
||||
embedding_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
embedding_model: str = "text-embedding-v4",
|
||||
**kwargs: Any,
|
||||
) -> ProviderBundle:
|
||||
"""一次性构建 main/fallback/aux 三条 provider 链。"""
|
||||
|
||||
runtime_kwargs = dict(kwargs)
|
||||
fallback_target = runtime_kwargs.pop("fallback_target", None)
|
||||
if fallback_target is None and "fallback_model" in kwargs:
|
||||
fallback_target = runtime_kwargs.pop("fallback_model")
|
||||
|
||||
main_runtime = build_provider_runtime(
|
||||
auxiliary=False,
|
||||
fallback_target=fallback_target,
|
||||
role="main",
|
||||
source="main_config",
|
||||
**runtime_kwargs,
|
||||
)
|
||||
primary_provider = make_provider_from_runtime(main_runtime)
|
||||
fallback_pair = make_fallback_provider(main_runtime, fallback_target)
|
||||
if fallback_pair is None:
|
||||
main_provider: LLMProvider = primary_provider
|
||||
fallback_runtime = None
|
||||
fallback_provider = None
|
||||
else:
|
||||
fallback_runtime, fallback_provider = fallback_pair
|
||||
main_provider = FallbackProviderChain(main_runtime, primary_provider, fallback_runtime, fallback_provider)
|
||||
|
||||
auxiliary_runtime = None
|
||||
auxiliary_provider = None
|
||||
if auxiliary_target is not None:
|
||||
auxiliary_runtime, auxiliary_provider = make_aux_provider(
|
||||
main_runtime,
|
||||
target=auxiliary_target,
|
||||
task_name=auxiliary_task_name,
|
||||
)
|
||||
|
||||
embedding_runtime = make_embedding_runtime(
|
||||
main_runtime,
|
||||
target=embedding_target,
|
||||
default_model=embedding_model,
|
||||
)
|
||||
|
||||
return ProviderBundle(
|
||||
main_runtime=main_runtime,
|
||||
main_provider=main_provider,
|
||||
fallback_runtime=fallback_runtime,
|
||||
fallback_provider=fallback_provider,
|
||||
auxiliary_runtime=auxiliary_runtime,
|
||||
auxiliary_provider=auxiliary_provider,
|
||||
embedding_runtime=embedding_runtime,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ProviderBundle",
|
||||
"ProviderRoutingConfig",
|
||||
"ProviderRuntime",
|
||||
"ProviderTarget",
|
||||
"build_provider_runtime",
|
||||
"make_aux_provider",
|
||||
"make_embedding_runtime",
|
||||
"make_fallback_provider",
|
||||
"make_main_provider",
|
||||
"make_provider_bundle",
|
||||
"make_provider_from_runtime",
|
||||
]
|
||||
230
app-instance/backend/beaver/engine/providers/litellm.py
Normal file
230
app-instance/backend/beaver/engine/providers/litellm.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .registry import find_by_model, find_gateway
|
||||
from .runtime import ProviderRoutingConfig
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import json_repair
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
json_repair = None # type: ignore[assignment]
|
||||
|
||||
try: # pragma: no cover - optional dependency
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
litellm = None # type: ignore[assignment]
|
||||
acompletion = None # type: ignore[assignment]
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""通过 LiteLLM 统一访问大多数 provider。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
routing: ProviderRoutingConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(api_key, api_base, request_timeout_seconds=request_timeout_seconds)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
self.routing = routing
|
||||
self.provider_name = provider_name
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
if litellm is not None:
|
||||
litellm.suppress_debug_info = True
|
||||
litellm.drop_params = True
|
||||
|
||||
def _build_env_overrides(self, api_key: str | None, api_base: str | None, model: str) -> dict[str, str]:
|
||||
"""为当前请求生成 LiteLLM 依赖的临时环境变量。
|
||||
|
||||
LiteLLM 对部分 provider 仍然优先读取环境变量。为了避免不同 runtime
|
||||
之间互相污染,这里只生成“本次请求需要的 env 覆盖”,真正调用时再临时注入。
|
||||
"""
|
||||
|
||||
if not api_key:
|
||||
return {}
|
||||
spec = self._gateway or find_by_model(model)
|
||||
if spec is None or not spec.env_key:
|
||||
return {}
|
||||
overrides: dict[str, str] = {spec.env_key: api_key}
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_value in spec.env_extras:
|
||||
resolved = env_value.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||||
overrides[env_name] = resolved
|
||||
return overrides
|
||||
|
||||
@contextmanager
|
||||
def _temporary_env(self, overrides: dict[str, str]):
|
||||
"""只在当前请求期间注入 provider 需要的环境变量。"""
|
||||
|
||||
if not overrides:
|
||||
yield
|
||||
return
|
||||
|
||||
sentinel = object()
|
||||
previous: dict[str, object] = {}
|
||||
for key, value in overrides.items():
|
||||
previous[key] = os.environ.get(key, sentinel)
|
||||
os.environ[key] = value
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for key, old_value in previous.items():
|
||||
if old_value is sentinel:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = str(old_value)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
if self._gateway:
|
||||
prefix = self._gateway.litellm_prefix
|
||||
resolved = model.split("/")[-1] if self._gateway.strip_model_prefix else model
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
if not any(model.startswith(prefix) for prefix in spec.skip_prefixes):
|
||||
model = f"{spec.litellm_prefix}/{model}"
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
sanitized = []
|
||||
for message in messages:
|
||||
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
|
||||
if clean.get("role") == "assistant" and "content" not in clean:
|
||||
clean["content"] = None
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
def _apply_model_overrides(self, original_model: str, kwargs: dict[str, Any]) -> None:
|
||||
spec = find_by_model(original_model)
|
||||
if spec is None:
|
||||
return
|
||||
model_lower = original_model.lower()
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
def _apply_openrouter_routing(self, kwargs: dict[str, Any]) -> None:
|
||||
if self.provider_name != "openrouter" or self.routing is None:
|
||||
return
|
||||
provider_payload: dict[str, Any] = {}
|
||||
if self.routing.sort:
|
||||
provider_payload["sort"] = self.routing.sort
|
||||
if self.routing.only:
|
||||
provider_payload["only"] = self.routing.only
|
||||
if self.routing.ignore:
|
||||
provider_payload["ignore"] = self.routing.ignore
|
||||
if self.routing.order:
|
||||
provider_payload["order"] = self.routing.order
|
||||
if self.routing.require_parameters:
|
||||
provider_payload["require_parameters"] = True
|
||||
if self.routing.data_collection:
|
||||
provider_payload["data_collection"] = self.routing.data_collection
|
||||
if provider_payload:
|
||||
kwargs["provider"] = provider_payload
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
) -> LLMResponse:
|
||||
if acompletion is None:
|
||||
return LLMResponse(content="Error: litellm is not installed", finish_reason="error", provider_name=self.provider_name)
|
||||
|
||||
original_model = model or self.default_model
|
||||
resolved_model = self._resolve_model(original_model)
|
||||
sanitized_messages = self._sanitize_messages(self.sanitize_empty_content(messages))
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": resolved_model,
|
||||
"messages": sanitized_messages,
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
self._apply_model_overrides(original_model, kwargs)
|
||||
self._apply_openrouter_routing(kwargs)
|
||||
env_overrides = self._build_env_overrides(self.api_key, self.api_base, original_model)
|
||||
|
||||
try:
|
||||
with self._temporary_env(env_overrides):
|
||||
response = await acompletion(**kwargs)
|
||||
except Exception as exc:
|
||||
return LLMResponse(content=f"Error: {exc}", finish_reason="error", provider_name=self.provider_name, model=resolved_model)
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
for tool_call in message.tool_calls or []:
|
||||
raw_arguments = tool_call.function.arguments
|
||||
if isinstance(raw_arguments, str):
|
||||
try:
|
||||
if json_repair is not None:
|
||||
arguments = json_repair.loads(raw_arguments)
|
||||
else:
|
||||
arguments = json.loads(raw_arguments)
|
||||
except Exception as exc:
|
||||
# 这里不要因为单个 tool_call 参数坏掉而直接炸掉整轮请求。
|
||||
# 后面的 ToolExecutor 会把这个标记转换成一条标准 tool failure。
|
||||
arguments = {
|
||||
"__beaver_tool_argument_parse_error__": str(exc),
|
||||
"__raw_arguments__": raw_arguments,
|
||||
}
|
||||
else:
|
||||
arguments = raw_arguments
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
usage = getattr(response, "usage", None)
|
||||
usage_payload = {}
|
||||
if usage is not None:
|
||||
usage_payload = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return LLMResponse(
|
||||
content=getattr(message, "content", None),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=getattr(choice, "finish_reason", "stop") or "stop",
|
||||
usage=usage_payload,
|
||||
reasoning_content=getattr(message, "reasoning_content", None),
|
||||
provider_name=self.provider_name or "litellm",
|
||||
model=resolved_model,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
249
app-instance/backend/beaver/engine/providers/registry.py
Normal file
249
app-instance/backend/beaver/engine/providers/registry.py
Normal file
@ -0,0 +1,249 @@
|
||||
"""Provider registry: 统一维护 provider 元数据与匹配规则。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ProviderSpec:
|
||||
"""单个 provider 的元数据定义。"""
|
||||
|
||||
name: str
|
||||
keywords: tuple[str, ...]
|
||||
env_key: str
|
||||
display_name: str = ""
|
||||
litellm_prefix: str = ""
|
||||
skip_prefixes: tuple[str, ...] = ()
|
||||
env_extras: tuple[tuple[str, str], ...] = ()
|
||||
is_gateway: bool = False
|
||||
is_local: bool = False
|
||||
detect_by_key_prefix: str = ""
|
||||
detect_by_base_keyword: str = ""
|
||||
default_api_base: str = ""
|
||||
strip_model_prefix: bool = False
|
||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||
is_oauth: bool = False
|
||||
is_direct: bool = False
|
||||
supports_prompt_caching: bool = False
|
||||
api_mode: str = "chat_completions"
|
||||
provider_impl: str = "litellm"
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.display_name or self.name.title()
|
||||
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
ProviderSpec(
|
||||
name="custom",
|
||||
keywords=(),
|
||||
env_key="",
|
||||
display_name="Custom",
|
||||
is_direct=True,
|
||||
provider_impl="custom",
|
||||
api_mode="chat_completions",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter",
|
||||
is_gateway=True,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="siliconflow",
|
||||
keywords=("siliconflow",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="SiliconFlow",
|
||||
litellm_prefix="openai",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="siliconflow",
|
||||
default_api_base="https://api.siliconflow.cn/v1",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="volcengine",
|
||||
keywords=("volcengine", "volces", "ark"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="VolcEngine",
|
||||
litellm_prefix="volcengine",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="volces",
|
||||
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
supports_prompt_caching=True,
|
||||
api_mode="anthropic_messages",
|
||||
provider_impl="anthropic",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="openai_codex",
|
||||
keywords=("openai-codex", "codex"),
|
||||
env_key="",
|
||||
display_name="OpenAI Codex",
|
||||
is_oauth=True,
|
||||
detect_by_base_keyword="codex",
|
||||
default_api_base="https://chatgpt.com/backend-api",
|
||||
api_mode="codex_responses",
|
||||
provider_impl="codex",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="github_copilot",
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="",
|
||||
display_name="Github Copilot",
|
||||
litellm_prefix="github_copilot",
|
||||
skip_prefixes=("github_copilot/",),
|
||||
is_oauth=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
env_key="DEEPSEEK_API_KEY",
|
||||
display_name="DeepSeek",
|
||||
litellm_prefix="deepseek",
|
||||
skip_prefixes=("deepseek/",),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini",
|
||||
skip_prefixes=("gemini/",),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai",
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope",
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot",
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
||||
default_api_base="https://api.moonshot.ai/v1",
|
||||
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||
),
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax",
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
),
|
||||
ProviderSpec(
|
||||
name="vllm",
|
||||
keywords=("vllm",),
|
||||
env_key="HOSTED_VLLM_API_KEY",
|
||||
display_name="vLLM/Local",
|
||||
litellm_prefix="hosted_vllm",
|
||||
is_local=True,
|
||||
),
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq",
|
||||
skip_prefixes=("groq/",),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def find_by_name(name: str) -> ProviderSpec | None:
|
||||
for spec in PROVIDERS:
|
||||
if spec.name == name:
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_by_model(model: str) -> ProviderSpec | None:
|
||||
"""按模型名关键词匹配标准 provider。"""
|
||||
|
||||
model_lower = model.lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
standard_specs = [spec for spec in PROVIDERS if not spec.is_gateway and not spec.is_local]
|
||||
|
||||
# 显式前缀优先级最高。
|
||||
# 这里不能只看 standard provider:
|
||||
# - `openrouter/...` 应该直接命中 openrouter
|
||||
# - `hosted_vllm/...` 应该能回到 vllm 这个本地 provider
|
||||
# - `github_copilot/...codex` 也不应被误判成 openai_codex
|
||||
for spec in PROVIDERS:
|
||||
aliases = {spec.name}
|
||||
if spec.litellm_prefix:
|
||||
aliases.add(spec.litellm_prefix.replace("-", "_"))
|
||||
if model_prefix and normalized_prefix in aliases:
|
||||
return spec
|
||||
|
||||
for spec in standard_specs:
|
||||
if any(keyword in model_lower or keyword.replace("-", "_") in model_normalized for keyword in spec.keywords):
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_gateway(
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""按 config key / api_key / api_base 识别 gateway 或 local provider。"""
|
||||
|
||||
if provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
return None
|
||||
408
app-instance/backend/beaver/engine/providers/runtime.py
Normal file
408
app-instance/backend/beaver/engine/providers/runtime.py
Normal file
@ -0,0 +1,408 @@
|
||||
"""Hermes 风格的 provider runtime resolution。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any
|
||||
|
||||
from .registry import ProviderSpec, find_by_model, find_by_name, find_gateway
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderRoutingConfig:
|
||||
"""OpenRouter provider routing 配置。"""
|
||||
|
||||
sort: str | None = None
|
||||
only: list[str] = field(default_factory=list)
|
||||
ignore: list[str] = field(default_factory=list)
|
||||
order: list[str] = field(default_factory=list)
|
||||
require_parameters: bool = False
|
||||
data_collection: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderTarget:
|
||||
"""一次 provider 选路请求的标准化配置。
|
||||
|
||||
这层不是具体 runtime,而是“调用方想要什么”:
|
||||
- 用哪个 provider
|
||||
- 跑哪个 model
|
||||
- 是否指定自定义 base_url
|
||||
- 是否带额外 headers / routing
|
||||
|
||||
后面 `resolve_provider_runtime()` 会把它真正解析成可实例化的 runtime。
|
||||
"""
|
||||
|
||||
provider_name: str | None = None
|
||||
model: str | None = None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
request_timeout_seconds: float | None = None
|
||||
routing: ProviderRoutingConfig | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProviderRuntime:
|
||||
"""运行时真正使用的 provider 解析结果。"""
|
||||
|
||||
spec: ProviderSpec
|
||||
model: str
|
||||
provider_name: str
|
||||
api_mode: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
default_model: str | None = None
|
||||
request_timeout_seconds: float | None = None
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
routing: ProviderRoutingConfig | None = None
|
||||
fallback_target: ProviderTarget | None = None
|
||||
auxiliary: bool = False
|
||||
role: str = "main"
|
||||
source: str = "runtime"
|
||||
|
||||
|
||||
def resolve_provider_runtime(
|
||||
*,
|
||||
model: str,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
request_timeout_seconds: float | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
routing: ProviderRoutingConfig | None = None,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None = None,
|
||||
auxiliary: bool = False,
|
||||
role: str = "main",
|
||||
source: str = "runtime",
|
||||
) -> ProviderRuntime:
|
||||
"""把调用侧传入的配置解析成统一 runtime。"""
|
||||
|
||||
gateway = find_gateway(provider_name, api_key, api_base)
|
||||
if gateway is not None:
|
||||
spec = gateway
|
||||
elif provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
else:
|
||||
spec = find_by_model(model)
|
||||
|
||||
if spec is None:
|
||||
if api_base:
|
||||
spec = find_by_name("custom")
|
||||
else:
|
||||
raise ValueError(f"Unable to resolve provider for model={model!r} provider_name={provider_name!r}")
|
||||
|
||||
resolved_model = _resolve_model_name(spec, model, gateway_mode=(gateway is not None))
|
||||
resolved_api_base = api_base or spec.default_api_base or None
|
||||
|
||||
return ProviderRuntime(
|
||||
spec=spec,
|
||||
model=resolved_model,
|
||||
provider_name=spec.name,
|
||||
api_mode=spec.api_mode,
|
||||
api_key=api_key,
|
||||
api_base=resolved_api_base,
|
||||
default_model=resolved_model,
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
extra_headers=extra_headers or {},
|
||||
routing=routing,
|
||||
fallback_target=normalize_provider_target(fallback_target),
|
||||
auxiliary=auxiliary,
|
||||
role=role,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def normalize_provider_target(target: ProviderTarget | dict[str, Any] | None) -> ProviderTarget | None:
|
||||
"""把 dict/对象形式的 provider 配置收敛成统一结构。
|
||||
|
||||
这里兼容几种常见写法,便于后续接 CLI / config / gateway:
|
||||
- `provider` 或 `provider_name`
|
||||
- `base_url` 或 `api_base`
|
||||
- `headers` 或 `extra_headers`
|
||||
- `timeout` 或 `request_timeout_seconds`
|
||||
"""
|
||||
|
||||
if target is None:
|
||||
return None
|
||||
if isinstance(target, ProviderTarget):
|
||||
return target
|
||||
|
||||
provider_name = target.get("provider_name")
|
||||
if provider_name is None:
|
||||
provider_name = target.get("provider")
|
||||
|
||||
api_base = target.get("api_base")
|
||||
if api_base is None:
|
||||
api_base = target.get("base_url")
|
||||
|
||||
extra_headers = target.get("extra_headers")
|
||||
if extra_headers is None:
|
||||
extra_headers = target.get("headers")
|
||||
|
||||
request_timeout_seconds = target.get("request_timeout_seconds")
|
||||
if request_timeout_seconds is None:
|
||||
request_timeout_seconds = target.get("timeout")
|
||||
|
||||
routing = target.get("routing")
|
||||
if isinstance(routing, dict):
|
||||
routing = ProviderRoutingConfig(**routing)
|
||||
|
||||
return ProviderTarget(
|
||||
provider_name=provider_name,
|
||||
model=target.get("model"),
|
||||
api_key=target.get("api_key"),
|
||||
api_base=api_base,
|
||||
extra_headers=dict(extra_headers or {}),
|
||||
request_timeout_seconds=request_timeout_seconds,
|
||||
routing=routing,
|
||||
)
|
||||
|
||||
|
||||
def resolve_fallback_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
fallback_target: ProviderTarget | dict[str, Any] | None,
|
||||
) -> ProviderRuntime | None:
|
||||
"""把 fallback 配置解析成独立 runtime。
|
||||
|
||||
Hermes 的 fallback 是“主 provider 失败后切换到另一个 provider:model”。
|
||||
这里先把 fallback 解析独立出来,具体何时激活交给上层 chain/factory。
|
||||
"""
|
||||
|
||||
target = normalize_provider_target(fallback_target)
|
||||
if target is None or not target.model:
|
||||
return None
|
||||
|
||||
inferred_provider = target.provider_name
|
||||
if inferred_provider in {None, "", "main"}:
|
||||
inferred_provider = primary_runtime.provider_name
|
||||
|
||||
api_key = target.api_key
|
||||
api_base = target.api_base
|
||||
extra_headers = dict(target.extra_headers)
|
||||
|
||||
# 只有在 fallback 没明确切换 provider/base 时,才继承主链的凭据与 headers。
|
||||
if inferred_provider == primary_runtime.provider_name and not api_base:
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
return resolve_provider_runtime(
|
||||
model=target.model,
|
||||
provider_name=inferred_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=target.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=target.routing,
|
||||
auxiliary=False,
|
||||
role="fallback",
|
||||
source="fallback_config",
|
||||
)
|
||||
|
||||
|
||||
def resolve_auxiliary_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
*,
|
||||
task_name: str = "auxiliary",
|
||||
) -> ProviderRuntime:
|
||||
"""解析辅助任务专用 runtime。
|
||||
|
||||
支持三类输入:
|
||||
- `None` / `provider=main`:直接复用主链 provider
|
||||
- 显式 `provider + model`:走独立 provider
|
||||
- 仅给 `model`:按模型名自动匹配 provider
|
||||
"""
|
||||
|
||||
normalized = normalize_provider_target(target)
|
||||
if normalized is None:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
)
|
||||
|
||||
provider_name = normalized.provider_name
|
||||
if provider_name in {None, "", "main"} and not normalized.api_base and not normalized.model:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
)
|
||||
|
||||
if provider_name == "main":
|
||||
return resolve_provider_runtime(
|
||||
model=normalized.model or primary_runtime.model,
|
||||
provider_name=primary_runtime.provider_name,
|
||||
api_key=normalized.api_key or primary_runtime.api_key,
|
||||
api_base=normalized.api_base or primary_runtime.api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=normalized.extra_headers or primary_runtime.extra_headers,
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="main_runtime",
|
||||
)
|
||||
|
||||
if provider_name in {"auto", None, ""} and not normalized.api_base and normalized.model is None:
|
||||
return _clone_runtime(
|
||||
primary_runtime,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auto->main",
|
||||
)
|
||||
|
||||
resolved_model = normalized.model or primary_runtime.model
|
||||
resolved_provider = normalized.provider_name
|
||||
if resolved_provider in {"auto", "", None} and not normalized.api_base:
|
||||
# `auto` 的第一阶段实现保持保守:
|
||||
# - 有显式 model 时按 model 匹配 provider
|
||||
# - 匹配不到则回退主链 provider
|
||||
spec = find_by_model(resolved_model)
|
||||
resolved_provider = spec.name if spec is not None else primary_runtime.provider_name
|
||||
|
||||
api_key = normalized.api_key
|
||||
api_base = normalized.api_base
|
||||
extra_headers = dict(normalized.extra_headers)
|
||||
|
||||
if resolved_provider == primary_runtime.provider_name and not api_base:
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
return resolve_provider_runtime(
|
||||
model=resolved_model,
|
||||
provider_name=resolved_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=normalized.routing or primary_runtime.routing,
|
||||
auxiliary=True,
|
||||
role=task_name,
|
||||
source="auxiliary_config",
|
||||
)
|
||||
|
||||
|
||||
def resolve_embedding_runtime(
|
||||
primary_runtime: ProviderRuntime,
|
||||
target: ProviderTarget | dict[str, Any] | None = None,
|
||||
*,
|
||||
default_model: str = "text-embedding-v4",
|
||||
) -> ProviderRuntime | None:
|
||||
"""解析 embedding 专用 runtime。
|
||||
|
||||
目标是把“embedding 用哪个 model / api_base / api_key”也收进 provider 层,
|
||||
避免上层检索逻辑直接偷拿 main/aux provider 的凭据。
|
||||
"""
|
||||
|
||||
normalized = normalize_provider_target(target)
|
||||
|
||||
if normalized is None:
|
||||
# 没有显式 embedding 配置时,只允许在主链本身就是 OpenAI-compatible
|
||||
# 的情况下,继承它的 api_base/api_key。否则不做模糊猜测。
|
||||
if not _supports_openai_embeddings(primary_runtime):
|
||||
return None
|
||||
return resolve_provider_runtime(
|
||||
model=default_model,
|
||||
provider_name="openai",
|
||||
api_key=primary_runtime.api_key,
|
||||
api_base=primary_runtime.api_base,
|
||||
request_timeout_seconds=primary_runtime.request_timeout_seconds,
|
||||
extra_headers=dict(primary_runtime.extra_headers),
|
||||
routing=primary_runtime.routing,
|
||||
auxiliary=False,
|
||||
role="embedding",
|
||||
source="embedding_inherited",
|
||||
)
|
||||
|
||||
resolved_model = normalized.model or default_model
|
||||
resolved_provider = normalized.provider_name
|
||||
if resolved_provider in {None, "", "main", "auto"}:
|
||||
resolved_provider = "custom" if normalized.api_base else "openai"
|
||||
|
||||
api_key = normalized.api_key
|
||||
api_base = normalized.api_base
|
||||
extra_headers = dict(normalized.extra_headers)
|
||||
|
||||
if not api_base and _supports_openai_embeddings(primary_runtime):
|
||||
api_key = api_key or primary_runtime.api_key
|
||||
api_base = api_base or primary_runtime.api_base
|
||||
if not extra_headers:
|
||||
extra_headers = dict(primary_runtime.extra_headers)
|
||||
|
||||
runtime = resolve_provider_runtime(
|
||||
model=resolved_model,
|
||||
provider_name=resolved_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
request_timeout_seconds=normalized.request_timeout_seconds or primary_runtime.request_timeout_seconds,
|
||||
extra_headers=extra_headers,
|
||||
routing=normalized.routing,
|
||||
auxiliary=False,
|
||||
role="embedding",
|
||||
source="embedding_config",
|
||||
)
|
||||
if not _supports_openai_embeddings(runtime):
|
||||
raise ValueError("Embedding runtime currently requires an OpenAI-compatible provider")
|
||||
return runtime
|
||||
|
||||
|
||||
def _supports_openai_embeddings(runtime: ProviderRuntime) -> bool:
|
||||
"""当前 embedding retriever 只支持 OpenAI-compatible `/v1/embeddings`。"""
|
||||
|
||||
return runtime.api_mode == "chat_completions" and runtime.spec.provider_impl in {"litellm", "custom"}
|
||||
|
||||
|
||||
def _clone_runtime(
|
||||
runtime: ProviderRuntime,
|
||||
**changes: Any,
|
||||
) -> ProviderRuntime:
|
||||
"""基于现有 runtime 复制一个轻量变体。
|
||||
|
||||
用在 `provider=main` 这类场景,避免重复跑一次 registry 解析。
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"extra_headers": dict(runtime.extra_headers),
|
||||
"routing": runtime.routing,
|
||||
"fallback_target": runtime.fallback_target,
|
||||
}
|
||||
payload.update(changes)
|
||||
return replace(runtime, **payload)
|
||||
|
||||
|
||||
def _resolve_model_name(spec: ProviderSpec, model: str, *, gateway_mode: bool) -> str:
|
||||
"""根据 registry 规则应用必要前缀。"""
|
||||
|
||||
resolved = model
|
||||
if gateway_mode:
|
||||
prefix = spec.litellm_prefix
|
||||
if spec.strip_model_prefix:
|
||||
resolved = resolved.split("/")[-1]
|
||||
if prefix and not resolved.startswith(f"{prefix}/"):
|
||||
resolved = f"{prefix}/{resolved}"
|
||||
return resolved
|
||||
|
||||
if spec.litellm_prefix:
|
||||
resolved = _canonicalize_explicit_prefix(resolved, spec.name, spec.litellm_prefix)
|
||||
if not any(resolved.startswith(item) for item in spec.skip_prefixes):
|
||||
resolved = f"{spec.litellm_prefix}/{resolved}"
|
||||
return resolved
|
||||
|
||||
|
||||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
||||
if "/" not in model:
|
||||
return model
|
||||
prefix, remainder = model.split("/", 1)
|
||||
if prefix.lower().replace("-", "_") != spec_name:
|
||||
return model
|
||||
return f"{canonical_prefix}/{remainder}"
|
||||
2
app-instance/backend/beaver/engine/runtime/__init__.py
Normal file
2
app-instance/backend/beaver/engine/runtime/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Runtime helper objects and execution context."""
|
||||
|
||||
15
app-instance/backend/beaver/engine/session/__init__.py
Normal file
15
app-instance/backend/beaver/engine/session/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Session state and persistence."""
|
||||
|
||||
from .manager import SessionManager
|
||||
from .models import MessageRecord, SessionRecord, SessionUsage
|
||||
from .search import SessionSearchService
|
||||
from .store import SessionStore
|
||||
|
||||
__all__ = [
|
||||
"MessageRecord",
|
||||
"SessionManager",
|
||||
"SessionRecord",
|
||||
"SessionSearchService",
|
||||
"SessionStore",
|
||||
"SessionUsage",
|
||||
]
|
||||
143
app-instance/backend/beaver/engine/session/manager.py
Normal file
143
app-instance/backend/beaver/engine/session/manager.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Beaver session 子系统对 runtime 暴露的统一门面。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .models import MessageRecord
|
||||
from .search import SessionSearchService
|
||||
from .store import SessionStore
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""供 AgentLoop / services / MCP tools 使用的统一 session facade。"""
|
||||
|
||||
def __init__(self, workspace: str | Path, db_path: str | Path | None = None) -> None:
|
||||
self.workspace = Path(workspace)
|
||||
self.sessions_dir = self.workspace / "sessions"
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = Path(db_path) if db_path is not None else self.sessions_dir / "state.db"
|
||||
self.store = SessionStore(self.db_path)
|
||||
self.search = SessionSearchService(self.store)
|
||||
|
||||
def close(self) -> None:
|
||||
self.store.close()
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> str:
|
||||
return self.store.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
|
||||
def get_session(self, session_id: str) -> dict[str, Any] | None:
|
||||
record = self.store.get_session_record(session_id)
|
||||
return record.to_dict() if record is not None else None
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
self.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
session = self.get_session(session_id)
|
||||
if session is None:
|
||||
raise RuntimeError(f"Failed to create session {session_id!r}")
|
||||
return session
|
||||
|
||||
def append_message(self, session_id: str, **kwargs: Any) -> int:
|
||||
return self.store.append_message(session_id, **kwargs)
|
||||
|
||||
def get_event_records(self, session_id: str) -> list[MessageRecord]:
|
||||
"""返回当前 session 的完整事件流。
|
||||
|
||||
这里和 `get_messages_as_conversation()` 的区别很关键:
|
||||
- `get_event_records()` 面向 runtime / replay / audit,保留隐藏系统事件
|
||||
- `get_messages_as_conversation()` 面向 prompt builder,只暴露可进上下文的事件
|
||||
|
||||
第 6 阶段开始后,session 已不再只是“聊天消息存储”,而是在逐步收敛成
|
||||
“外部事件流 + 上层投影视图”。
|
||||
"""
|
||||
|
||||
return self.store.get_event_records(session_id)
|
||||
|
||||
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
|
||||
"""返回某一次 direct run / future bus run 对应的事件片段。"""
|
||||
|
||||
return self.store.get_run_event_records(session_id, run_id)
|
||||
|
||||
def list_run_ids(self, session_id: str) -> list[str]:
|
||||
"""按出现顺序列出当前 session 的所有 run_id。"""
|
||||
|
||||
return self.store.list_run_ids(session_id)
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
return self.store.get_messages_as_conversation(session_id)
|
||||
|
||||
def get_visible_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""返回适合注入 prompt 的可见历史切片。
|
||||
|
||||
这里故意不直接暴露完整事件流,而是继续提供“模型可消费历史”这个投影视图:
|
||||
1. 只包含 `context_visible=True` 的事件
|
||||
2. 继续保留旧式窗口裁剪逻辑,避免当前主链行为突然变化
|
||||
3. 让 `ContextBuilder` 明确消费的是“上游裁剪后的可见片段”
|
||||
"""
|
||||
|
||||
history = self.get_messages_as_conversation(session_id)
|
||||
sliced = history[-max_messages:]
|
||||
for index, message in enumerate(sliced):
|
||||
if message.get("role") == "user":
|
||||
sliced = sliced[index:]
|
||||
break
|
||||
return sliced
|
||||
|
||||
def get_history(self, session_id: str, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""兼容旧名称,实际返回可见历史切片。"""
|
||||
|
||||
return self.get_visible_history(session_id, max_messages=max_messages)
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
self.store.update_system_prompt(session_id, system_prompt)
|
||||
|
||||
def update_usage(self, session_id: str, **kwargs: Any) -> None:
|
||||
self.store.update_usage(session_id, **kwargs)
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
self.store.end_session(session_id, end_reason)
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
self.store.reopen_session(session_id)
|
||||
|
||||
def list_sessions_rich(self, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
return self.search.list_sessions_rich(**kwargs)
|
||||
|
||||
def search_messages(self, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
return self.search.search_messages(**kwargs)
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
|
||||
return self.search.resolve_session_id(session_id_or_prefix)
|
||||
211
app-instance/backend/beaver/engine/session/models.py
Normal file
211
app-instance/backend/beaver/engine/session/models.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""Beaver session 子系统的数据模型。
|
||||
|
||||
这层只定义数据结构,不放数据库读写逻辑。目的是把:
|
||||
1. SQLite 行结构
|
||||
2. 运行时会话对象
|
||||
3. 对外暴露的 conversation message
|
||||
|
||||
三件事分开,避免后续所有地方都直接和裸字典耦合。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionUsage:
|
||||
"""会话维度的 usage/cost 统计。"""
|
||||
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
estimated_cost_usd: float = 0.0
|
||||
actual_cost_usd: float | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"cache_read_tokens": self.cache_read_tokens,
|
||||
"cache_write_tokens": self.cache_write_tokens,
|
||||
"reasoning_tokens": self.reasoning_tokens,
|
||||
"estimated_cost_usd": self.estimated_cost_usd,
|
||||
"actual_cost_usd": self.actual_cost_usd,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MessageRecord:
|
||||
"""单条会话事件的结构化表示。
|
||||
|
||||
当前仍然沿用 `messages` 这张表名,但语义已经开始向 event stream 收拢:
|
||||
1. 普通 user/assistant/tool 消息本身就是事件
|
||||
2. 运行时的 system snapshot / run lifecycle 也可写成隐藏事件
|
||||
3. 是否进入模型上下文由 `context_visible` 决定,而不是简单看 role
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: str | None = None
|
||||
timestamp: float | None = None
|
||||
message_id: int | None = None
|
||||
run_id: str | None = None
|
||||
event_type: str | None = None
|
||||
event_payload: dict[str, Any] | None = None
|
||||
context_visible: bool = True
|
||||
tool_name: str | None = None
|
||||
tool_calls: list[dict[str, Any]] | None = None
|
||||
tool_call_id: str | None = None
|
||||
finish_reason: str | None = None
|
||||
reasoning: str | None = None
|
||||
reasoning_details: Any | None = None
|
||||
codex_reasoning_items: Any | None = None
|
||||
|
||||
def to_conversation_message(self) -> dict[str, Any]:
|
||||
"""转成 provider / context builder 可直接消费的消息格式。"""
|
||||
|
||||
if not self.context_visible:
|
||||
raise ValueError("Hidden session events cannot be converted into conversation messages")
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"role": self.role,
|
||||
"content": self.content,
|
||||
}
|
||||
if self.tool_name:
|
||||
payload["tool_name"] = self.tool_name
|
||||
if self.tool_calls:
|
||||
payload["tool_calls"] = self.tool_calls
|
||||
if self.tool_call_id:
|
||||
payload["tool_call_id"] = self.tool_call_id
|
||||
if self.finish_reason:
|
||||
payload["finish_reason"] = self.finish_reason
|
||||
if self.reasoning:
|
||||
payload["reasoning"] = self.reasoning
|
||||
if self.reasoning_details is not None:
|
||||
payload["reasoning_details"] = self.reasoning_details
|
||||
if self.codex_reasoning_items is not None:
|
||||
payload["codex_reasoning_items"] = self.codex_reasoning_items
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: dict[str, Any]) -> "MessageRecord":
|
||||
"""从 SQLite row/dict 恢复消息模型。"""
|
||||
|
||||
tool_calls = row.get("tool_calls")
|
||||
if isinstance(tool_calls, str):
|
||||
try:
|
||||
tool_calls = json.loads(tool_calls)
|
||||
except json.JSONDecodeError:
|
||||
tool_calls = []
|
||||
|
||||
reasoning_details = row.get("reasoning_details")
|
||||
if isinstance(reasoning_details, str):
|
||||
try:
|
||||
reasoning_details = json.loads(reasoning_details)
|
||||
except json.JSONDecodeError:
|
||||
reasoning_details = None
|
||||
|
||||
codex_reasoning_items = row.get("codex_reasoning_items")
|
||||
if isinstance(codex_reasoning_items, str):
|
||||
try:
|
||||
codex_reasoning_items = json.loads(codex_reasoning_items)
|
||||
except json.JSONDecodeError:
|
||||
codex_reasoning_items = None
|
||||
|
||||
event_payload = row.get("event_payload")
|
||||
if isinstance(event_payload, str):
|
||||
try:
|
||||
event_payload = json.loads(event_payload)
|
||||
except json.JSONDecodeError:
|
||||
event_payload = None
|
||||
|
||||
return cls(
|
||||
message_id=row.get("id"),
|
||||
run_id=row.get("run_id"),
|
||||
role=row["role"],
|
||||
content=row.get("content"),
|
||||
event_type=row.get("event_type") or row.get("role"),
|
||||
event_payload=event_payload,
|
||||
context_visible=bool(row.get("context_visible", 1)),
|
||||
tool_name=row.get("tool_name"),
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=row.get("tool_call_id"),
|
||||
timestamp=row.get("timestamp"),
|
||||
finish_reason=row.get("finish_reason"),
|
||||
reasoning=row.get("reasoning"),
|
||||
reasoning_details=reasoning_details,
|
||||
codex_reasoning_items=codex_reasoning_items,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class SessionRecord:
|
||||
"""单个 session 的结构化表示。"""
|
||||
|
||||
session_id: str
|
||||
source: str
|
||||
started_at: float
|
||||
last_active: float
|
||||
user_id: str | None = None
|
||||
title: str | None = None
|
||||
model: str | None = None
|
||||
system_prompt: str | None = None
|
||||
parent_session_id: str | None = None
|
||||
ended_at: float | None = None
|
||||
end_reason: str | None = None
|
||||
message_count: int = 0
|
||||
tool_call_count: int = 0
|
||||
preview: str | None = None
|
||||
usage: SessionUsage = field(default_factory=SessionUsage)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload = {
|
||||
"id": self.session_id,
|
||||
"source": self.source,
|
||||
"user_id": self.user_id,
|
||||
"title": self.title,
|
||||
"model": self.model,
|
||||
"system_prompt": self.system_prompt,
|
||||
"parent_session_id": self.parent_session_id,
|
||||
"started_at": self.started_at,
|
||||
"last_active": self.last_active,
|
||||
"ended_at": self.ended_at,
|
||||
"end_reason": self.end_reason,
|
||||
"message_count": self.message_count,
|
||||
"tool_call_count": self.tool_call_count,
|
||||
"preview": self.preview,
|
||||
}
|
||||
payload.update(self.usage.to_dict())
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: dict[str, Any]) -> "SessionRecord":
|
||||
return cls(
|
||||
session_id=row["id"],
|
||||
source=row["source"],
|
||||
user_id=row.get("user_id"),
|
||||
title=row.get("title"),
|
||||
model=row.get("model"),
|
||||
system_prompt=row.get("system_prompt"),
|
||||
parent_session_id=row.get("parent_session_id"),
|
||||
started_at=row["started_at"],
|
||||
last_active=row["last_active"],
|
||||
ended_at=row.get("ended_at"),
|
||||
end_reason=row.get("end_reason"),
|
||||
message_count=row.get("message_count", 0),
|
||||
tool_call_count=row.get("tool_call_count", 0),
|
||||
preview=row.get("preview"),
|
||||
usage=SessionUsage(
|
||||
input_tokens=row.get("input_tokens", 0),
|
||||
output_tokens=row.get("output_tokens", 0),
|
||||
cache_read_tokens=row.get("cache_read_tokens", 0),
|
||||
cache_write_tokens=row.get("cache_write_tokens", 0),
|
||||
reasoning_tokens=row.get("reasoning_tokens", 0),
|
||||
estimated_cost_usd=row.get("estimated_cost_usd", 0.0) or 0.0,
|
||||
actual_cost_usd=row.get("actual_cost_usd"),
|
||||
),
|
||||
)
|
||||
151
app-instance/backend/beaver/engine/session/search.py
Normal file
151
app-instance/backend/beaver/engine/session/search.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""Beaver session 子系统的检索能力。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
from .store import SessionStore
|
||||
|
||||
|
||||
class SessionSearchService:
|
||||
"""围绕 `SessionStore` 提供 browsing / FTS / lineage 辅助能力。"""
|
||||
|
||||
def __init__(self, store: SessionStore) -> None:
|
||||
self.store = store
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_fts5_query(query: str) -> str:
|
||||
quoted_parts: list[str] = []
|
||||
|
||||
def preserve(match: re.Match[str]) -> str:
|
||||
quoted_parts.append(match.group(0))
|
||||
return f"\x00Q{len(quoted_parts) - 1}\x00"
|
||||
|
||||
sanitized = re.sub(r'"[^"]*"', preserve, query)
|
||||
sanitized = re.sub(r'[+{}()\"^]', " ", sanitized)
|
||||
sanitized = re.sub(r"\*+", "*", sanitized)
|
||||
sanitized = re.sub(r"(^|\s)\*", r"\1", sanitized)
|
||||
sanitized = re.sub(r"(?i)^(AND|OR|NOT)\b\s*", "", sanitized.strip())
|
||||
sanitized = re.sub(r"(?i)\s+(AND|OR|NOT)\s*$", "", sanitized.strip())
|
||||
sanitized = re.sub(r"\b(\w+(?:[.-]\w+)+)\b", r'"\1"', sanitized)
|
||||
|
||||
for index, quoted in enumerate(quoted_parts):
|
||||
sanitized = sanitized.replace(f"\x00Q{index}\x00", quoted)
|
||||
return sanitized.strip()
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> str | None:
|
||||
"""用完整 ID 或唯一前缀解析出目标 session_id。"""
|
||||
|
||||
exact = self.store.get_session_record(session_id_or_prefix)
|
||||
if exact is not None:
|
||||
return exact.session_id
|
||||
|
||||
escaped = (
|
||||
session_id_or_prefix
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
rows = self.store._fetchall(
|
||||
"""
|
||||
SELECT id
|
||||
FROM sessions
|
||||
WHERE id LIKE ? ESCAPE '\\'
|
||||
ORDER BY started_at DESC
|
||||
LIMIT 2
|
||||
""",
|
||||
(f"{escaped}%",),
|
||||
)
|
||||
if len(rows) == 1:
|
||||
return rows[0]["id"]
|
||||
return None
|
||||
|
||||
def list_sessions_rich(
|
||||
self,
|
||||
*,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
include_children: bool = False,
|
||||
source: str | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""列出最近活跃的 session 及其摘要元数据。"""
|
||||
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
if not include_children:
|
||||
clauses.append("parent_session_id IS NULL")
|
||||
if source:
|
||||
clauses.append("source = ?")
|
||||
params.append(source)
|
||||
if exclude_sources:
|
||||
placeholders = ",".join("?" for _ in exclude_sources)
|
||||
clauses.append(f"source NOT IN ({placeholders})")
|
||||
params.extend(exclude_sources)
|
||||
|
||||
where = f"WHERE {' AND '.join(clauses)}" if clauses else ""
|
||||
params.extend([limit, offset])
|
||||
rows = self.store._fetchall(
|
||||
f"""
|
||||
SELECT *
|
||||
FROM sessions
|
||||
{where}
|
||||
ORDER BY last_active DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
return rows
|
||||
|
||||
def search_messages(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
role_filter: list[str] | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""使用 FTS5 搜索 session transcript。"""
|
||||
|
||||
query = self._sanitize_fts5_query(query)
|
||||
if not query:
|
||||
return []
|
||||
|
||||
clauses = ["messages_fts MATCH ?", "m.context_visible = 1"]
|
||||
params: list[Any] = [query]
|
||||
|
||||
if exclude_sources:
|
||||
placeholders = ",".join("?" for _ in exclude_sources)
|
||||
clauses.append(f"s.source NOT IN ({placeholders})")
|
||||
params.extend(exclude_sources)
|
||||
if role_filter:
|
||||
placeholders = ",".join("?" for _ in role_filter)
|
||||
clauses.append(f"m.role IN ({placeholders})")
|
||||
params.extend(role_filter)
|
||||
|
||||
params.extend([limit, offset])
|
||||
sql = f"""
|
||||
SELECT
|
||||
m.id,
|
||||
m.session_id,
|
||||
m.role,
|
||||
s.source,
|
||||
s.model,
|
||||
s.started_at AS session_started,
|
||||
snippet(messages_fts, 0, '>>>', '<<<', '...', 40) AS snippet
|
||||
FROM messages_fts
|
||||
JOIN messages m ON m.id = messages_fts.rowid
|
||||
JOIN sessions s ON s.id = m.session_id
|
||||
WHERE {' AND '.join(clauses)}
|
||||
ORDER BY rank
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.store._fetchall(sql, tuple(params))
|
||||
except sqlite3.Error as exc:
|
||||
raise RuntimeError(f"Session transcript search failed for query={query!r}") from exc
|
||||
467
app-instance/backend/beaver/engine/session/store.py
Normal file
467
app-instance/backend/beaver/engine/session/store.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""Beaver session 子系统的 SQLite 存储实现。
|
||||
|
||||
设计来源主要参考 Hermes-agent:
|
||||
1. SQLite 作为统一 session/transcript backend
|
||||
2. WAL 模式支持多读单写
|
||||
3. FTS5 支持跨 session 文本检索
|
||||
4. `parent_session_id` 支持 lineage
|
||||
|
||||
这层只负责“存”和“取”,复杂检索逻辑由 `search.py` 承担。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TypeVar
|
||||
|
||||
from .models import MessageRecord, SessionRecord
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
title TEXT,
|
||||
model TEXT,
|
||||
system_prompt TEXT,
|
||||
parent_session_id TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
last_active REAL NOT NULL,
|
||||
ended_at REAL,
|
||||
end_reason TEXT,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
cache_read_tokens INTEGER DEFAULT 0,
|
||||
cache_write_tokens INTEGER DEFAULT 0,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
estimated_cost_usd REAL DEFAULT 0,
|
||||
actual_cost_usd REAL,
|
||||
preview TEXT,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL REFERENCES sessions(id),
|
||||
run_id TEXT,
|
||||
role TEXT NOT NULL,
|
||||
event_type TEXT,
|
||||
event_payload TEXT,
|
||||
context_visible INTEGER NOT NULL DEFAULT 1,
|
||||
content TEXT,
|
||||
tool_name TEXT,
|
||||
tool_calls TEXT,
|
||||
tool_call_id TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
finish_reason TEXT,
|
||||
reasoning TEXT,
|
||||
reasoning_details TEXT,
|
||||
codex_reasoning_items TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_last_active ON sessions(last_active DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp, id);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_run ON messages(session_id, run_id, timestamp, id);
|
||||
"""
|
||||
|
||||
FTS_TABLE_SQL = """
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
|
||||
content,
|
||||
content=messages,
|
||||
content_rowid=id
|
||||
);
|
||||
"""
|
||||
|
||||
FTS_TRIGGER_SQL = """
|
||||
DROP TRIGGER IF EXISTS messages_fts_insert;
|
||||
DROP TRIGGER IF EXISTS messages_fts_delete;
|
||||
DROP TRIGGER IF EXISTS messages_fts_update;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_insert AFTER INSERT ON messages BEGIN
|
||||
INSERT INTO messages_fts(rowid, content)
|
||||
SELECT new.id, new.content
|
||||
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_delete AFTER DELETE ON messages BEGIN
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', old.id, old.content
|
||||
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS messages_fts_update AFTER UPDATE ON messages BEGIN
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', old.id, old.content
|
||||
WHERE old.context_visible = 1 AND old.content IS NOT NULL;
|
||||
INSERT INTO messages_fts(rowid, content)
|
||||
SELECT new.id, new.content
|
||||
WHERE new.context_visible = 1 AND new.content IS NOT NULL;
|
||||
END;
|
||||
"""
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""SQLite-backed session store."""
|
||||
|
||||
def __init__(self, db_path: str | Path) -> None:
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||
self._init_schema()
|
||||
|
||||
def _init_schema(self) -> None:
|
||||
with self._lock:
|
||||
self._conn.executescript(SCHEMA_SQL)
|
||||
try:
|
||||
self._conn.execute("SELECT * FROM messages_fts LIMIT 0")
|
||||
except sqlite3.OperationalError:
|
||||
self._conn.executescript(FTS_TABLE_SQL)
|
||||
self._conn.executescript(FTS_TRIGGER_SQL)
|
||||
# 旧版本可能把 hidden 事件也写进了 FTS;初始化时顺手清掉这些噪声项。
|
||||
self._conn.execute(
|
||||
"""
|
||||
INSERT INTO messages_fts(messages_fts, rowid, content)
|
||||
SELECT 'delete', id, content
|
||||
FROM messages
|
||||
WHERE context_visible = 0 AND content IS NOT NULL
|
||||
"""
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
with self._lock:
|
||||
self._conn.close()
|
||||
|
||||
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||||
with self._lock:
|
||||
self._conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
result = fn(self._conn)
|
||||
self._conn.commit()
|
||||
return result
|
||||
except BaseException:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
|
||||
def _fetchone(self, sql: str, params: tuple[Any, ...] = ()) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
row = self._conn.execute(sql, params).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def _fetchall(self, sql: str, params: tuple[Any, ...] = ()) -> list[dict[str, Any]]:
|
||||
with self._lock:
|
||||
rows = self._conn.execute(sql, params).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
source: str = "unknown",
|
||||
model: str | None = None,
|
||||
title: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> str:
|
||||
"""确保 session 行存在;若不存在则创建,若存在则尽量补全缺失元数据。"""
|
||||
|
||||
now = time.time()
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> str:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO sessions (
|
||||
id, source, user_id, title, model, parent_session_id, started_at, last_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
source = CASE
|
||||
WHEN sessions.source = 'unknown' AND excluded.source != 'unknown' THEN excluded.source
|
||||
ELSE sessions.source
|
||||
END,
|
||||
user_id = COALESCE(sessions.user_id, excluded.user_id),
|
||||
title = COALESCE(sessions.title, excluded.title),
|
||||
model = COALESCE(sessions.model, excluded.model),
|
||||
parent_session_id = COALESCE(sessions.parent_session_id, excluded.parent_session_id)
|
||||
""",
|
||||
(session_id, source, user_id, title, model, parent_session_id, now, now),
|
||||
)
|
||||
return session_id
|
||||
|
||||
return self._execute_write(_do)
|
||||
|
||||
def get_session_record(self, session_id: str) -> SessionRecord | None:
|
||||
row = self._fetchone("SELECT * FROM sessions WHERE id = ?", (session_id,))
|
||||
return SessionRecord.from_row(row) if row else None
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
"""保存本 session 组装后的完整 system prompt snapshot。"""
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET system_prompt = ?, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(system_prompt, time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: float = 0.0,
|
||||
actual_cost_usd: float | None = None,
|
||||
absolute: bool = False,
|
||||
) -> None:
|
||||
"""更新会话 usage。默认按增量累加。"""
|
||||
|
||||
if absolute:
|
||||
sql = """
|
||||
UPDATE sessions
|
||||
SET input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = ?,
|
||||
actual_cost_usd = ?,
|
||||
last_active = ?
|
||||
WHERE id = ?
|
||||
"""
|
||||
params = (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
time.time(),
|
||||
session_id,
|
||||
)
|
||||
else:
|
||||
sql = """
|
||||
UPDATE sessions
|
||||
SET input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
cache_read_tokens = cache_read_tokens + ?,
|
||||
cache_write_tokens = cache_write_tokens + ?,
|
||||
reasoning_tokens = reasoning_tokens + ?,
|
||||
estimated_cost_usd = estimated_cost_usd + ?,
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE COALESCE(actual_cost_usd, 0) + ?
|
||||
END,
|
||||
last_active = ?
|
||||
WHERE id = ?
|
||||
"""
|
||||
params = (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
time.time(),
|
||||
session_id,
|
||||
)
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(sql, params)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def append_message(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
run_id: str | None = None,
|
||||
role: str,
|
||||
event_type: str | None = None,
|
||||
event_payload: dict[str, Any] | None = None,
|
||||
context_visible: bool = True,
|
||||
content: str | None = None,
|
||||
tool_name: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
finish_reason: str | None = None,
|
||||
reasoning: str | None = None,
|
||||
reasoning_details: Any | None = None,
|
||||
codex_reasoning_items: Any | None = None,
|
||||
source: str = "unknown",
|
||||
title: str | None = None,
|
||||
model: str | None = None,
|
||||
user_id: str | None = None,
|
||||
parent_session_id: str | None = None,
|
||||
) -> int:
|
||||
"""向指定 session 追加一条消息。"""
|
||||
|
||||
self.ensure_session(
|
||||
session_id,
|
||||
source=source,
|
||||
model=model,
|
||||
title=title,
|
||||
user_id=user_id,
|
||||
parent_session_id=parent_session_id,
|
||||
)
|
||||
now = time.time()
|
||||
tool_calls_json = json.dumps(tool_calls) if tool_calls is not None else None
|
||||
event_payload_json = json.dumps(event_payload) if event_payload is not None else None
|
||||
reasoning_details_json = json.dumps(reasoning_details) if reasoning_details is not None else None
|
||||
codex_items_json = json.dumps(codex_reasoning_items) if codex_reasoning_items is not None else None
|
||||
preview = (content or "")[:120] if role == "user" and content else None
|
||||
tool_call_count = len(tool_calls) if isinstance(tool_calls, list) else (1 if tool_calls else 0)
|
||||
|
||||
def _do(conn: sqlite3.Connection) -> int:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
INSERT INTO messages (
|
||||
session_id, run_id, role, event_type, event_payload, context_visible, content,
|
||||
tool_name, tool_calls, tool_call_id, timestamp, finish_reason, reasoning,
|
||||
reasoning_details, codex_reasoning_items
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
session_id,
|
||||
run_id,
|
||||
role,
|
||||
event_type or role,
|
||||
event_payload_json,
|
||||
1 if context_visible else 0,
|
||||
content,
|
||||
tool_name,
|
||||
tool_calls_json,
|
||||
tool_call_id,
|
||||
now,
|
||||
finish_reason,
|
||||
reasoning,
|
||||
reasoning_details_json,
|
||||
codex_items_json,
|
||||
),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET last_active = ?,
|
||||
message_count = message_count + 1,
|
||||
tool_call_count = tool_call_count + ?,
|
||||
model = COALESCE(model, ?),
|
||||
preview = CASE
|
||||
WHEN preview IS NULL AND ? IS NOT NULL THEN ?
|
||||
ELSE preview
|
||||
END
|
||||
WHERE id = ?
|
||||
""",
|
||||
(now, tool_call_count, model, preview, preview, session_id),
|
||||
)
|
||||
return int(cursor.lastrowid)
|
||||
|
||||
return self._execute_write(_do)
|
||||
|
||||
def get_message_records(self, session_id: str) -> list[MessageRecord]:
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ?
|
||||
ORDER BY timestamp, id
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
return [MessageRecord.from_row(row) for row in rows]
|
||||
|
||||
def get_event_records(self, session_id: str) -> list[MessageRecord]:
|
||||
"""返回当前 session 的完整事件流。
|
||||
|
||||
当前阶段里,事件流仍复用 `messages` 表承载,所以这里等价于读取全部 message records。
|
||||
后面如果单独拆出 run/checkpoint/system event 表,上层 manager 仍可以继续保持这个接口不变。
|
||||
"""
|
||||
|
||||
return self.get_message_records(session_id)
|
||||
|
||||
def list_run_ids(self, session_id: str) -> list[str]:
|
||||
"""按时间顺序列出当前 session 中出现过的 run_id。"""
|
||||
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT run_id
|
||||
FROM messages
|
||||
WHERE session_id = ? AND run_id IS NOT NULL
|
||||
GROUP BY run_id
|
||||
ORDER BY MIN(timestamp), MIN(id)
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
return [str(row["run_id"]) for row in rows if row.get("run_id")]
|
||||
|
||||
def get_run_event_records(self, session_id: str, run_id: str) -> list[MessageRecord]:
|
||||
"""返回某一次 run 对应的事件片段。"""
|
||||
|
||||
rows = self._fetchall(
|
||||
"""
|
||||
SELECT *
|
||||
FROM messages
|
||||
WHERE session_id = ? AND run_id = ?
|
||||
ORDER BY timestamp, id
|
||||
""",
|
||||
(session_id, run_id),
|
||||
)
|
||||
return [MessageRecord.from_row(row) for row in rows]
|
||||
|
||||
def get_messages_as_conversation(self, session_id: str) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = []
|
||||
for record in self.get_event_records(session_id):
|
||||
if not record.context_visible:
|
||||
continue
|
||||
messages.append(record.to_conversation_message())
|
||||
return messages
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET ended_at = ?, end_reason = ?, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(time.time(), end_reason, time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
def _do(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE sessions
|
||||
SET ended_at = NULL, end_reason = NULL, last_active = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(time.time(), session_id),
|
||||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
Reference in New Issue
Block a user