395 lines
16 KiB
Python
395 lines
16 KiB
Python
"""统一 agent 注册表。
|
||
|
||
这个模块把当前工作区里“可被委派”的执行体统一抽象成 `AgentDescriptor`:
|
||
1. workspace 手工登记的远端 A2A agent;
|
||
2. plugin 提供的本地 prompt agent;
|
||
3. skill 元数据里声明的 agent cards;
|
||
4. 内置 local fallback agent。
|
||
|
||
上层委派逻辑只和 `AgentDescriptor` 打交道,不需要关心来源细节。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import re
|
||
from dataclasses import asdict, dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
from nanobot.agent.plugins import PluginLoader
|
||
from nanobot.agent.skills import SkillsLoader
|
||
|
||
_TOKEN_RE = re.compile(r"[a-z0-9_-]+")
|
||
|
||
|
||
@dataclass
|
||
class AgentDescriptor:
|
||
"""委派层使用的统一 agent 描述对象。"""
|
||
|
||
# 稳定 ID,供路由、持久化和精确匹配使用。
|
||
id: str
|
||
# 面向 UI/日志的展示名。
|
||
name: str
|
||
# 简短说明,主要供模型和前端展示。
|
||
description: str
|
||
# 来源类型:builtin / plugin / skill / workspace。
|
||
source: str
|
||
# 运行方式:local_prompt / local_fallback / a2a_remote 等。
|
||
kind: str
|
||
# 底层协议,目前主要是 a2a 或 None。
|
||
protocol: str | None = None
|
||
plugin_name: str | None = None
|
||
skill_name: str | None = None
|
||
model: str | None = None
|
||
system_prompt: str | None = None
|
||
endpoint: str | None = None
|
||
base_url: str | None = None
|
||
card_url: str | None = None
|
||
auth_env: str | None = None
|
||
auth_mode: str = "none"
|
||
auth_audience: str | None = None
|
||
auth_scopes: list[str] = field(default_factory=list)
|
||
enabled: bool = True
|
||
tags: list[str] = field(default_factory=list)
|
||
aliases: list[str] = field(default_factory=list)
|
||
capabilities: dict[str, Any] = field(default_factory=dict)
|
||
metadata: dict[str, Any] = field(default_factory=dict)
|
||
support_group: bool = True
|
||
support_streaming: bool = False
|
||
|
||
def matches(self, target: str) -> bool:
|
||
"""判断给定目标字符串是否命中当前 agent。"""
|
||
probe = (target or "").strip().lower()
|
||
if not probe:
|
||
return False
|
||
# 同时支持按 id / name / alias 命中,方便模型用自然语言近似引用。
|
||
candidates = {self.id.lower(), self.name.lower()}
|
||
candidates.update(alias.lower() for alias in self.aliases if alias)
|
||
return probe in candidates
|
||
|
||
def searchable_text(self) -> str:
|
||
"""构造一段用于简单相关性匹配的可搜索文本。"""
|
||
fields = [
|
||
self.id,
|
||
self.name,
|
||
self.description,
|
||
" ".join(self.tags),
|
||
" ".join(self.aliases),
|
||
self.plugin_name or "",
|
||
self.skill_name or "",
|
||
]
|
||
return " ".join(part for part in fields if part).lower()
|
||
|
||
def public_dict(self) -> dict[str, Any]:
|
||
"""导出给前端使用的安全字典。"""
|
||
data = asdict(self)
|
||
# system_prompt 属于内部实现细节,不应默认暴露给前端。
|
||
data.pop("system_prompt", None)
|
||
return data
|
||
|
||
|
||
class WorkspaceAgentStore:
|
||
"""workspace 级 agent 存储。
|
||
|
||
这里保存的是用户在 Web UI 或本地配置里手工登记的 agent,
|
||
文件位置固定为 `<workspace>/agents/registry.json`。
|
||
"""
|
||
|
||
def __init__(self, workspace: Path):
|
||
self.workspace = workspace
|
||
# 单独放到 `agents/` 目录,便于和 skills / memory / files 等目录职责分离。
|
||
self.directory = workspace / "agents"
|
||
self.path = self.directory / "registry.json"
|
||
|
||
def list_agents(self) -> list[dict[str, Any]]:
|
||
"""读取并返回所有手工登记 agent。"""
|
||
if not self.path.exists():
|
||
return []
|
||
try:
|
||
raw = json.loads(self.path.read_text(encoding="utf-8"))
|
||
except (OSError, json.JSONDecodeError, ValueError):
|
||
# 存储损坏时不抛异常拖垮主流程,直接视为空。
|
||
return []
|
||
if not isinstance(raw, list):
|
||
return []
|
||
result: list[dict[str, Any]] = []
|
||
for item in raw:
|
||
# 仅接受带 id 的对象,保证后续 registry 至少有稳定主键。
|
||
if isinstance(item, dict) and item.get("id"):
|
||
result.append(item)
|
||
return result
|
||
|
||
def save_agents(self, agents: list[dict[str, Any]]) -> None:
|
||
"""将 agent 列表完整覆写到 registry 文件。"""
|
||
self.directory.mkdir(parents=True, exist_ok=True)
|
||
self.path.write_text(
|
||
json.dumps(agents, indent=2, ensure_ascii=False),
|
||
encoding="utf-8",
|
||
)
|
||
|
||
def upsert_agent(self, agent: dict[str, Any]) -> dict[str, Any]:
|
||
"""按 id 新增或更新一个 agent 记录。"""
|
||
record = dict(agent)
|
||
agent_id = str(record.get("id", "")).strip()
|
||
if not agent_id:
|
||
raise ValueError("Agent id is required")
|
||
record["id"] = agent_id
|
||
# 对基础展示字段做最小兜底,避免后续 UI 或提示词出现空值。
|
||
record.setdefault("name", agent_id)
|
||
record.setdefault("description", record["name"])
|
||
record.setdefault("protocol", "a2a")
|
||
record.setdefault("enabled", True)
|
||
record.setdefault("tags", [])
|
||
# 先剔除旧记录再 append,最后统一排序,保持存储文件稳定可读。
|
||
agents = [a for a in self.list_agents() if a.get("id") != agent_id]
|
||
agents.append(record)
|
||
agents.sort(key=lambda item: item.get("id", "").lower())
|
||
self.save_agents(agents)
|
||
return record
|
||
|
||
def delete_agent(self, agent_id: str) -> bool:
|
||
"""按 id 删除一个 agent,删除成功返回 True。"""
|
||
target = agent_id.strip()
|
||
if not target:
|
||
return False
|
||
agents = self.list_agents()
|
||
filtered = [a for a in agents if a.get("id") != target]
|
||
if len(filtered) == len(agents):
|
||
return False
|
||
self.save_agents(filtered)
|
||
return True
|
||
|
||
|
||
class AgentRegistry:
|
||
"""构建并查询当前可委派 agent 集合。"""
|
||
|
||
def __init__(
|
||
self,
|
||
workspace: Path,
|
||
plugins: PluginLoader | None = None,
|
||
skills: SkillsLoader | None = None,
|
||
allow_skill_cards: bool = True,
|
||
allow_workspace_agents: bool = True,
|
||
):
|
||
self.workspace = workspace
|
||
# 插件和技能加载器允许外部复用同一个实例,避免重复扫描磁盘。
|
||
self.plugins = plugins or PluginLoader(workspace)
|
||
self.skills = skills or SkillsLoader(workspace, extra_dirs=self.plugins.get_skill_dirs())
|
||
self.allow_skill_cards = allow_skill_cards
|
||
self.allow_workspace_agents = allow_workspace_agents
|
||
self.workspace_store = WorkspaceAgentStore(workspace)
|
||
|
||
def list_agents(self, include_local_fallback: bool = True) -> list[AgentDescriptor]:
|
||
"""按统一格式列出当前可见 agent。"""
|
||
agents: list[AgentDescriptor] = []
|
||
|
||
if self.allow_workspace_agents:
|
||
for record in self.workspace_store.list_agents():
|
||
if not record.get("enabled", True):
|
||
continue
|
||
agent = self._workspace_record_to_descriptor(record)
|
||
if agent:
|
||
agents.append(agent)
|
||
|
||
# plugin agents 本质上是“带独立系统提示词的本地执行器”。
|
||
for plugin in self.plugins.plugins.values():
|
||
for agent in plugin.agents.values():
|
||
agents.append(
|
||
AgentDescriptor(
|
||
id=f"plugin:{agent.name}",
|
||
name=agent.name,
|
||
description=agent.description or agent.name,
|
||
source="plugin",
|
||
kind="local_prompt",
|
||
protocol=None,
|
||
plugin_name=agent.plugin_name,
|
||
model=agent.model,
|
||
system_prompt=agent.system_prompt,
|
||
aliases=[agent.name],
|
||
metadata={"plugin_name": agent.plugin_name},
|
||
)
|
||
)
|
||
|
||
if self.allow_skill_cards:
|
||
# skill 里声明的 card 视为远端 A2A agent 的静态入口。
|
||
for card in self.skills.list_skill_agent_cards():
|
||
agent = self._skill_card_to_descriptor(card)
|
||
if agent:
|
||
agents.append(agent)
|
||
|
||
if include_local_fallback:
|
||
# 永远保留一个本地兜底执行器,确保自动路由时至少有可执行目标。
|
||
agents.append(
|
||
AgentDescriptor(
|
||
id="local-subagent",
|
||
name="Local Subagent",
|
||
description="Local fallback agent that can use files, shell, and web tools.",
|
||
source="builtin",
|
||
kind="local_fallback",
|
||
protocol=None,
|
||
aliases=["subagent", "local"],
|
||
support_group=True,
|
||
)
|
||
)
|
||
|
||
seen: set[str] = set()
|
||
result: list[AgentDescriptor] = []
|
||
for agent in agents:
|
||
# 去重规则按 id 小写匹配,优先保留先出现的来源。
|
||
key = agent.id.lower()
|
||
if key in seen:
|
||
continue
|
||
seen.add(key)
|
||
result.append(agent)
|
||
return result
|
||
|
||
def get_agent(self, target: str) -> AgentDescriptor | None:
|
||
"""按 id / name / alias 获取单个 agent。"""
|
||
probe = (target or "").strip()
|
||
if not probe:
|
||
return None
|
||
for agent in self.list_agents():
|
||
if agent.matches(probe):
|
||
return agent
|
||
return None
|
||
|
||
def suggest_agents(self, query: str, limit: int = 5) -> list[AgentDescriptor]:
|
||
"""基于简单词项打分为一段任务文本推荐 agent。"""
|
||
tokens = {token for token in _TOKEN_RE.findall((query or "").lower()) if len(token) > 2}
|
||
if not tokens:
|
||
return []
|
||
|
||
scored: list[tuple[int, AgentDescriptor]] = []
|
||
for agent in self.list_agents(include_local_fallback=False):
|
||
haystack = agent.searchable_text()
|
||
score = 0
|
||
for token in tokens:
|
||
# token 命中一次给基础分。
|
||
if token in haystack:
|
||
score += 2
|
||
# 如果查询里直接出现了 agent 名或 id,再给更高权重。
|
||
if agent.name.lower() in query.lower() or agent.id.lower() in query.lower():
|
||
score += 5
|
||
if score > 0:
|
||
scored.append((score, agent))
|
||
|
||
scored.sort(key=lambda item: (-item[0], item[1].name.lower()))
|
||
return [agent for _, agent in scored[:limit]]
|
||
|
||
def build_agents_summary(self) -> str:
|
||
"""把 agent 列表格式化成 prompt 可直接嵌入的 XML 片段。"""
|
||
agents = self.list_agents()
|
||
if not agents:
|
||
return ""
|
||
|
||
def esc(value: str) -> str:
|
||
# 这里手工转义最基础的 XML 特殊字符,避免描述文本破坏结构。
|
||
return (
|
||
value.replace("&", "&")
|
||
.replace("<", "<")
|
||
.replace(">", ">")
|
||
)
|
||
|
||
lines = ["<agents>"]
|
||
for agent in agents:
|
||
lines.append(" <agent>")
|
||
lines.append(f" <id>{esc(agent.id)}</id>")
|
||
lines.append(f" <name>{esc(agent.name)}</name>")
|
||
lines.append(f" <source>{esc(agent.source)}</source>")
|
||
lines.append(f" <kind>{esc(agent.kind)}</kind>")
|
||
lines.append(f" <description>{esc(agent.description)}</description>")
|
||
if agent.protocol:
|
||
lines.append(f" <protocol>{esc(agent.protocol)}</protocol>")
|
||
if agent.tags:
|
||
lines.append(f" <tags>{esc(', '.join(agent.tags))}</tags>")
|
||
lines.append(
|
||
f" <supports-group>{str(agent.support_group).lower()}</supports-group>"
|
||
)
|
||
lines.append(" </agent>")
|
||
lines.append("</agents>")
|
||
return "\n".join(lines)
|
||
|
||
def list_public_agents(self) -> list[dict[str, Any]]:
|
||
"""列出脱敏后的 agent 结构,供 Web API 使用。"""
|
||
return [agent.public_dict() for agent in self.list_agents()]
|
||
|
||
def _workspace_record_to_descriptor(self, record: dict[str, Any]) -> AgentDescriptor | None:
|
||
"""把 workspace registry 里的原始记录转成统一描述对象。"""
|
||
protocol = str(record.get("protocol") or "a2a").lower()
|
||
if protocol != "a2a":
|
||
# 当前仅支持把 workspace 记录解释成 A2A agent。
|
||
return None
|
||
agent_id = str(record.get("id", "")).strip()
|
||
if not agent_id:
|
||
return None
|
||
name = str(record.get("name") or agent_id)
|
||
return AgentDescriptor(
|
||
id=agent_id,
|
||
name=name,
|
||
description=str(record.get("description") or name),
|
||
source="workspace",
|
||
kind="a2a_remote",
|
||
protocol="a2a",
|
||
endpoint=record.get("endpoint") or record.get("base_url"),
|
||
base_url=record.get("base_url") or record.get("endpoint"),
|
||
card_url=record.get("card_url"),
|
||
auth_env=record.get("auth_env"),
|
||
auth_mode=str(record.get("auth_mode") or "none").strip().lower() or "none",
|
||
auth_audience=(str(record.get("auth_audience") or "").strip() or None),
|
||
auth_scopes=[
|
||
str(scope).strip()
|
||
for scope in record.get("auth_scopes", [])
|
||
if str(scope).strip()
|
||
],
|
||
enabled=bool(record.get("enabled", True)),
|
||
tags=[str(tag) for tag in record.get("tags", []) if str(tag).strip()],
|
||
aliases=[
|
||
alias
|
||
for alias in [record.get("name"), *record.get("aliases", [])]
|
||
if isinstance(alias, str) and alias.strip()
|
||
],
|
||
capabilities=record.get("capabilities", {}) if isinstance(record.get("capabilities"), dict) else {},
|
||
metadata=record.get("metadata", {}) if isinstance(record.get("metadata"), dict) else {},
|
||
support_group=bool(record.get("support_group", True)),
|
||
support_streaming=bool(record.get("support_streaming", False)),
|
||
)
|
||
|
||
def _skill_card_to_descriptor(self, card: dict[str, Any]) -> AgentDescriptor | None:
|
||
"""把 skill frontmatter 中的 agent card 转成统一描述对象。"""
|
||
card_id = str(card.get("id") or "").strip()
|
||
skill_name = str(card.get("skill_name") or "").strip()
|
||
if not card_id:
|
||
return None
|
||
name = str(card.get("name") or card_id)
|
||
return AgentDescriptor(
|
||
id=card_id,
|
||
name=name,
|
||
description=str(card.get("description") or name),
|
||
source="skill",
|
||
kind="a2a_remote",
|
||
protocol="a2a",
|
||
skill_name=skill_name or None,
|
||
endpoint=card.get("endpoint") or card.get("base_url"),
|
||
base_url=card.get("base_url") or card.get("endpoint"),
|
||
card_url=card.get("url") or card.get("card_url"),
|
||
auth_env=card.get("auth_env"),
|
||
auth_mode=str(card.get("auth_mode") or "none").strip().lower() or "none",
|
||
auth_audience=(str(card.get("auth_audience") or "").strip() or None),
|
||
auth_scopes=[
|
||
str(scope).strip()
|
||
for scope in card.get("auth_scopes", [])
|
||
if str(scope).strip()
|
||
],
|
||
tags=[str(tag) for tag in card.get("tags", []) if str(tag).strip()],
|
||
aliases=[
|
||
alias
|
||
for alias in [card.get("name"), *card.get("aliases", [])]
|
||
if isinstance(alias, str) and alias.strip()
|
||
],
|
||
capabilities=card.get("capabilities", {}) if isinstance(card.get("capabilities"), dict) else {},
|
||
metadata=card.get("metadata", {}) if isinstance(card.get("metadata"), dict) else {},
|
||
support_group=bool(card.get("support_group", True)),
|
||
support_streaming=bool(card.get("support_streaming", False)),
|
||
)
|