Compare commits
69 Commits
personal-u
...
KEYCLOAK_L
| Author | SHA1 | Date | |
|---|---|---|---|
| 461d1300ad | |||
| fc9fd93c36 | |||
| 9cc3334ea7 | |||
| dc4c6f313d | |||
| 9e2c02a333 | |||
| b9171998b9 | |||
| 64d789a3d0 | |||
| cc1bf85517 | |||
| 4c8bc53d33 | |||
| 70014c0f70 | |||
| eb69bb168a | |||
| 7287e93f87 | |||
| a925f0e77f | |||
| 6dc580ab26 | |||
| 3a16dc283d | |||
| 0fd4df3c17 | |||
| f46a435bab | |||
| a28254c6b8 | |||
| 0cf4f44346 | |||
| e0bc6c55b0 | |||
| 2c5205b06e | |||
| 236ac19789 | |||
| ba2417c7f2 | |||
| f3655c86c9 | |||
| c3d84b904a | |||
| ee972441f5 | |||
| d335199a64 | |||
| feeaccc0e3 | |||
| cf35edb4ca | |||
| e0a4862af8 | |||
| c3a0aef104 | |||
| b25713a141 | |||
| d74a1c9c12 | |||
| 834d4e1e2f | |||
| 6a6ddc21c0 | |||
| 826db8ec2e | |||
| 33a9845566 | |||
| 55b39563a0 | |||
| 41ac87e322 | |||
| 542b23ef6e | |||
| 9002d1206f | |||
| dd9f40b38c | |||
| 96562877cc | |||
| f58a57e5b8 | |||
| 362aae9b12 | |||
| 29d175222d | |||
| 2e4f8541ee | |||
| a1164dc49a | |||
| 7b638b083a | |||
| 6e9e74d1ee | |||
| 16347caf5e | |||
| 030bce8a60 | |||
| c671b66043 | |||
| e061961a79 | |||
| 8068d86760 | |||
| 4022db8887 | |||
| c53e221117 | |||
| b808f5cbc2 | |||
| 0adc04806c | |||
| 60605a74e0 | |||
| 3ff2e2ce11 | |||
| 0ace09b984 | |||
| c3c4df306b | |||
| 5446614828 | |||
| 2fd618da9c | |||
| 28a2627b1f | |||
| 249087e943 | |||
| 8bff282892 | |||
| 3a3e848a78 |
40
.env.example
40
.env.example
@ -1,30 +1,24 @@
|
|||||||
# Shared values used by the root deployment flow in README.md
|
# Standalone app-instance defaults
|
||||||
|
|
||||||
PROJECT_ROOT=/home/ivan/xuan/beaver_project
|
BEAVER_PUBLIC_URL=http://172.19.0.245:18080
|
||||||
BEAVER_NET=beaver-instance-edge
|
BEAVER_HOST_IP=172.19.0.245
|
||||||
BEAVER_PROXY_CONTAINER_NAME=beaver-router-proxy
|
BEAVER_HOST_PORT=18080
|
||||||
|
|
||||||
BEAVER_DEPLOY_TOKEN=change-me
|
|
||||||
BEAVER_AUTHZ_INTERNAL_TOKEN=change-me
|
|
||||||
|
|
||||||
BEAVER_SERVER_IP=127.0.0.1
|
|
||||||
BEAVER_BASE_DOMAIN=localhost
|
|
||||||
|
|
||||||
BEAVER_PROVIDER=openai
|
|
||||||
BEAVER_MODEL=openai/gpt-5
|
|
||||||
BEAVER_API_KEY=sk-xxxxxxxx
|
|
||||||
BEAVER_API_BASE=
|
|
||||||
|
|
||||||
# Per-instance Beaver backend config. In Docker app-instance this should point
|
|
||||||
# to the mounted single-user sandbox config, not to frontend env.
|
|
||||||
BEAVER_HOME=/root/.beaver
|
BEAVER_HOME=/root/.beaver
|
||||||
BEAVER_CONFIG_PATH=/root/.beaver/config.json
|
BEAVER_CONFIG_PATH=/root/.beaver/config.json
|
||||||
BEAVER_WORKSPACE=/root/.beaver/workspace
|
BEAVER_WORKSPACE=/root/.beaver/workspace
|
||||||
|
|
||||||
# Must be reachable from app-instance containers.
|
# Keycloak SSO
|
||||||
BEAVER_AUTHZ_URL=http://beaver-authz-service:19090
|
BEAVER_KEYCLOAK_ISSUER=https://keycloak.bwgdi.com/realms/beaver
|
||||||
BEAVER_OUTLOOK_MCP_URL=
|
BEAVER_KEYCLOAK_CLIENT_ID=beaver-agnet
|
||||||
BEAVER_OUTLOOK_MCP_SERVER_ID=outlook_mcp
|
BEAVER_KEYCLOAK_REDIRECT_URI=http://172.19.0.245:18080/auth/callback
|
||||||
|
BEAVER_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=http://172.19.0.245:18080/logout/callback
|
||||||
|
|
||||||
# Must be reachable from auth-portal and authz-service containers.
|
# Model provider settings used when generating config.json with run-standalone.sh
|
||||||
BEAVER_DEPLOY_URL=http://beaver-deploy-control:8090
|
BEAVER_PROVIDER=openai
|
||||||
|
BEAVER_MODEL=openai/gpt-5
|
||||||
|
BEAVER_API_KEY=
|
||||||
|
BEAVER_API_BASE=
|
||||||
|
|
||||||
|
# Optional runtime limits
|
||||||
|
BEAVER_USER_FILES_MAX_UPLOAD_BYTES=5368709120
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -21,6 +21,7 @@ sessions/
|
|||||||
**/.ruff_cache/
|
**/.ruff_cache/
|
||||||
**/.mypy_cache/
|
**/.mypy_cache/
|
||||||
**/.cache/
|
**/.cache/
|
||||||
|
**/.codegraph/
|
||||||
**/.venv/
|
**/.venv/
|
||||||
**/dist/
|
**/dist/
|
||||||
**/build/
|
**/build/
|
||||||
|
|||||||
12
AGENTS.md
Normal file
12
AGENTS.md
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Repository Instructions
|
||||||
|
|
||||||
|
This branch is scoped to the standalone `app-instance` package.
|
||||||
|
|
||||||
|
Useful entry points:
|
||||||
|
|
||||||
|
- `app-instance/README.md`
|
||||||
|
- `app-instance/run-standalone.sh`
|
||||||
|
- `app-instance/backend`
|
||||||
|
- `app-instance/frontend`
|
||||||
|
|
||||||
|
The repository no longer owns the old auth portal, authz service, deploy control plane, router proxy, or per-user instance orchestration.
|
||||||
411
DESIGN.md
411
DESIGN.md
@ -1,411 +0,0 @@
|
|||||||
# DESIGN.md
|
|
||||||
|
|
||||||
## Brand
|
|
||||||
|
|
||||||
**Beaver — Taupe**
|
|
||||||
|
|
||||||
A calm editorial UI system focused on rhythm, hierarchy, and soft neutral contrast.
|
|
||||||
Designed for AI-native tools, dashboards, and minimalist productivity software.
|
|
||||||
|
|
||||||
The interface emphasizes:
|
|
||||||
|
|
||||||
- Spacious layouts
|
|
||||||
- Soft grayscale surfaces
|
|
||||||
- Typography-first hierarchy
|
|
||||||
- Rounded geometry
|
|
||||||
- Quiet shadows
|
|
||||||
- Dense information with low visual noise
|
|
||||||
|
|
||||||
The visual tone should feel:
|
|
||||||
|
|
||||||
- thoughtful
|
|
||||||
- mature
|
|
||||||
- calm
|
|
||||||
- premium
|
|
||||||
- architectural
|
|
||||||
- editorial
|
|
||||||
|
|
||||||
Avoid:
|
|
||||||
|
|
||||||
- saturated colors
|
|
||||||
- hard borders
|
|
||||||
- sharp corners
|
|
||||||
- excessive gradients
|
|
||||||
- loud shadows
|
|
||||||
- playful illustration-heavy UI
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Colors
|
|
||||||
|
|
||||||
## Core Palette
|
|
||||||
|
|
||||||
| Token | Hex | Usage |
|
|
||||||
|---|---|---|
|
|
||||||
| background | `#F5F3F1` | Main app background |
|
|
||||||
| foreground | `#0B0B0B` | Primary text |
|
|
||||||
| primary | `#1D1715` | Primary actions |
|
|
||||||
| secondary | `#E5E2DF` | Secondary surfaces |
|
|
||||||
| muted | `#DDD9D6` | Muted backgrounds |
|
|
||||||
| accent | `#CAC5C0` | Borders / subtle emphasis |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Neutral Scale
|
|
||||||
|
|
||||||
| Token | Hex |
|
|
||||||
|---|---|
|
|
||||||
| zinc-50 | `#F7F5F4` |
|
|
||||||
| zinc-100 | `#ECE8E5` |
|
|
||||||
| zinc-200 | `#D8D2CE` |
|
|
||||||
| zinc-300 | `#B8AEA8` |
|
|
||||||
| zinc-400 | `#8B7E77` |
|
|
||||||
| zinc-500 | `#6A5E58` |
|
|
||||||
| zinc-600 | `#4F4642` |
|
|
||||||
| zinc-700 | `#342E2B` |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Semantic Colors
|
|
||||||
|
|
||||||
### Taupe
|
|
||||||
|
|
||||||
| Step | Hex |
|
|
||||||
|---|---|
|
|
||||||
| taupe-100 | `#E7E2DE` |
|
|
||||||
| taupe-300 | `#B8AEA8` |
|
|
||||||
| taupe-500 | `#8B7E77` |
|
|
||||||
| taupe-700 | `#5F5550` |
|
|
||||||
|
|
||||||
### Sage
|
|
||||||
|
|
||||||
| Step | Hex |
|
|
||||||
|---|---|
|
|
||||||
| sage-100 | `#E3E8E2` |
|
|
||||||
| sage-300 | `#B7C2B5` |
|
|
||||||
| sage-500 | `#869683` |
|
|
||||||
| sage-700 | `#657162` |
|
|
||||||
|
|
||||||
### Slate
|
|
||||||
|
|
||||||
| Step | Hex |
|
|
||||||
|---|---|
|
|
||||||
| slate-100 | `#E4E7EB` |
|
|
||||||
| slate-300 | `#BCC4CE` |
|
|
||||||
| slate-500 | `#8C96A3` |
|
|
||||||
| slate-700 | `#697281` |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Typography
|
|
||||||
|
|
||||||
## Philosophy
|
|
||||||
|
|
||||||
Typography drives hierarchy.
|
|
||||||
|
|
||||||
The system should feel like a modern editorial publication mixed with a productivity dashboard.
|
|
||||||
|
|
||||||
Large headings use elegant serif typography.
|
|
||||||
UI and body copy use neutral grotesk sans-serif typography.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Font Stack
|
|
||||||
|
|
||||||
### Serif
|
|
||||||
|
|
||||||
```css
|
|
||||||
font-family: "Lora", Georgia, serif;
|
|
||||||
```
|
|
||||||
|
|
||||||
Used for:
|
|
||||||
|
|
||||||
- hero titles
|
|
||||||
- article headings
|
|
||||||
- marketing emphasis
|
|
||||||
- editorial sections
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Sans
|
|
||||||
|
|
||||||
```css
|
|
||||||
font-family: "Public Sans", Inter, sans-serif;
|
|
||||||
```
|
|
||||||
|
|
||||||
Used for:
|
|
||||||
|
|
||||||
- UI
|
|
||||||
- labels
|
|
||||||
- forms
|
|
||||||
- dashboards
|
|
||||||
- buttons
|
|
||||||
- navigation
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Type Scale
|
|
||||||
|
|
||||||
| Style | Size | Weight | Line Height |
|
|
||||||
|---|---|---|---|
|
|
||||||
| h1 | 48px | 600 | 1.1 |
|
|
||||||
| h2 | 36px | 600 | 1.15 |
|
|
||||||
| h3 | 28px | 500 | 1.2 |
|
|
||||||
| body-lg | 18px | 400 | 1.7 |
|
|
||||||
| body | 16px | 400 | 1.6 |
|
|
||||||
| small | 14px | 400 | 1.5 |
|
|
||||||
| mono | 13px | 500 | 1.4 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Radius
|
|
||||||
|
|
||||||
Rounded geometry should feel soft but architectural.
|
|
||||||
|
|
||||||
| Token | Radius |
|
|
||||||
|---|---|
|
|
||||||
| xs | 4px |
|
|
||||||
| sm | 8px |
|
|
||||||
| md | 12px |
|
|
||||||
| lg | 16px |
|
|
||||||
| xl | 24px |
|
|
||||||
| full | 999px |
|
|
||||||
|
|
||||||
Cards should primarily use:
|
|
||||||
|
|
||||||
```css
|
|
||||||
border-radius: 16px;
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Shadows
|
|
||||||
|
|
||||||
Shadows should be subtle and diffused.
|
|
||||||
|
|
||||||
Avoid strong elevation.
|
|
||||||
|
|
||||||
## Soft
|
|
||||||
|
|
||||||
```css
|
|
||||||
box-shadow:
|
|
||||||
0 1px 2px rgba(0,0,0,0.04),
|
|
||||||
0 6px 24px rgba(0,0,0,0.03);
|
|
||||||
```
|
|
||||||
|
|
||||||
## Floating
|
|
||||||
|
|
||||||
```css
|
|
||||||
box-shadow:
|
|
||||||
0 12px 40px rgba(0,0,0,0.06);
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Grid
|
|
||||||
|
|
||||||
## Layout
|
|
||||||
|
|
||||||
- 12-column grid
|
|
||||||
- Max width: `1280px`
|
|
||||||
- Horizontal padding: `32px`
|
|
||||||
- Large whitespace between sections
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Content Widths
|
|
||||||
|
|
||||||
| Type | Width |
|
|
||||||
|---|---|
|
|
||||||
| reading | 720px |
|
|
||||||
| dashboard | 1280px |
|
|
||||||
| modal | 480px |
|
|
||||||
| form | 560px |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Spacing
|
|
||||||
|
|
||||||
Base unit:
|
|
||||||
|
|
||||||
```txt
|
|
||||||
4px
|
|
||||||
```
|
|
||||||
|
|
||||||
Spacing scale:
|
|
||||||
|
|
||||||
| Token | Value |
|
|
||||||
|---|---|
|
|
||||||
| 1 | 4px |
|
|
||||||
| 2 | 8px |
|
|
||||||
| 3 | 12px |
|
|
||||||
| 4 | 16px |
|
|
||||||
| 5 | 20px |
|
|
||||||
| 6 | 24px |
|
|
||||||
| 8 | 32px |
|
|
||||||
| 10 | 40px |
|
|
||||||
| 12 | 48px |
|
|
||||||
| 16 | 64px |
|
|
||||||
|
|
||||||
Use generous vertical rhythm.
|
|
||||||
|
|
||||||
Sections should breathe.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Components
|
|
||||||
|
|
||||||
## Buttons
|
|
||||||
|
|
||||||
### Primary
|
|
||||||
|
|
||||||
- Dark background
|
|
||||||
- White text
|
|
||||||
- Pill radius
|
|
||||||
- Minimal shadow
|
|
||||||
|
|
||||||
```css
|
|
||||||
background: #1D1715;
|
|
||||||
color: white;
|
|
||||||
border-radius: 999px;
|
|
||||||
height: 40px;
|
|
||||||
padding: 0 16px;
|
|
||||||
```
|
|
||||||
|
|
||||||
### Secondary
|
|
||||||
|
|
||||||
```css
|
|
||||||
background: #ECE8E5;
|
|
||||||
color: #1D1715;
|
|
||||||
```
|
|
||||||
|
|
||||||
### Ghost
|
|
||||||
|
|
||||||
Transparent background with subtle hover fill.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Cards
|
|
||||||
|
|
||||||
Cards are soft containers with quiet separation.
|
|
||||||
|
|
||||||
```css
|
|
||||||
background: rgba(255,255,255,0.7);
|
|
||||||
border: 1px solid rgba(0,0,0,0.04);
|
|
||||||
border-radius: 16px;
|
|
||||||
```
|
|
||||||
|
|
||||||
Avoid heavy borders.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Inputs
|
|
||||||
|
|
||||||
Inputs should feel invisible until focused.
|
|
||||||
|
|
||||||
```css
|
|
||||||
background: #F7F5F4;
|
|
||||||
border: 1px solid transparent;
|
|
||||||
```
|
|
||||||
|
|
||||||
Focus:
|
|
||||||
|
|
||||||
```css
|
|
||||||
border-color: #8B7E77;
|
|
||||||
box-shadow: 0 0 0 3px rgba(139,126,119,0.12);
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Charts
|
|
||||||
|
|
||||||
Charts should use muted earthy tones.
|
|
||||||
|
|
||||||
Preferred palette:
|
|
||||||
|
|
||||||
- taupe
|
|
||||||
- sage
|
|
||||||
- slate
|
|
||||||
|
|
||||||
Avoid:
|
|
||||||
|
|
||||||
- neon colors
|
|
||||||
- bright blue dashboards
|
|
||||||
- rainbow charts
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Motion
|
|
||||||
|
|
||||||
Motion should be restrained and smooth.
|
|
||||||
|
|
||||||
Preferred easing:
|
|
||||||
|
|
||||||
```css
|
|
||||||
cubic-bezier(0.22, 1, 0.36, 1)
|
|
||||||
```
|
|
||||||
|
|
||||||
Preferred duration:
|
|
||||||
|
|
||||||
| Type | Duration |
|
|
||||||
|---|---|
|
|
||||||
| hover | 150ms |
|
|
||||||
| panel | 250ms |
|
|
||||||
| modal | 350ms |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Layout Skeleton
|
|
||||||
|
|
||||||
Application layout:
|
|
||||||
|
|
||||||
- Left sidebar
|
|
||||||
- Large content canvas
|
|
||||||
- Floating top toolbar
|
|
||||||
- Soft dashboard cards
|
|
||||||
- Spacious internal padding
|
|
||||||
|
|
||||||
The UI should always feel:
|
|
||||||
|
|
||||||
- breathable
|
|
||||||
- editorial
|
|
||||||
- premium
|
|
||||||
- calm
|
|
||||||
|
|
||||||
Never dense or overly enterprise-looking.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# Design Keywords
|
|
||||||
|
|
||||||
Use these words when generating UI:
|
|
||||||
|
|
||||||
- editorial
|
|
||||||
- taupe
|
|
||||||
- soft neutral
|
|
||||||
- premium minimal
|
|
||||||
- typography-first
|
|
||||||
- architectural spacing
|
|
||||||
- calm dashboard
|
|
||||||
- quiet luxury
|
|
||||||
- modern serif
|
|
||||||
- subtle shadows
|
|
||||||
- muted grayscale
|
|
||||||
- sophisticated SaaS
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
# AI Agent Instructions
|
|
||||||
|
|
||||||
When generating UI:
|
|
||||||
|
|
||||||
1. Prioritize whitespace over density
|
|
||||||
2. Typography should create hierarchy
|
|
||||||
3. Use muted neutral palettes
|
|
||||||
4. Prefer soft cards over hard sections
|
|
||||||
5. Avoid excessive color usage
|
|
||||||
6. Keep interactions subtle
|
|
||||||
7. Use serif fonts sparingly for emphasis
|
|
||||||
8. Maintain premium visual restraint
|
|
||||||
9. Design should feel timeless rather than trendy
|
|
||||||
10. Every screen should feel breathable
|
|
||||||
198
README.md
198
README.md
@ -1,183 +1,75 @@
|
|||||||
# Beaver Project
|
# Beaver App Instance
|
||||||
|
|
||||||
`Beaver Project` 是一套单机 Docker 部署的多实例运行环境:
|
This branch contains a standalone Beaver app instance for an external orchestrator to deploy.
|
||||||
|
|
||||||
- 用户先进入独立的 `auth-portal` 完成注册或登录。
|
The app instance is a single Docker image that runs:
|
||||||
- 注册会触发 `authz-service` 调用 `deploy-control`。
|
|
||||||
- `deploy-control` 在同一台机器上创建一个独立的 `app-instance` 容器。
|
|
||||||
- `router-proxy` 按实例域名把流量转发到对应容器。
|
|
||||||
|
|
||||||
当前推荐的最小部署方式是一台 Linux / WSL2 Ubuntu 机器加 Docker。生产域名和 HTTPS 可以放在项目外层的 Nginx、Caddy、Traefik 或云负载均衡上。
|
- Next.js frontend
|
||||||
|
- FastAPI backend
|
||||||
|
- Nginx reverse proxy
|
||||||
|
|
||||||
## 组件
|
Authentication is delegated to Keycloak. This repository does not maintain local user registration, passwords, per-user instance routing, or deployment orchestration.
|
||||||
|
|
||||||
| 目录 | 职责 | 默认端口 |
|
## Runtime Contract
|
||||||
| --- | --- | --- |
|
|
||||||
| `auth-portal/` | 用户登录、注册、模型配置引导入口 | `3081` |
|
|
||||||
| `authz-service/` | AuthZ 服务,负责账号和 backend 身份编排 | `19090` |
|
|
||||||
| `deploy-control/` | 部署控制面,调用 Docker 创建和管理实例 | `8090` |
|
|
||||||
| `router-proxy/` | 统一实例入口代理,按 Host 分发到实例容器 | `8088` |
|
|
||||||
| `app-instance/` | 单用户运行实例,容器内包含前端、后端和 Nginx | 容器内 `8080` |
|
|
||||||
|
|
||||||
公网环境通常只暴露:
|
The external orchestrator is responsible for:
|
||||||
|
|
||||||
- `auth-portal`: `3081`,或外层代理后的 `https://portal.example.com`
|
- Building or pulling the app image
|
||||||
- `router-proxy`: `8088`,或外层代理后的 `https://<slug>.apps.example.com`
|
- Providing a mounted Beaver home directory
|
||||||
|
- Providing `config.json`
|
||||||
|
- Choosing the public host and port
|
||||||
|
- Registering the public callback URL and post-logout callback URL in Keycloak
|
||||||
|
- Managing lifecycle, routing, TLS, and scaling
|
||||||
|
|
||||||
不要直接把 `deploy-control:8090` 和 `authz-service:19090` 暴露到公网。
|
The app instance exposes container port `8080`:
|
||||||
|
|
||||||
## 请求链路
|
|
||||||
|
|
||||||
注册:
|
|
||||||
|
|
||||||
```text
|
```text
|
||||||
Browser
|
/ frontend
|
||||||
-> auth-portal
|
/api backend API
|
||||||
-> authz-service POST /portal/register
|
/ws backend WebSocket
|
||||||
-> deploy-control POST /api/instances/register
|
|
||||||
-> app-instance/create-instance.sh
|
|
||||||
-> app-instance POST /api/auth/register
|
|
||||||
-> authz-service /oauth/register or /backends/register
|
|
||||||
-> auth-portal provider onboarding
|
|
||||||
-> deploy-control POST /api/instances/configure-provider
|
|
||||||
```
|
```
|
||||||
|
|
||||||
登录:
|
## Keycloak
|
||||||
|
|
||||||
|
Default IdP settings:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
Browser
|
issuer: https://keycloak.bwgdi.com/realms/beaver
|
||||||
-> auth-portal
|
client_id: beaver-agnet
|
||||||
-> deploy-control POST /api/instances/resolve
|
|
||||||
-> app-instance POST /api/auth/login
|
|
||||||
-> app-instance frontend URL
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 快速开始
|
For the current plain-IP deployment, register these values in Keycloak:
|
||||||
|
|
||||||
本机完整流程见:
|
```text
|
||||||
|
web_origin: http://172.19.0.245:18080
|
||||||
|
redirect_uri: http://172.19.0.245:18080/auth/callback
|
||||||
|
post_logout_redirect_uri: http://172.19.0.245:18080/logout/callback
|
||||||
|
```
|
||||||
|
|
||||||
- [部署指南.md](./部署指南.md)
|
The app validates Keycloak JWT issuer, signature, expiry, audience/azp, and nonce where applicable. The Keycloak `sub` claim is used as the application user ID.
|
||||||
|
|
||||||
域名、HTTPS、公网反向代理说明见:
|
## Build
|
||||||
|
|
||||||
- [域名配置指引.md](./域名配置指引.md)
|
|
||||||
|
|
||||||
最小配置变量:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export PROJECT_ROOT=/home/ivan/xuan/beaver_project
|
cd app-instance
|
||||||
export BEAVER_NET=beaver-instance-edge
|
docker build -t beaver/app-instance:keycloak-login .
|
||||||
export BEAVER_PROXY_CONTAINER_NAME=beaver-router-proxy
|
|
||||||
|
|
||||||
export BEAVER_DEPLOY_TOKEN="$(openssl rand -hex 32)"
|
|
||||||
export BEAVER_AUTHZ_INTERNAL_TOKEN="$(openssl rand -hex 32)"
|
|
||||||
|
|
||||||
export BEAVER_BASE_DOMAIN=localhost
|
|
||||||
export BEAVER_AUTHZ_URL='http://beaver-authz-service:19090'
|
|
||||||
export BEAVER_DEPLOY_URL='http://beaver-deploy-control:8090'
|
|
||||||
|
|
||||||
export BEAVER_OUTLOOK_MCP_URL=''
|
|
||||||
export BEAVER_OUTLOOK_MCP_SERVER_ID='outlook_mcp'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
启动顺序:
|
## Run Locally
|
||||||
|
|
||||||
1. 创建运行目录。
|
|
||||||
2. 构建四个镜像。
|
|
||||||
3. 创建共享 Docker network。
|
|
||||||
4. 启动 `router-proxy`。
|
|
||||||
5. 启动 `authz-service`。
|
|
||||||
6. 启动 `deploy-control`。
|
|
||||||
7. 启动 `auth-portal`。
|
|
||||||
8. 打开 `http://127.0.0.1:3081/register` 测试注册。
|
|
||||||
|
|
||||||
## 关键配置关系
|
|
||||||
|
|
||||||
`DEPLOY_API_TOKEN` 和 `DEPLOY_CONTROL_API_TOKEN` 必须相等:
|
|
||||||
|
|
||||||
- `auth-portal` / `authz-service` 用 `DEPLOY_API_TOKEN` 请求 `deploy-control`。
|
|
||||||
- `deploy-control` 用 `DEPLOY_CONTROL_API_TOKEN` 校验请求。
|
|
||||||
|
|
||||||
`AUTHZ_ISSUER` 在这套单机部署里要写容器网络地址:
|
|
||||||
|
|
||||||
```text
|
|
||||||
http://beaver-authz-service:19090
|
|
||||||
```
|
|
||||||
|
|
||||||
不要写成 `http://127.0.0.1:19090`,因为新创建的 `app-instance` 容器里的 `127.0.0.1` 指向它自己,不是 AuthZ 容器。
|
|
||||||
|
|
||||||
`DEPLOY_PUBLIC_*` 决定新实例展示给用户的 URL:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
DEPLOY_PUBLIC_SCHEME=http
|
cd app-instance
|
||||||
DEPLOY_PUBLIC_BASE_DOMAIN=localhost
|
./run-standalone.sh --build --replace
|
||||||
DEPLOY_PUBLIC_PORT=8088
|
|
||||||
```
|
```
|
||||||
|
|
||||||
本机测试时实例 URL 形如:
|
The script prints the `web_origin`, `redirect_uri`, and `post_logout_redirect_uri` values that must be configured in Keycloak.
|
||||||
|
|
||||||
```text
|
## Useful Paths
|
||||||
http://alice.localhost:8088
|
|
||||||
```
|
|
||||||
|
|
||||||
正式 HTTPS 域名通常改成:
|
- `app-instance/` - app image, frontend, backend, runtime helper script, app tests
|
||||||
|
- `skills/` - reusable skill content retained for product/runtime packaging
|
||||||
|
- `agents/` - agent registry retained for runtime packaging
|
||||||
|
- `docs/` - product and architecture documentation
|
||||||
|
- `AGENTS.md` - local coding-agent instructions
|
||||||
|
|
||||||
```bash
|
See `app-instance/README.md` for the app-specific contract and environment variables.
|
||||||
DEPLOY_PUBLIC_SCHEME=https
|
|
||||||
DEPLOY_PUBLIC_BASE_DOMAIN=apps.example.com
|
|
||||||
DEPLOY_PUBLIC_PORT=443
|
|
||||||
```
|
|
||||||
|
|
||||||
实例 URL 形如:
|
|
||||||
|
|
||||||
```text
|
|
||||||
https://alice.apps.example.com
|
|
||||||
```
|
|
||||||
|
|
||||||
前提是你已经在项目外层把 `*.apps.example.com` 的 `80/443` 流量转发到 `router-proxy:8088`。
|
|
||||||
|
|
||||||
## 模型配置方式
|
|
||||||
|
|
||||||
当前版本不会在注册创建实例时写入模型 provider、model 或 API key。
|
|
||||||
|
|
||||||
流程是:
|
|
||||||
|
|
||||||
1. 注册先创建一个不含模型凭证的实例。
|
|
||||||
2. `auth-portal` 进入模型配置引导页。
|
|
||||||
3. 用户确认后,Portal 调用 `deploy-control /api/instances/configure-provider`。
|
|
||||||
4. `deploy-control` 写入该实例的 `config.json` 并重启对应容器。
|
|
||||||
|
|
||||||
如果用户跳过引导,实例仍会创建成功,但后续需要在实例内补齐 provider 配置后才能正常调用模型。
|
|
||||||
|
|
||||||
## 持久化目录
|
|
||||||
|
|
||||||
至少保留:
|
|
||||||
|
|
||||||
```text
|
|
||||||
authz-service/runtime/data
|
|
||||||
app-instance/runtime/instances
|
|
||||||
app-instance/runtime/registry
|
|
||||||
router-proxy/runtime/conf.d
|
|
||||||
```
|
|
||||||
|
|
||||||
不要在需要保留账号、实例或配置时删除这些目录。
|
|
||||||
|
|
||||||
## 模板文件
|
|
||||||
|
|
||||||
可参考这些环境变量模板:
|
|
||||||
|
|
||||||
- [`.env.example`](./.env.example)
|
|
||||||
- [`auth-portal/src/.env.example`](./auth-portal/src/.env.example)
|
|
||||||
- [`authz-service/.env.example`](./authz-service/.env.example)
|
|
||||||
- [`deploy-control/.env.example`](./deploy-control/.env.example)
|
|
||||||
- [`router-proxy/.env.example`](./router-proxy/.env.example)
|
|
||||||
|
|
||||||
这些模板不会被脚本自动加载。你可以手动 `export`,也可以在 `docker run` 时使用 `--env-file`。
|
|
||||||
|
|
||||||
## 子项目文档
|
|
||||||
|
|
||||||
- [`app-instance/README.md`](./app-instance/README.md)
|
|
||||||
- [`auth-portal/src/README.md`](./auth-portal/src/README.md)
|
|
||||||
- [`authz-service/README.md`](./authz-service/README.md)
|
|
||||||
- [`deploy-control/README.md`](./deploy-control/README.md)
|
|
||||||
- [`router-proxy/README.md`](./router-proxy/README.md)
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
runtime/
|
runtime/
|
||||||
|
.env
|
||||||
backend/.git/
|
backend/.git/
|
||||||
backend/.venv/
|
backend/.venv/
|
||||||
backend/.pytest_cache/
|
backend/.pytest_cache/
|
||||||
@ -13,4 +14,3 @@ frontend/node_modules/
|
|||||||
frontend/.next/
|
frontend/.next/
|
||||||
frontend/.next-dev/
|
frontend/.next-dev/
|
||||||
frontend/tsconfig.tsbuildinfo
|
frontend/tsconfig.tsbuildinfo
|
||||||
|
|
||||||
|
|||||||
28
app-instance/.env.example
Normal file
28
app-instance/.env.example
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Docker image/runtime
|
||||||
|
IMAGE_NAME=beaver/app-instance:keycloak
|
||||||
|
CONTAINER_NAME=beaver-app-instance
|
||||||
|
HOST_IP=172.19.0.245
|
||||||
|
HOST_PORT=18080
|
||||||
|
HOST_BIND_IP=0.0.0.0
|
||||||
|
DATA_DIR=runtime/standalone
|
||||||
|
DOCKER_NETWORK=
|
||||||
|
|
||||||
|
# Beaver model config rendered into runtime/standalone/beaver-home/config.json
|
||||||
|
BEAVER_PROVIDER=custom
|
||||||
|
BEAVER_MODEL=
|
||||||
|
BEAVER_API_KEY=
|
||||||
|
BEAVER_API_BASE=
|
||||||
|
|
||||||
|
# Keycloak public client
|
||||||
|
KEYCLOAK_ISSUER=https://keycloak.bwgdi.com/realms/beaver
|
||||||
|
KEYCLOAK_CLIENT_ID=beaver-agnet
|
||||||
|
KEYCLOAK_REDIRECT_URI=
|
||||||
|
KEYCLOAK_POST_LOGOUT_REDIRECT_URI=
|
||||||
|
|
||||||
|
# Build mirrors
|
||||||
|
NPM_REGISTRY=https://registry.npmmirror.com
|
||||||
|
NPM_FETCH_RETRIES=5
|
||||||
|
NPM_FETCH_RETRY_MIN_TIMEOUT=20000
|
||||||
|
NPM_FETCH_RETRY_MAX_TIMEOUT=120000
|
||||||
|
APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/debian
|
||||||
|
PYPI_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
@ -20,11 +20,15 @@ RUN --mount=type=cache,target=/root/.npm \
|
|||||||
|
|
||||||
COPY frontend/ ./
|
COPY frontend/ ./
|
||||||
|
|
||||||
ARG NEXT_PUBLIC_AUTH_PORTAL_URL=""
|
ARG NEXT_PUBLIC_KEYCLOAK_ISSUER="https://keycloak.bwgdi.com/realms/beaver"
|
||||||
ARG NEXT_PUBLIC_AUTH_PORTAL_PORT="3081"
|
ARG NEXT_PUBLIC_KEYCLOAK_CLIENT_ID="beaver-agnet"
|
||||||
|
ARG NEXT_PUBLIC_KEYCLOAK_REDIRECT_URI=""
|
||||||
|
ARG NEXT_PUBLIC_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=""
|
||||||
|
|
||||||
ENV NEXT_PUBLIC_AUTH_PORTAL_URL=${NEXT_PUBLIC_AUTH_PORTAL_URL}
|
ENV NEXT_PUBLIC_KEYCLOAK_ISSUER=${NEXT_PUBLIC_KEYCLOAK_ISSUER}
|
||||||
ENV NEXT_PUBLIC_AUTH_PORTAL_PORT=${NEXT_PUBLIC_AUTH_PORTAL_PORT}
|
ENV NEXT_PUBLIC_KEYCLOAK_CLIENT_ID=${NEXT_PUBLIC_KEYCLOAK_CLIENT_ID}
|
||||||
|
ENV NEXT_PUBLIC_KEYCLOAK_REDIRECT_URI=${NEXT_PUBLIC_KEYCLOAK_REDIRECT_URI}
|
||||||
|
ENV NEXT_PUBLIC_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=${NEXT_PUBLIC_KEYCLOAK_POST_LOGOUT_REDIRECT_URI}
|
||||||
|
|
||||||
# API / WS 走同域反代,不在构建时写死实例地址。
|
# API / WS 走同域反代,不在构建时写死实例地址。
|
||||||
RUN npm run build
|
RUN npm run build
|
||||||
@ -39,7 +43,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
|||||||
BEAVER_HOME=/root/.beaver \
|
BEAVER_HOME=/root/.beaver \
|
||||||
BEAVER_CONFIG_PATH=/root/.beaver/config.json \
|
BEAVER_CONFIG_PATH=/root/.beaver/config.json \
|
||||||
BEAVER_WORKSPACE=/root/.beaver/workspace \
|
BEAVER_WORKSPACE=/root/.beaver/workspace \
|
||||||
BEAVER_AUTH_FILE=/root/.beaver/web_auth_users.json \
|
|
||||||
PORT=3000 \
|
PORT=3000 \
|
||||||
HOSTNAME=127.0.0.1
|
HOSTNAME=127.0.0.1
|
||||||
|
|
||||||
@ -47,8 +50,12 @@ ARG NPM_REGISTRY="https://registry.npmmirror.com"
|
|||||||
ARG NPM_FETCH_RETRIES="5"
|
ARG NPM_FETCH_RETRIES="5"
|
||||||
ARG NPM_FETCH_RETRY_MIN_TIMEOUT="20000"
|
ARG NPM_FETCH_RETRY_MIN_TIMEOUT="20000"
|
||||||
ARG NPM_FETCH_RETRY_MAX_TIMEOUT="120000"
|
ARG NPM_FETCH_RETRY_MAX_TIMEOUT="120000"
|
||||||
|
ARG APT_MIRROR="https://mirrors.tuna.tsinghua.edu.cn/debian"
|
||||||
|
ARG PYPI_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN find /etc/apt -type f \( -name "*.list" -o -name "*.sources" \) -exec \
|
||||||
|
sed -i "s|http://deb.debian.org/debian-security|${APT_MIRROR}-security|g; s|http://deb.debian.org/debian|${APT_MIRROR}|g; s|http://security.debian.org/debian-security|${APT_MIRROR}-security|g" {} + && \
|
||||||
|
apt-get update && \
|
||||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git nginx dumb-init && \
|
apt-get install -y --no-install-recommends curl ca-certificates gnupg git nginx dumb-init && \
|
||||||
mkdir -p /etc/apt/keyrings && \
|
mkdir -p /etc/apt/keyrings && \
|
||||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||||
@ -63,7 +70,7 @@ WORKDIR /opt/app/backend
|
|||||||
|
|
||||||
COPY backend/pyproject.toml backend/README.md ./
|
COPY backend/pyproject.toml backend/README.md ./
|
||||||
COPY backend/beaver/ ./beaver/
|
COPY backend/beaver/ ./beaver/
|
||||||
RUN uv pip install --system --no-cache .
|
RUN uv pip install --system --no-cache --index-url "${PYPI_INDEX_URL}" ".[channels]"
|
||||||
|
|
||||||
WORKDIR /opt/app/frontend
|
WORKDIR /opt/app/frontend
|
||||||
COPY --from=frontend-builder /build/frontend/next.config.js ./
|
COPY --from=frontend-builder /build/frontend/next.config.js ./
|
||||||
|
|||||||
@ -1,118 +1,56 @@
|
|||||||
# app-instance
|
# app-instance
|
||||||
|
|
||||||
单实例应用单元:
|
单体 app instance:一个 Docker 容器内运行前端、后端和 Nginx。这个目录只提供干净的应用实例,不再负责用户注册登录、实例编排、注册表或反向代理路由。
|
||||||
|
|
||||||
- 一个 Docker 容器里同时运行前端、后端和 Nginx 反代
|
|
||||||
- 前端走 `/`
|
|
||||||
- 后端 API 走 `/api`
|
|
||||||
- WebSocket 走 `/ws`
|
|
||||||
|
|
||||||
## 关键文件
|
## 关键文件
|
||||||
|
|
||||||
- `Dockerfile`
|
- `Dockerfile`: 统一镜像构建入口
|
||||||
- 统一镜像构建入口
|
- `entrypoint.sh`: 容器内启动前端、后端、Nginx
|
||||||
- `entrypoint.sh`
|
- `run-standalone.sh`: 从 `.env` 读取参数,构建并运行单体容器
|
||||||
- 容器内启动前端、后端、Nginx
|
- `.env.example`: 可复制的部署参数模板
|
||||||
- `create-instance.sh`
|
|
||||||
- 创建实例目录、生成配置、启动容器、写注册表
|
|
||||||
- `remove-instance.sh`
|
|
||||||
- 删除容器、移除注册表、可选清理实例目录
|
|
||||||
- `list-instances.sh`
|
|
||||||
- 查看当前注册实例
|
|
||||||
- `instance-registry.py`
|
|
||||||
- 维护 `runtime/registry/instances.json`
|
|
||||||
|
|
||||||
## 注册表
|
## Keycloak SSO
|
||||||
|
|
||||||
默认注册表路径:
|
当前登录交给 Keycloak:
|
||||||
|
|
||||||
|
- issuer: `https://keycloak.bwgdi.com/realms/beaver`
|
||||||
|
- client_id: `beaver-agnet`
|
||||||
|
- flow: Authorization Code + PKCE S256
|
||||||
|
- password grant: disabled
|
||||||
|
- client secret: none
|
||||||
|
|
||||||
|
应用需要在 Keycloak client 中配置:
|
||||||
|
|
||||||
|
- `redirect_uri`: `http://<host-ip>:<host-port>/auth/callback`
|
||||||
|
- `post_logout_redirect_uri`: `http://<host-ip>:<host-port>/logout/callback`
|
||||||
|
- `web_origin`: `http://<host-ip>:<host-port>`
|
||||||
|
|
||||||
|
默认单机部署示例:
|
||||||
|
|
||||||
```text
|
```text
|
||||||
runtime/registry/instances.json
|
redirect_uri: http://172.19.0.245:18080/auth/callback
|
||||||
|
post_logout_redirect_uri: http://172.19.0.245:18080/logout/callback
|
||||||
|
web_origin: http://172.19.0.245:18080
|
||||||
```
|
```
|
||||||
|
|
||||||
每条记录至少包含:
|
## 部署
|
||||||
|
|
||||||
- `instance_id`
|
|
||||||
- `instance_slug`
|
|
||||||
- `container_name`
|
|
||||||
- `host_port`
|
|
||||||
- `public_url`
|
|
||||||
- `instance_root`
|
|
||||||
- `image_name`
|
|
||||||
|
|
||||||
## 常用命令
|
|
||||||
|
|
||||||
### 1. 构建镜像
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker build -t beaver/app-instance:latest .
|
cp .env.example .env
|
||||||
|
./run-standalone.sh --build --replace
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. 创建实例
|
脚本会输出当前需要填到 Keycloak 的 `redirect_uri`、`post_logout_redirect_uri` 和 `web_origin`。
|
||||||
|
|
||||||
```bash
|
## 数据目录
|
||||||
./create-instance.sh \
|
|
||||||
--image beaver/app-instance:latest \
|
|
||||||
--instance-id demo-001 \
|
|
||||||
--auth-username admin \
|
|
||||||
--auth-password 123456 \
|
|
||||||
--api-key 'your-api-key'
|
|
||||||
```
|
|
||||||
|
|
||||||
可选参数:
|
默认数据目录:
|
||||||
|
|
||||||
- `--host-port`
|
|
||||||
- `--public-url`
|
|
||||||
- `--username`
|
|
||||||
- `--email`
|
|
||||||
- `--instance-host`
|
|
||||||
- `--authz-base-url`
|
|
||||||
- `--backend-id`
|
|
||||||
- `--client-id`
|
|
||||||
- `--client-secret`
|
|
||||||
- `--network`
|
|
||||||
- `--host-bind-ip`
|
|
||||||
- `--build`
|
|
||||||
- `--replace`
|
|
||||||
|
|
||||||
### 3. 查看实例
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./list-instances.sh
|
|
||||||
./list-instances.sh --json
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 删除实例
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./remove-instance.sh --instance-id demo-001
|
|
||||||
```
|
|
||||||
|
|
||||||
如果要把实例目录也一并清掉:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./remove-instance.sh --instance-id demo-001 --purge-data
|
|
||||||
```
|
|
||||||
|
|
||||||
## 目录约定
|
|
||||||
|
|
||||||
默认实例数据目录:
|
|
||||||
|
|
||||||
```text
|
```text
|
||||||
runtime/instances/<instance-slug>/
|
runtime/standalone/beaver-home/
|
||||||
```
|
```
|
||||||
|
|
||||||
其中会生成:
|
容器内挂载到:
|
||||||
|
|
||||||
```text
|
|
||||||
runtime/instances/<instance-slug>/
|
|
||||||
└── beaver-home
|
|
||||||
├── config.json
|
|
||||||
├── web_auth_users.json
|
|
||||||
└── workspace/
|
|
||||||
```
|
|
||||||
|
|
||||||
这个目录是单用户 sandbox 的配置与数据边界。容器内会把它挂到:
|
|
||||||
|
|
||||||
```text
|
```text
|
||||||
/root/.beaver/
|
/root/.beaver/
|
||||||
@ -125,28 +63,4 @@ BEAVER_CONFIG_PATH=/root/.beaver/config.json
|
|||||||
BEAVER_WORKSPACE=/root/.beaver/workspace
|
BEAVER_WORKSPACE=/root/.beaver/workspace
|
||||||
```
|
```
|
||||||
|
|
||||||
所以模型 `provider/api_key/api_base/model` 配一次即可,Web / channel 请求不需要、也不应该携带 API Key。
|
模型 provider、API key、workspace 数据都在这个目录下维护。
|
||||||
|
|
||||||
## 当前状态
|
|
||||||
|
|
||||||
这层已经支持:
|
|
||||||
|
|
||||||
- 统一镜像构建
|
|
||||||
- 镜像内安装并启动新的 `beaver` 后端
|
|
||||||
- 实例创建
|
|
||||||
- 实例删除
|
|
||||||
- 实例列表
|
|
||||||
- 基于注册表的端口分配
|
|
||||||
- 为 deploy-control / router-proxy 记录用户名和实例 host
|
|
||||||
|
|
||||||
## 生产注意
|
|
||||||
|
|
||||||
- 实例容器的宿主机端口默认只绑定 `127.0.0.1`
|
|
||||||
- 外部访问应统一走 `router-proxy`
|
|
||||||
- 如果你确实要把单个实例端口直接暴露到公网,再显式传 `--host-bind-ip 0.0.0.0`
|
|
||||||
|
|
||||||
下一步可以继续接:
|
|
||||||
|
|
||||||
- portal 调用创建实例
|
|
||||||
- URL 分配和反向代理
|
|
||||||
- 实例续期 / 停用 / 启用
|
|
||||||
|
|||||||
4
app-instance/backend/agents/registry.json
Normal file
4
app-instance/backend/agents/registry.json
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"agents": [],
|
||||||
|
"version": 1
|
||||||
|
}
|
||||||
@ -18,8 +18,9 @@ if TYPE_CHECKING:
|
|||||||
class TeamGraphScheduler:
|
class TeamGraphScheduler:
|
||||||
"""Execute sequence, parallel, and DAG team graphs."""
|
"""Execute sequence, parallel, and DAG team graphs."""
|
||||||
|
|
||||||
def __init__(self, runner: LocalAgentRunner) -> None:
|
def __init__(self, runner: LocalAgentRunner, *, max_parallel_team_nodes: int = 3) -> None:
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
|
self.max_parallel_team_nodes = max(1, int(max_parallel_team_nodes))
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@ -96,7 +97,18 @@ class TeamGraphScheduler:
|
|||||||
nodes: list[ExecutionNode],
|
nodes: list[ExecutionNode],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[NodeRunResult]:
|
) -> list[NodeRunResult]:
|
||||||
return list(await asyncio.gather(*(self._run_node(node, dependency_outputs={}, **kwargs) for node in nodes)))
|
semaphore = asyncio.Semaphore(self.max_parallel_team_nodes)
|
||||||
|
|
||||||
|
async def run_one(node: ExecutionNode) -> NodeRunResult:
|
||||||
|
async with semaphore:
|
||||||
|
return await self._run_node(
|
||||||
|
node,
|
||||||
|
dependency_outputs={},
|
||||||
|
execution_mode="isolated_loop",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(await asyncio.gather(*(run_one(node) for node in nodes)))
|
||||||
|
|
||||||
async def _run_dag(
|
async def _run_dag(
|
||||||
self,
|
self,
|
||||||
@ -164,6 +176,7 @@ class TeamGraphScheduler:
|
|||||||
inherited_pinned_skill_contexts: list["SkillContext"],
|
inherited_pinned_skill_contexts: list["SkillContext"],
|
||||||
allow_candidate_generation: bool,
|
allow_candidate_generation: bool,
|
||||||
dependency_outputs: dict[str, str],
|
dependency_outputs: dict[str, str],
|
||||||
|
execution_mode: str = "shared_loop",
|
||||||
) -> NodeRunResult:
|
) -> NodeRunResult:
|
||||||
try:
|
try:
|
||||||
pinned = self._merge_pinned(inherited_pinned_skills, node.inherited_pinned_skills)
|
pinned = self._merge_pinned(inherited_pinned_skills, node.inherited_pinned_skills)
|
||||||
@ -189,6 +202,7 @@ class TeamGraphScheduler:
|
|||||||
envelope,
|
envelope,
|
||||||
provider_bundle=node_provider_bundle,
|
provider_bundle=node_provider_bundle,
|
||||||
allow_candidate_generation=allow_candidate_generation,
|
allow_candidate_generation=allow_candidate_generation,
|
||||||
|
execution_mode=execution_mode,
|
||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
@ -241,7 +255,7 @@ class TeamGraphScheduler:
|
|||||||
failed = [item for item in results if not item.success]
|
failed = [item for item in results if not item.success]
|
||||||
if failed:
|
if failed:
|
||||||
failure_lines = [
|
failure_lines = [
|
||||||
f"- {item.node_id}: {item.error or item.finish_reason}"
|
f"- {item.node_id}: {item.error or item.finish_reason} evidence={'yes' if item.evidence else 'no'}"
|
||||||
for item in failed
|
for item in failed
|
||||||
]
|
]
|
||||||
summary_parts.append("Failed nodes:\n" + "\n".join(failure_lines))
|
summary_parts.append("Failed nodes:\n" + "\n".join(failure_lines))
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from beaver.engine import AgentLoop
|
from beaver.engine import AgentLoop
|
||||||
from beaver.engine.providers import ProviderBundle
|
from beaver.engine.providers import ProviderBundle
|
||||||
|
from beaver.tasks.evidence import EvidenceBuilder
|
||||||
|
|
||||||
from .models import DelegationEnvelope, NodeRunResult
|
from .models import DelegationEnvelope, NodeRunResult
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ class LocalAgentRunner:
|
|||||||
*,
|
*,
|
||||||
provider_bundle: ProviderBundle | None = None,
|
provider_bundle: ProviderBundle | None = None,
|
||||||
allow_candidate_generation: bool = False,
|
allow_candidate_generation: bool = False,
|
||||||
|
execution_mode: str = "shared_loop",
|
||||||
) -> NodeRunResult:
|
) -> NodeRunResult:
|
||||||
if provider_bundle is not None and (envelope.agent.model or envelope.agent.provider_name):
|
if provider_bundle is not None and (envelope.agent.model or envelope.agent.provider_name):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -29,7 +31,14 @@ class LocalAgentRunner:
|
|||||||
"build a node-specific provider bundle instead."
|
"build a node-specific provider bundle instead."
|
||||||
)
|
)
|
||||||
child_session_id = self._child_session_id(envelope)
|
child_session_id = self._child_session_id(envelope)
|
||||||
runner = self.loop.submit_direct if self.loop.is_running else self.loop.process_direct
|
target_loop = self.loop
|
||||||
|
if execution_mode == "isolated_loop":
|
||||||
|
target_loop = AgentLoop(profile=self.loop.profile, loader=self.loop.loader)
|
||||||
|
runner = (
|
||||||
|
target_loop.process_direct
|
||||||
|
if execution_mode == "isolated_loop"
|
||||||
|
else (self.loop.submit_direct if self.loop.is_running else self.loop.process_direct)
|
||||||
|
)
|
||||||
result = await runner(
|
result = await runner(
|
||||||
envelope.task,
|
envelope.task,
|
||||||
session_id=child_session_id,
|
session_id=child_session_id,
|
||||||
@ -47,6 +56,13 @@ class LocalAgentRunner:
|
|||||||
pinned_skill_contexts=envelope.inherited_pinned_skill_contexts,
|
pinned_skill_contexts=envelope.inherited_pinned_skill_contexts,
|
||||||
allow_candidate_generation=allow_candidate_generation,
|
allow_candidate_generation=allow_candidate_generation,
|
||||||
)
|
)
|
||||||
|
loaded = target_loop.boot()
|
||||||
|
evidence = EvidenceBuilder(loaded.session_manager).build_run_evidence(
|
||||||
|
result.session_id,
|
||||||
|
result.run_id,
|
||||||
|
result.output_text,
|
||||||
|
result.finish_reason,
|
||||||
|
)
|
||||||
success = result.finish_reason == "stop"
|
success = result.finish_reason == "stop"
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
node_id=envelope.node_id or envelope.agent.name,
|
node_id=envelope.node_id or envelope.agent.name,
|
||||||
@ -56,6 +72,7 @@ class LocalAgentRunner:
|
|||||||
session_id=result.session_id,
|
session_id=result.session_id,
|
||||||
finish_reason=result.finish_reason,
|
finish_reason=result.finish_reason,
|
||||||
error=None if success else (result.output_text or result.finish_reason),
|
error=None if success else (result.output_text or result.finish_reason),
|
||||||
|
evidence=evidence,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Literal
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from beaver.engine.context import SkillContext
|
from beaver.engine.context import SkillContext
|
||||||
|
from beaver.tasks.evidence import RunEvidence
|
||||||
|
|
||||||
|
|
||||||
TeamStrategy = Literal[
|
TeamStrategy = Literal[
|
||||||
@ -116,6 +117,7 @@ class NodeRunResult:
|
|||||||
session_id: str | None = None
|
session_id: str | None = None
|
||||||
finish_reason: str = "stop"
|
finish_reason: str = "stop"
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
evidence: "RunEvidence | None" = None
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@ -126,6 +128,7 @@ class NodeRunResult:
|
|||||||
"session_id": self.session_id,
|
"session_id": self.session_id,
|
||||||
"finish_reason": self.finish_reason,
|
"finish_reason": self.finish_reason,
|
||||||
"error": self.error,
|
"error": self.error,
|
||||||
|
"evidence": self.evidence.to_dict() if self.evidence is not None else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,9 @@ class AgentRegistry:
|
|||||||
self.path = self.workspace / "agents" / "registry.json"
|
self.path = self.workspace / "agents" / "registry.json"
|
||||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
if not self.path.exists():
|
if not self.path.exists():
|
||||||
self._write_agents(_builtin_agents())
|
self._write_agents([])
|
||||||
|
else:
|
||||||
|
self._drop_legacy_builtin_agents()
|
||||||
|
|
||||||
def list_agents(self, *, include_disabled: bool = True) -> list[RegisteredAgent]:
|
def list_agents(self, *, include_disabled: bool = True) -> list[RegisteredAgent]:
|
||||||
agents = self._read_agents()
|
agents = self._read_agents()
|
||||||
@ -125,72 +127,14 @@ class AgentRegistry:
|
|||||||
payload = {"version": 1, "agents": [agent.to_dict() for agent in agents]}
|
payload = {"version": 1, "agents": [agent.to_dict() for agent in agents]}
|
||||||
self.path.write_text(json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
self.path.write_text(json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||||
|
|
||||||
|
def _drop_legacy_builtin_agents(self) -> None:
|
||||||
|
agents = self._read_agents()
|
||||||
|
migrated = [agent for agent in agents if agent.source != "builtin"]
|
||||||
|
if len(migrated) != len(agents):
|
||||||
|
self._write_agents(migrated)
|
||||||
|
|
||||||
|
|
||||||
def _terms(text: str) -> set[str]:
|
def _terms(text: str) -> set[str]:
|
||||||
normalized = "".join(ch.lower() if ch.isalnum() else " " for ch in text)
|
normalized = "".join(ch.lower() if ch.isalnum() else " " for ch in text)
|
||||||
return {part for part in normalized.split() if part}
|
return {part for part in normalized.split() if part}
|
||||||
|
|
||||||
|
|
||||||
def _builtin_agents() -> list[RegisteredAgent]:
|
|
||||||
return [
|
|
||||||
RegisteredAgent(
|
|
||||||
agent_id="researcher",
|
|
||||||
name="researcher",
|
|
||||||
display_name="Researcher",
|
|
||||||
role="research",
|
|
||||||
description="Finds facts, references, constraints, and implementation options.",
|
|
||||||
system_prompt="You are a research specialist. Gather concise evidence and tradeoffs for the parent task.",
|
|
||||||
capabilities=["research", "analysis", "source review", "requirements"],
|
|
||||||
tags=["planning", "research"],
|
|
||||||
priority=50,
|
|
||||||
source="builtin",
|
|
||||||
),
|
|
||||||
RegisteredAgent(
|
|
||||||
agent_id="implementer",
|
|
||||||
name="implementer",
|
|
||||||
display_name="Implementer",
|
|
||||||
role="implementation",
|
|
||||||
description="Builds scoped implementation slices and proposes concrete changes.",
|
|
||||||
system_prompt="You are an implementation specialist. Produce practical, scoped implementation output.",
|
|
||||||
capabilities=["implementation", "coding", "refactor", "integration"],
|
|
||||||
tags=["coding", "build"],
|
|
||||||
priority=45,
|
|
||||||
source="builtin",
|
|
||||||
),
|
|
||||||
RegisteredAgent(
|
|
||||||
agent_id="reviewer",
|
|
||||||
name="reviewer",
|
|
||||||
display_name="Reviewer",
|
|
||||||
role="review",
|
|
||||||
description="Reviews plans, code, outputs, and risks before final synthesis.",
|
|
||||||
system_prompt="You are a review specialist. Focus on defects, missing requirements, and risks.",
|
|
||||||
capabilities=["review", "quality", "risk", "verification"],
|
|
||||||
tags=["review", "quality"],
|
|
||||||
priority=45,
|
|
||||||
source="builtin",
|
|
||||||
),
|
|
||||||
RegisteredAgent(
|
|
||||||
agent_id="tester",
|
|
||||||
name="tester",
|
|
||||||
display_name="Tester",
|
|
||||||
role="testing",
|
|
||||||
description="Designs and executes verification checks for task outputs.",
|
|
||||||
system_prompt="You are a testing specialist. Identify focused checks and report pass/fail evidence.",
|
|
||||||
capabilities=["testing", "verification", "regression", "qa"],
|
|
||||||
tags=["test", "quality"],
|
|
||||||
priority=40,
|
|
||||||
source="builtin",
|
|
||||||
),
|
|
||||||
RegisteredAgent(
|
|
||||||
agent_id="documenter",
|
|
||||||
name="documenter",
|
|
||||||
display_name="Documenter",
|
|
||||||
role="documentation",
|
|
||||||
description="Writes and reconciles user-facing and internal documentation updates.",
|
|
||||||
system_prompt="You are a documentation specialist. Produce concise docs aligned with the implementation.",
|
|
||||||
capabilities=["documentation", "explanation", "migration notes", "release notes"],
|
|
||||||
tags=["docs", "communication"],
|
|
||||||
priority=35,
|
|
||||||
source="builtin",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from .builder import (
|
|||||||
ContextBuildInput,
|
ContextBuildInput,
|
||||||
ContextBuildResult,
|
ContextBuildResult,
|
||||||
ContextBuilder,
|
ContextBuilder,
|
||||||
|
RuntimeContext,
|
||||||
SessionContext,
|
SessionContext,
|
||||||
SkillContext,
|
SkillContext,
|
||||||
)
|
)
|
||||||
@ -12,6 +13,7 @@ __all__ = [
|
|||||||
"ContextBuildInput",
|
"ContextBuildInput",
|
||||||
"ContextBuildResult",
|
"ContextBuildResult",
|
||||||
"ContextBuilder",
|
"ContextBuilder",
|
||||||
|
"RuntimeContext",
|
||||||
"SessionContext",
|
"SessionContext",
|
||||||
"SkillContext",
|
"SkillContext",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -27,13 +27,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from beaver.memory.curated.snapshot import MemorySnapshot
|
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||||
|
from beaver.prompts import get_main_agent_prompt
|
||||||
|
|
||||||
BEAVER_USER_ASSISTANT_IDENTITY_PROMPT = (
|
|
||||||
"You are 海狸 (Beaver), an AI assistant developed by 博维资讯系统有限公司. "
|
|
||||||
"When communicating with users, keep this identity consistent. "
|
|
||||||
"If users ask who you are, say that you are 海狸 (Beaver), 博维资讯系统有限公司研发的 AI 助手."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -76,10 +70,25 @@ class SessionContext:
|
|||||||
model: str | None = None
|
model: str | None = None
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
channel: str | None = None
|
channel: str | None = None
|
||||||
|
channel_kind: str | None = None
|
||||||
|
account_id: str | None = None
|
||||||
|
peer_id: str | None = None
|
||||||
|
peer_type: str | None = None
|
||||||
chat_id: str | None = None
|
chat_id: str | None = None
|
||||||
|
thread_id: str | None = None
|
||||||
parent_session_id: str | None = None
|
parent_session_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class RuntimeContext:
|
||||||
|
"""Per-run runtime facts that should be visible to the model."""
|
||||||
|
|
||||||
|
utc_datetime: str
|
||||||
|
local_datetime: str
|
||||||
|
timezone: str | None = None
|
||||||
|
utc_offset: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class ContextBuildInput:
|
class ContextBuildInput:
|
||||||
"""一次上下文构建所需的全部输入。
|
"""一次上下文构建所需的全部输入。
|
||||||
@ -98,11 +107,13 @@ class ContextBuildInput:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
base_system_prompt: str = ""
|
base_system_prompt: str = ""
|
||||||
|
prompt_locale: str | None = None
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
current_user_input: str | list[dict[str, Any]] | None = None
|
current_user_input: str | list[dict[str, Any]] | None = None
|
||||||
memory_snapshot: MemorySnapshot | None = None
|
memory_snapshot: MemorySnapshot | None = None
|
||||||
activated_skills: list[SkillContext] = field(default_factory=list)
|
activated_skills: list[SkillContext] = field(default_factory=list)
|
||||||
session_context: SessionContext | None = None
|
session_context: SessionContext | None = None
|
||||||
|
runtime_context: RuntimeContext | None = None
|
||||||
execution_context: str | None = None
|
execution_context: str | None = None
|
||||||
extra_sections: list[str] = field(default_factory=list)
|
extra_sections: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
@ -143,9 +154,10 @@ class ContextBuilder:
|
|||||||
1. Beaver user-facing assistant identity
|
1. Beaver user-facing assistant identity
|
||||||
2. base system prompt
|
2. base system prompt
|
||||||
3. session metadata
|
3. session metadata
|
||||||
4. execution context
|
4. runtime date/time
|
||||||
5. frozen memory snapshot
|
5. execution context
|
||||||
6. extra sections
|
6. frozen memory snapshot
|
||||||
|
7. extra sections
|
||||||
|
|
||||||
这样设计的原因:
|
这样设计的原因:
|
||||||
- 身份与总规则要最靠前
|
- 身份与总规则要最靠前
|
||||||
@ -154,7 +166,7 @@ class ContextBuilder:
|
|||||||
- activated skill 正文放到显式消息里,避免 system prompt 持续膨胀
|
- activated skill 正文放到显式消息里,避免 system prompt 持续膨胀
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sections: list[str] = [BEAVER_USER_ASSISTANT_IDENTITY_PROMPT]
|
sections: list[str] = [get_main_agent_prompt(build_input.prompt_locale)]
|
||||||
|
|
||||||
base_system_prompt = (build_input.base_system_prompt or "").strip()
|
base_system_prompt = (build_input.base_system_prompt or "").strip()
|
||||||
if base_system_prompt:
|
if base_system_prompt:
|
||||||
@ -164,6 +176,10 @@ class ContextBuilder:
|
|||||||
if session_section:
|
if session_section:
|
||||||
sections.append(session_section)
|
sections.append(session_section)
|
||||||
|
|
||||||
|
runtime_section = self._render_runtime_section(build_input.runtime_context)
|
||||||
|
if runtime_section:
|
||||||
|
sections.append(runtime_section)
|
||||||
|
|
||||||
execution_context = (build_input.execution_context or "").strip()
|
execution_context = (build_input.execution_context or "").strip()
|
||||||
if execution_context:
|
if execution_context:
|
||||||
sections.append(f"# Execution Context\n\n{execution_context}")
|
sections.append(f"# Execution Context\n\n{execution_context}")
|
||||||
@ -338,8 +354,18 @@ class ContextBuilder:
|
|||||||
rows.append(f"User ID: {session_context.user_id}")
|
rows.append(f"User ID: {session_context.user_id}")
|
||||||
if session_context.channel:
|
if session_context.channel:
|
||||||
rows.append(f"Channel: {session_context.channel}")
|
rows.append(f"Channel: {session_context.channel}")
|
||||||
|
if session_context.channel_kind:
|
||||||
|
rows.append(f"Channel Kind: {session_context.channel_kind}")
|
||||||
|
if session_context.account_id:
|
||||||
|
rows.append(f"Account ID: {session_context.account_id}")
|
||||||
|
if session_context.peer_id:
|
||||||
|
rows.append(f"Peer ID: {session_context.peer_id}")
|
||||||
|
if session_context.peer_type:
|
||||||
|
rows.append(f"Peer Type: {session_context.peer_type}")
|
||||||
if session_context.chat_id:
|
if session_context.chat_id:
|
||||||
rows.append(f"Chat ID: {session_context.chat_id}")
|
rows.append(f"Chat ID: {session_context.chat_id}")
|
||||||
|
if session_context.thread_id:
|
||||||
|
rows.append(f"Thread ID: {session_context.thread_id}")
|
||||||
if session_context.parent_session_id:
|
if session_context.parent_session_id:
|
||||||
rows.append(f"Parent Session ID: {session_context.parent_session_id}")
|
rows.append(f"Parent Session ID: {session_context.parent_session_id}")
|
||||||
|
|
||||||
@ -347,6 +373,31 @@ class ContextBuilder:
|
|||||||
return None
|
return None
|
||||||
return "# Current Session\n\n" + "\n".join(rows)
|
return "# Current Session\n\n" + "\n".join(rows)
|
||||||
|
|
||||||
|
def _render_runtime_section(self, runtime_context: RuntimeContext | None) -> str | None:
|
||||||
|
"""Render date/time facts captured for the current model run."""
|
||||||
|
|
||||||
|
if runtime_context is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
rows: list[str] = []
|
||||||
|
if runtime_context.utc_datetime:
|
||||||
|
rows.append(f"Current UTC time: {runtime_context.utc_datetime}")
|
||||||
|
if runtime_context.local_datetime:
|
||||||
|
rows.append(f"Current local time: {runtime_context.local_datetime}")
|
||||||
|
if runtime_context.timezone:
|
||||||
|
rows.append(f"Local timezone: {runtime_context.timezone}")
|
||||||
|
if runtime_context.utc_offset:
|
||||||
|
rows.append(f"Local UTC offset: {runtime_context.utc_offset}")
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
return (
|
||||||
|
"# Current Date and Time\n\n"
|
||||||
|
+ "\n".join(rows)
|
||||||
|
+ "\n\nUse this section as authoritative for relative date/time references such as "
|
||||||
|
'"today", "tomorrow", "now", "this week", and "next month".'
|
||||||
|
)
|
||||||
|
|
||||||
def build_skill_activation_messages(self, activated_skills: list[SkillContext]) -> list[dict[str, str]]:
|
def build_skill_activation_messages(self, activated_skills: list[SkillContext]) -> list[dict[str, str]]:
|
||||||
"""把已激活 skill 转成显式消息。
|
"""把已激活 skill 转成显式消息。
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from beaver.skills.learning.eval import SkillDraftEvaluator
|
|||||||
from beaver.skills.publisher import SkillPublisher
|
from beaver.skills.publisher import SkillPublisher
|
||||||
from beaver.skills.reviews import ReviewService
|
from beaver.skills.reviews import ReviewService
|
||||||
from beaver.skills.specs import SkillSpecStore
|
from beaver.skills.specs import SkillSpecStore
|
||||||
from beaver.tasks import TaskExecutionPlanner, TaskService, ValidationService
|
from beaver.tasks import TaskExecutionPlanner, TaskService
|
||||||
from beaver.tasks.skill_resolver import TaskSkillResolver
|
from beaver.tasks.skill_resolver import TaskSkillResolver
|
||||||
from beaver.skills import SkillAssembler, SkillsLoader
|
from beaver.skills import SkillAssembler, SkillsLoader
|
||||||
from beaver.tools import ObjectBackedTool, ToolAssembler, ToolExecutor, ToolRegistry
|
from beaver.tools import ObjectBackedTool, ToolAssembler, ToolExecutor, ToolRegistry
|
||||||
@ -44,6 +44,7 @@ from beaver.tools.builtins import (
|
|||||||
SpawnTool,
|
SpawnTool,
|
||||||
SessionSearchTool,
|
SessionSearchTool,
|
||||||
SkillManageTool,
|
SkillManageTool,
|
||||||
|
SkillViewTool,
|
||||||
SkillsListTool,
|
SkillsListTool,
|
||||||
TerminalTool,
|
TerminalTool,
|
||||||
TodoTool,
|
TodoTool,
|
||||||
@ -97,7 +98,6 @@ class EngineLoadResult:
|
|||||||
task_skill_resolver: TaskSkillResolver | None = None
|
task_skill_resolver: TaskSkillResolver | None = None
|
||||||
task_service: TaskService | None = None
|
task_service: TaskService | None = None
|
||||||
task_execution_planner: TaskExecutionPlanner | None = None
|
task_execution_planner: TaskExecutionPlanner | None = None
|
||||||
validation_service: ValidationService | None = None
|
|
||||||
mcp_manager: MCPConnectionManager | None = None
|
mcp_manager: MCPConnectionManager | None = None
|
||||||
mcp_report: dict[str, dict] = field(default_factory=dict)
|
mcp_report: dict[str, dict] = field(default_factory=dict)
|
||||||
closeables: list[tuple[str, Callable[[], None]]] = field(default_factory=list, repr=False)
|
closeables: list[tuple[str, Callable[[], None]]] = field(default_factory=list, repr=False)
|
||||||
@ -172,7 +172,6 @@ class EngineLoader:
|
|||||||
task_skill_resolver: TaskSkillResolver | None = None,
|
task_skill_resolver: TaskSkillResolver | None = None,
|
||||||
task_service: TaskService | None = None,
|
task_service: TaskService | None = None,
|
||||||
task_execution_planner: TaskExecutionPlanner | None = None,
|
task_execution_planner: TaskExecutionPlanner | None = None,
|
||||||
validation_service: ValidationService | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.config = config or load_config(workspace=workspace, config_path=config_path)
|
self.config = config or load_config(workspace=workspace, config_path=config_path)
|
||||||
configured_workspace = self.config.agents_defaults.workspace
|
configured_workspace = self.config.agents_defaults.workspace
|
||||||
@ -198,7 +197,6 @@ class EngineLoader:
|
|||||||
self._task_skill_resolver = task_skill_resolver
|
self._task_skill_resolver = task_skill_resolver
|
||||||
self._task_service = task_service
|
self._task_service = task_service
|
||||||
self._task_execution_planner = task_execution_planner
|
self._task_execution_planner = task_execution_planner
|
||||||
self._validation_service = validation_service
|
|
||||||
|
|
||||||
def load(self) -> EngineLoadResult:
|
def load(self) -> EngineLoadResult:
|
||||||
"""装配当前主链需要的最小 runtime 对象。"""
|
"""装配当前主链需要的最小 runtime 对象。"""
|
||||||
@ -245,6 +243,7 @@ class EngineLoader:
|
|||||||
ObjectBackedTool(DelegateTool()),
|
ObjectBackedTool(DelegateTool()),
|
||||||
ObjectBackedTool(SpawnTool()),
|
ObjectBackedTool(SpawnTool()),
|
||||||
SkillsListTool(),
|
SkillsListTool(),
|
||||||
|
ObjectBackedTool(SkillViewTool(loader=skills_loader)),
|
||||||
SkillManageTool(),
|
SkillManageTool(),
|
||||||
CronTool(),
|
CronTool(),
|
||||||
]
|
]
|
||||||
@ -288,7 +287,6 @@ class EngineLoader:
|
|||||||
)
|
)
|
||||||
task_service = self._task_service or TaskService(workspace / "tasks")
|
task_service = self._task_service or TaskService(workspace / "tasks")
|
||||||
task_execution_planner = self._task_execution_planner or TaskExecutionPlanner(task_skill_resolver=task_skill_resolver)
|
task_execution_planner = self._task_execution_planner or TaskExecutionPlanner(task_skill_resolver=task_skill_resolver)
|
||||||
validation_service = self._validation_service or ValidationService()
|
|
||||||
mcp_manager = MCPConnectionManager(
|
mcp_manager = MCPConnectionManager(
|
||||||
self.config.tools.mcp_servers,
|
self.config.tools.mcp_servers,
|
||||||
authz_config=self.config.authz,
|
authz_config=self.config.authz,
|
||||||
@ -323,7 +321,6 @@ class EngineLoader:
|
|||||||
task_skill_resolver=task_skill_resolver,
|
task_skill_resolver=task_skill_resolver,
|
||||||
task_service=task_service,
|
task_service=task_service,
|
||||||
task_execution_planner=task_execution_planner,
|
task_execution_planner=task_execution_planner,
|
||||||
validation_service=validation_service,
|
|
||||||
mcp_manager=mcp_manager,
|
mcp_manager=mcp_manager,
|
||||||
)
|
)
|
||||||
if self._session_manager is None:
|
if self._session_manager is None:
|
||||||
|
|||||||
@ -4,12 +4,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||||
|
|
||||||
from beaver.engine.context import ContextBuildInput, SessionContext, SkillContext
|
from beaver.engine.context import ContextBuildInput, RuntimeContext, SessionContext, SkillContext
|
||||||
|
from beaver.foundation.events import ChannelIdentity
|
||||||
from beaver.memory.runs import RunRecord, SkillEffectRecord
|
from beaver.memory.runs import RunRecord, SkillEffectRecord
|
||||||
from beaver.skills.learning import RunReceiptContext
|
from beaver.skills.learning import RunReceiptContext
|
||||||
from beaver.skills.catalog.utils import strip_frontmatter
|
from beaver.skills.catalog.utils import strip_frontmatter
|
||||||
@ -26,6 +30,17 @@ TOOL_FAILURE_GUIDANCE_PROMPT = (
|
|||||||
"Use available materials, state uncertainty clearly, and provide partial confirmed results."
|
"Use available materials, state uncertainty clearly, and provide partial confirmed results."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
RAW_TOOL_CALL_FALLBACK = (
|
||||||
|
"The run reached the configured tool-call limit before producing a reliable final answer. "
|
||||||
|
"The model attempted another tool call instead of answering, so the raw tool call was suppressed. "
|
||||||
|
"Please request a revision to continue the task."
|
||||||
|
)
|
||||||
|
|
||||||
|
_RAW_TOOL_CALL_RE = re.compile(
|
||||||
|
r"^\s*<tool_call\b[\s\S]*?</tool_call>\s*$|^\s*<function=[^>]+>[\s\S]*?</function>\s*$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class AgentProfile:
|
class AgentProfile:
|
||||||
@ -34,9 +49,10 @@ class AgentProfile:
|
|||||||
name: str = "default"
|
name: str = "default"
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
default_model: str = "gpt-4.1-mini"
|
default_model: str = "gpt-4.1-mini"
|
||||||
max_tokens: int = 4096
|
max_tokens: int | None = None
|
||||||
|
max_context_messages: int = 1000
|
||||||
temperature: float = 0.2
|
temperature: float = 0.2
|
||||||
max_tool_iterations: int = 8
|
max_tool_iterations: int = 30
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -74,6 +90,7 @@ class AgentLoop:
|
|||||||
self.loaded: EngineLoadResult | None = None
|
self.loaded: EngineLoadResult | None = None
|
||||||
self.runtime_services: dict[str, Any] = {}
|
self.runtime_services: dict[str, Any] = {}
|
||||||
self._run_queue: asyncio.Queue[_DirectRunRequest | None] | None = None
|
self._run_queue: asyncio.Queue[_DirectRunRequest | None] | None = None
|
||||||
|
self._active_direct_task: asyncio.Task[Any] | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
self._stop_requested = False
|
self._stop_requested = False
|
||||||
|
|
||||||
@ -115,6 +132,8 @@ class AgentLoop:
|
|||||||
if item.future.cancelled():
|
if item.future.cancelled():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
previous_direct_task = self._active_direct_task
|
||||||
|
self._active_direct_task = asyncio.current_task()
|
||||||
try:
|
try:
|
||||||
result = await self._process_direct_impl(item.task, **item.kwargs)
|
result = await self._process_direct_impl(item.task, **item.kwargs)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
@ -127,6 +146,8 @@ class AgentLoop:
|
|||||||
else:
|
else:
|
||||||
if not item.future.done():
|
if not item.future.done():
|
||||||
item.future.set_result(result)
|
item.future.set_result(result)
|
||||||
|
finally:
|
||||||
|
self._active_direct_task = previous_direct_task
|
||||||
finally:
|
finally:
|
||||||
if self._run_queue is not None:
|
if self._run_queue is not None:
|
||||||
while True:
|
while True:
|
||||||
@ -168,6 +189,9 @@ class AgentLoop:
|
|||||||
if self._stop_requested:
|
if self._stop_requested:
|
||||||
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
|
raise RuntimeError("AgentLoop.submit_direct() is not accepting new tasks after stop()")
|
||||||
|
|
||||||
|
if asyncio.current_task() is self._active_direct_task:
|
||||||
|
return await self._process_direct_impl(task, **kwargs)
|
||||||
|
|
||||||
future: asyncio.Future[AgentRunResult] = asyncio.get_running_loop().create_future()
|
future: asyncio.Future[AgentRunResult] = asyncio.get_running_loop().create_future()
|
||||||
await self._run_queue.put(_DirectRunRequest(task=task, kwargs=dict(kwargs), future=future))
|
await self._run_queue.put(_DirectRunRequest(task=task, kwargs=dict(kwargs), future=future))
|
||||||
return await future
|
return await future
|
||||||
@ -200,6 +224,7 @@ class AgentLoop:
|
|||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
execution_context: str | None = None,
|
execution_context: str | None = None,
|
||||||
skill_selection_context: str | None = None,
|
skill_selection_context: str | None = None,
|
||||||
|
prompt_locale: str | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
provider_name: str | None = None,
|
provider_name: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
@ -223,8 +248,10 @@ class AgentLoop:
|
|||||||
attempt_index: int | None = None,
|
attempt_index: int | None = None,
|
||||||
pinned_skill_names: list[str] | None = None,
|
pinned_skill_names: list[str] | None = None,
|
||||||
pinned_skill_contexts: list[SkillContext] | None = None,
|
pinned_skill_contexts: list[SkillContext] | None = None,
|
||||||
|
tool_executor_override: Any = None,
|
||||||
allow_candidate_generation: bool = False,
|
allow_candidate_generation: bool = False,
|
||||||
intent_agent_decision: dict[str, Any] | None = None,
|
intent_agent_decision: dict[str, Any] | None = None,
|
||||||
|
channel_identity: ChannelIdentity | None = None,
|
||||||
) -> AgentRunResult:
|
) -> AgentRunResult:
|
||||||
"""跑通最小 direct run 主链。
|
"""跑通最小 direct run 主链。
|
||||||
|
|
||||||
@ -249,6 +276,7 @@ class AgentLoop:
|
|||||||
title=title,
|
title=title,
|
||||||
execution_context=execution_context,
|
execution_context=execution_context,
|
||||||
skill_selection_context=skill_selection_context,
|
skill_selection_context=skill_selection_context,
|
||||||
|
prompt_locale=prompt_locale,
|
||||||
model=model,
|
model=model,
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
@ -272,8 +300,10 @@ class AgentLoop:
|
|||||||
attempt_index=attempt_index,
|
attempt_index=attempt_index,
|
||||||
pinned_skill_names=pinned_skill_names,
|
pinned_skill_names=pinned_skill_names,
|
||||||
pinned_skill_contexts=pinned_skill_contexts,
|
pinned_skill_contexts=pinned_skill_contexts,
|
||||||
|
tool_executor_override=tool_executor_override,
|
||||||
allow_candidate_generation=allow_candidate_generation,
|
allow_candidate_generation=allow_candidate_generation,
|
||||||
intent_agent_decision=intent_agent_decision,
|
intent_agent_decision=intent_agent_decision,
|
||||||
|
channel_identity=channel_identity,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_direct_impl(
|
async def _process_direct_impl(
|
||||||
@ -286,6 +316,7 @@ class AgentLoop:
|
|||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
execution_context: str | None = None,
|
execution_context: str | None = None,
|
||||||
skill_selection_context: str | None = None,
|
skill_selection_context: str | None = None,
|
||||||
|
prompt_locale: str | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
provider_name: str | None = None,
|
provider_name: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
@ -309,8 +340,10 @@ class AgentLoop:
|
|||||||
attempt_index: int | None = None,
|
attempt_index: int | None = None,
|
||||||
pinned_skill_names: list[str] | None = None,
|
pinned_skill_names: list[str] | None = None,
|
||||||
pinned_skill_contexts: list[SkillContext] | None = None,
|
pinned_skill_contexts: list[SkillContext] | None = None,
|
||||||
|
tool_executor_override: Any = None,
|
||||||
allow_candidate_generation: bool = False,
|
allow_candidate_generation: bool = False,
|
||||||
intent_agent_decision: dict[str, Any] | None = None,
|
intent_agent_decision: dict[str, Any] | None = None,
|
||||||
|
channel_identity: ChannelIdentity | None = None,
|
||||||
) -> AgentRunResult:
|
) -> AgentRunResult:
|
||||||
"""真正执行一轮 direct run 的内部实现。
|
"""真正执行一轮 direct run 的内部实现。
|
||||||
|
|
||||||
@ -327,6 +360,7 @@ class AgentLoop:
|
|||||||
tool_registry = self._require_loaded("tool_registry")
|
tool_registry = self._require_loaded("tool_registry")
|
||||||
tool_assembler = self._require_loaded("tool_assembler")
|
tool_assembler = self._require_loaded("tool_assembler")
|
||||||
tool_executor = self._require_loaded("tool_executor")
|
tool_executor = self._require_loaded("tool_executor")
|
||||||
|
effective_tool_executor = tool_executor_override or tool_executor
|
||||||
skills_loader = self._require_loaded("skills_loader")
|
skills_loader = self._require_loaded("skills_loader")
|
||||||
skill_assembler = self._require_loaded("skill_assembler")
|
skill_assembler = self._require_loaded("skill_assembler")
|
||||||
skill_learning_service = self._require_loaded("skill_learning_service")
|
skill_learning_service = self._require_loaded("skill_learning_service")
|
||||||
@ -348,7 +382,7 @@ class AgentLoop:
|
|||||||
resolved_request_timeout_seconds = configured_provider.get("request_timeout_seconds")
|
resolved_request_timeout_seconds = configured_provider.get("request_timeout_seconds")
|
||||||
resolved_embedding_model = embedding_model or config.default_embedding_model
|
resolved_embedding_model = embedding_model or config.default_embedding_model
|
||||||
resolved_embedding_target = embedding_target or config.resolve_embedding_target()
|
resolved_embedding_target = embedding_target or config.resolve_embedding_target()
|
||||||
resolved_max_tokens = max_tokens or self.profile.max_tokens
|
resolved_max_tokens = self.profile.max_tokens if max_tokens is None else max_tokens
|
||||||
resolved_temperature = self.profile.temperature if temperature is None else temperature
|
resolved_temperature = self.profile.temperature if temperature is None else temperature
|
||||||
resolved_max_tool_iterations = (
|
resolved_max_tool_iterations = (
|
||||||
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
|
self.profile.max_tool_iterations if max_tool_iterations is None else max_tool_iterations
|
||||||
@ -446,7 +480,7 @@ class AgentLoop:
|
|||||||
*(pinned_skill_contexts or []),
|
*(pinned_skill_contexts or []),
|
||||||
*self._load_pinned_skill_contexts(skills_loader, pinned_skill_names or []),
|
*self._load_pinned_skill_contexts(skills_loader, pinned_skill_names or []),
|
||||||
]
|
]
|
||||||
if not include_skill_assembly or thinking_enabled is False:
|
if not include_skill_assembly:
|
||||||
activated_skills = self._merge_skill_contexts(pinned_skills, [])
|
activated_skills = self._merge_skill_contexts(pinned_skills, [])
|
||||||
else:
|
else:
|
||||||
skill_query = skill_selection_context or task
|
skill_query = skill_selection_context or task
|
||||||
@ -512,8 +546,6 @@ class AgentLoop:
|
|||||||
|
|
||||||
if not include_tools:
|
if not include_tools:
|
||||||
selected_tool_specs = []
|
selected_tool_specs = []
|
||||||
elif thinking_enabled is False:
|
|
||||||
selected_tool_specs = tool_registry.list_specs()
|
|
||||||
else:
|
else:
|
||||||
selected_tool_specs = await tool_assembler.assemble(
|
selected_tool_specs = await tool_assembler.assemble(
|
||||||
task_description=task,
|
task_description=task,
|
||||||
@ -543,7 +575,11 @@ class AgentLoop:
|
|||||||
|
|
||||||
build_input = ContextBuildInput(
|
build_input = ContextBuildInput(
|
||||||
base_system_prompt=self.profile.system_prompt,
|
base_system_prompt=self.profile.system_prompt,
|
||||||
history=session_manager.get_history(resolved_session_id),
|
prompt_locale=prompt_locale,
|
||||||
|
history=session_manager.get_history(
|
||||||
|
resolved_session_id,
|
||||||
|
max_messages=max(1, self.profile.max_context_messages),
|
||||||
|
),
|
||||||
current_user_input=task,
|
current_user_input=task,
|
||||||
memory_snapshot=memory_snapshot,
|
memory_snapshot=memory_snapshot,
|
||||||
activated_skills=activated_skills,
|
activated_skills=activated_skills,
|
||||||
@ -552,8 +588,16 @@ class AgentLoop:
|
|||||||
source=source,
|
source=source,
|
||||||
model=resolved_model,
|
model=resolved_model,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
channel=channel_identity.channel_id if channel_identity else None,
|
||||||
|
channel_kind=channel_identity.kind if channel_identity else None,
|
||||||
|
account_id=channel_identity.account_id if channel_identity else None,
|
||||||
|
peer_id=channel_identity.peer_id if channel_identity else None,
|
||||||
|
peer_type=channel_identity.peer_type if channel_identity else None,
|
||||||
|
chat_id=channel_identity.peer_id if channel_identity else None,
|
||||||
|
thread_id=channel_identity.thread_id if channel_identity else None,
|
||||||
parent_session_id=parent_session_id,
|
parent_session_id=parent_session_id,
|
||||||
),
|
),
|
||||||
|
runtime_context=self._current_runtime_context(),
|
||||||
execution_context=execution_context,
|
execution_context=execution_context,
|
||||||
extra_sections=[TOOL_FAILURE_GUIDANCE_PROMPT],
|
extra_sections=[TOOL_FAILURE_GUIDANCE_PROMPT],
|
||||||
)
|
)
|
||||||
@ -643,36 +687,39 @@ class AgentLoop:
|
|||||||
while True:
|
while True:
|
||||||
chat_kwargs: dict[str, Any] = {
|
chat_kwargs: dict[str, Any] = {
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools": tool_schemas,
|
"tools": tool_schemas if include_tools else None,
|
||||||
"model": final_model,
|
"model": final_model,
|
||||||
"max_tokens": resolved_max_tokens,
|
"max_tokens": resolved_max_tokens,
|
||||||
"temperature": resolved_temperature,
|
"temperature": resolved_temperature,
|
||||||
}
|
}
|
||||||
if thinking_enabled is not None:
|
if thinking_enabled is not None:
|
||||||
chat_kwargs["thinking_enabled"] = thinking_enabled
|
chat_kwargs["thinking_enabled"] = thinking_enabled
|
||||||
|
message_char_length = len(json.dumps(messages, ensure_ascii=False, default=str))
|
||||||
|
tool_schema_char_length = len(json.dumps(tool_schemas, ensure_ascii=False, default=str))
|
||||||
|
tool_names = [
|
||||||
|
str(tool.get("function", {}).get("name") or tool.get("name") or "tool")
|
||||||
|
for tool in (tool_schemas or [])
|
||||||
|
if isinstance(tool, dict)
|
||||||
|
]
|
||||||
|
snapshot_payload = {
|
||||||
|
"iteration": iterations,
|
||||||
|
"provider_name": final_provider_name,
|
||||||
|
"model": final_model,
|
||||||
|
"message_count": len(messages),
|
||||||
|
"tool_names": tool_names,
|
||||||
|
"message_char_length": message_char_length,
|
||||||
|
"tool_schema_char_length": tool_schema_char_length,
|
||||||
|
"max_tokens": resolved_max_tokens,
|
||||||
|
"temperature": resolved_temperature,
|
||||||
|
"thinking_enabled": thinking_enabled,
|
||||||
|
}
|
||||||
session_manager.append_message(
|
session_manager.append_message(
|
||||||
resolved_session_id,
|
resolved_session_id,
|
||||||
run_id=resolved_run_id,
|
run_id=resolved_run_id,
|
||||||
role="system",
|
role="system",
|
||||||
event_type="llm_request_snapshotted",
|
event_type="llm_request_snapshotted",
|
||||||
event_payload={
|
event_payload=snapshot_payload,
|
||||||
"iteration": iterations,
|
content=json.dumps(snapshot_payload, ensure_ascii=False, default=str),
|
||||||
"provider_name": final_provider_name,
|
|
||||||
"model": final_model,
|
|
||||||
"messages": messages,
|
|
||||||
"tools": tool_schemas,
|
|
||||||
"max_tokens": resolved_max_tokens,
|
|
||||||
"temperature": resolved_temperature,
|
|
||||||
"thinking_enabled": thinking_enabled,
|
|
||||||
},
|
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"messages": messages,
|
|
||||||
"tools": tool_schemas,
|
|
||||||
},
|
|
||||||
ensure_ascii=False,
|
|
||||||
default=str,
|
|
||||||
),
|
|
||||||
context_visible=False,
|
context_visible=False,
|
||||||
source=source,
|
source=source,
|
||||||
title=title,
|
title=title,
|
||||||
@ -696,6 +743,7 @@ class AgentLoop:
|
|||||||
tool_calls=assistant_tool_calls or None,
|
tool_calls=assistant_tool_calls or None,
|
||||||
finish_reason=response.finish_reason,
|
finish_reason=response.finish_reason,
|
||||||
reasoning=response.reasoning_content,
|
reasoning=response.reasoning_content,
|
||||||
|
context_visible=not bool(assistant_tool_calls),
|
||||||
source=source,
|
source=source,
|
||||||
title=title,
|
title=title,
|
||||||
model=final_model,
|
model=final_model,
|
||||||
@ -710,12 +758,24 @@ class AgentLoop:
|
|||||||
|
|
||||||
if not response.has_tool_calls:
|
if not response.has_tool_calls:
|
||||||
final_text = response.content or ""
|
final_text = response.content or ""
|
||||||
|
if self._looks_like_raw_tool_call(final_text):
|
||||||
|
final_text = RAW_TOOL_CALL_FALLBACK
|
||||||
|
final_finish_reason = "invalid_tool_call_text"
|
||||||
|
else:
|
||||||
final_finish_reason = response.finish_reason or "stop"
|
final_finish_reason = response.finish_reason or "stop"
|
||||||
break
|
break
|
||||||
|
|
||||||
if iterations >= resolved_max_tool_iterations:
|
if iterations >= resolved_max_tool_iterations:
|
||||||
final_text = response.content or "Tool loop stopped after reaching the configured iteration limit."
|
finalized = await self._finalize_after_tool_limit(
|
||||||
final_finish_reason = "max_tool_iterations"
|
provider=provider,
|
||||||
|
messages=messages,
|
||||||
|
model=final_model,
|
||||||
|
max_tokens=resolved_max_tokens,
|
||||||
|
temperature=resolved_temperature,
|
||||||
|
thinking_enabled=thinking_enabled,
|
||||||
|
)
|
||||||
|
final_text = finalized or RAW_TOOL_CALL_FALLBACK
|
||||||
|
final_finish_reason = "max_tool_iterations_finalized" if finalized else "max_tool_iterations"
|
||||||
session_manager.append_message(
|
session_manager.append_message(
|
||||||
resolved_session_id,
|
resolved_session_id,
|
||||||
run_id=resolved_run_id,
|
run_id=resolved_run_id,
|
||||||
@ -737,7 +797,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
iterations += 1
|
iterations += 1
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
|
result = await effective_tool_executor.execute_tool_call(tool_call, context=tool_context)
|
||||||
session_manager.append_message(
|
session_manager.append_message(
|
||||||
resolved_session_id,
|
resolved_session_id,
|
||||||
run_id=resolved_run_id,
|
run_id=resolved_run_id,
|
||||||
@ -859,6 +919,56 @@ class AgentLoop:
|
|||||||
raise RuntimeError(f"Engine loader did not provide required dependency {field_name!r}")
|
raise RuntimeError(f"Engine loader did not provide required dependency {field_name!r}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _finalize_after_tool_limit(
|
||||||
|
*,
|
||||||
|
provider: Any,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None,
|
||||||
|
temperature: float,
|
||||||
|
thinking_enabled: bool | None,
|
||||||
|
) -> str:
|
||||||
|
final_messages = AgentLoop._with_system_guidance(
|
||||||
|
messages,
|
||||||
|
(
|
||||||
|
"The configured tool iteration budget is exhausted. Do not call tools. "
|
||||||
|
"Produce the best final answer from the existing conversation and tool results. "
|
||||||
|
"State uncertainty explicitly."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"messages": final_messages,
|
||||||
|
"tools": None,
|
||||||
|
"model": model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
}
|
||||||
|
if thinking_enabled is not None:
|
||||||
|
kwargs["thinking_enabled"] = thinking_enabled
|
||||||
|
response = await provider.chat(**kwargs)
|
||||||
|
if response.has_tool_calls:
|
||||||
|
return ""
|
||||||
|
content = (response.content or "").strip()
|
||||||
|
if AgentLoop._looks_like_raw_tool_call(content):
|
||||||
|
return ""
|
||||||
|
return content
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _looks_like_raw_tool_call(content: str | None) -> bool:
|
||||||
|
if not content:
|
||||||
|
return False
|
||||||
|
return bool(_RAW_TOOL_CALL_RE.match(content))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _with_system_guidance(messages: list[dict[str, Any]], guidance: str) -> list[dict[str, Any]]:
|
||||||
|
copied = [dict(message) for message in messages]
|
||||||
|
if copied and copied[0].get("role") == "system":
|
||||||
|
existing = str(copied[0].get("content") or "").strip()
|
||||||
|
copied[0]["content"] = "\n\n".join(part for part in (existing, guidance.strip()) if part)
|
||||||
|
return copied
|
||||||
|
return [{"role": "system", "content": guidance.strip()}, *copied]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_pinned_skill_contexts(skills_loader: Any, skill_names: list[str]) -> list[SkillContext]:
|
def _load_pinned_skill_contexts(skills_loader: Any, skill_names: list[str]) -> list[SkillContext]:
|
||||||
contexts: list[SkillContext] = []
|
contexts: list[SkillContext] = []
|
||||||
@ -1092,3 +1202,49 @@ class AgentLoop:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _utc_now() -> str:
|
def _utc_now() -> str:
|
||||||
return datetime.now(timezone.utc).isoformat()
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _current_runtime_context() -> RuntimeContext:
|
||||||
|
utc_now = datetime.now(timezone.utc)
|
||||||
|
timezone_name = AgentLoop._configured_timezone_name()
|
||||||
|
local_now = datetime.now().astimezone()
|
||||||
|
rendered_timezone = local_now.tzname()
|
||||||
|
|
||||||
|
if timezone_name:
|
||||||
|
try:
|
||||||
|
local_now = utc_now.astimezone(ZoneInfo(timezone_name))
|
||||||
|
rendered_timezone = timezone_name
|
||||||
|
except ZoneInfoNotFoundError:
|
||||||
|
rendered_timezone = local_now.tzname() or timezone_name
|
||||||
|
|
||||||
|
return RuntimeContext(
|
||||||
|
utc_datetime=utc_now.isoformat(),
|
||||||
|
local_datetime=local_now.isoformat(),
|
||||||
|
timezone=rendered_timezone,
|
||||||
|
utc_offset=AgentLoop._format_utc_offset(local_now),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configured_timezone_name() -> str | None:
|
||||||
|
for value in (os.getenv("BEAVER_RUNTIME_TIMEZONE"), os.getenv("TZ")):
|
||||||
|
cleaned = (value or "").strip()
|
||||||
|
if cleaned:
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
try:
|
||||||
|
timezone_file = "/etc/timezone"
|
||||||
|
if os.path.exists(timezone_file):
|
||||||
|
with open(timezone_file, encoding="utf-8") as file:
|
||||||
|
cleaned = file.read().strip()
|
||||||
|
if cleaned:
|
||||||
|
return cleaned
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_utc_offset(value: datetime) -> str | None:
|
||||||
|
raw = value.strftime("%z")
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
return f"{raw[:3]}:{raw[3:]}"
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class AnthropicProvider(LLMProvider):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int | None = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
thinking_enabled: bool | None = None,
|
thinking_enabled: bool | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
@ -57,9 +57,14 @@ class AnthropicProvider(LLMProvider):
|
|||||||
"model": model or self.default_model,
|
"model": model or self.default_model,
|
||||||
"system": system_prompt or "",
|
"system": system_prompt or "",
|
||||||
"messages": anthropic_messages,
|
"messages": anthropic_messages,
|
||||||
"max_tokens": max(1, max_tokens),
|
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
resolved_max_tokens = (
|
||||||
|
_default_max_tokens_for_model(model or self.default_model)
|
||||||
|
if max_tokens is None
|
||||||
|
else max(1, max_tokens)
|
||||||
|
)
|
||||||
|
kwargs["max_tokens"] = resolved_max_tokens
|
||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = _convert_tools(tools)
|
kwargs["tools"] = _convert_tools(tools)
|
||||||
|
|
||||||
@ -100,6 +105,17 @@ class AnthropicProvider(LLMProvider):
|
|||||||
return self.default_model
|
return self.default_model
|
||||||
|
|
||||||
|
|
||||||
|
def _default_max_tokens_for_model(model: str) -> int:
|
||||||
|
"""Return a conservative native output ceiling for Anthropic Messages."""
|
||||||
|
|
||||||
|
normalized = model.lower().replace("_", "-")
|
||||||
|
if "sonnet-4" in normalized or "opus-4" in normalized or "3-7" in normalized or "3.7" in normalized:
|
||||||
|
return 64_000
|
||||||
|
if "haiku" in normalized:
|
||||||
|
return 4_096
|
||||||
|
return 8_192
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||||
system_prompt = ""
|
system_prompt = ""
|
||||||
converted: list[dict[str, Any]] = []
|
converted: list[dict[str, Any]] = []
|
||||||
|
|||||||
@ -88,7 +88,7 @@ class LLMProvider(ABC):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int | None = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
thinking_enabled: bool | None = None,
|
thinking_enabled: bool | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
|
|||||||
@ -56,7 +56,7 @@ class FallbackProviderChain(LLMProvider):
|
|||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int | None = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
thinking_enabled: bool | None = None,
|
thinking_enabled: bool | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
@ -115,7 +115,7 @@ class FallbackProviderChain(LLMProvider):
|
|||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
tools: list[dict] | None,
|
tools: list[dict] | None,
|
||||||
model: str,
|
model: str,
|
||||||
max_tokens: int,
|
max_tokens: int | None,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
thinking_enabled: bool | None,
|
thinking_enabled: bool | None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int | None = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
thinking_enabled: bool | None = None,
|
thinking_enabled: bool | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class CustomProvider(LLMProvider):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int | None = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
thinking_enabled: bool | None = None,
|
thinking_enabled: bool | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
@ -55,9 +55,10 @@ class CustomProvider(LLMProvider):
|
|||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model or self.default_model,
|
"model": model or self.default_model,
|
||||||
"messages": self.sanitize_empty_content(messages),
|
"messages": self.sanitize_empty_content(messages),
|
||||||
"max_tokens": max(1, max_tokens),
|
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
if max_tokens is not None:
|
||||||
|
kwargs["max_tokens"] = max(1, max_tokens)
|
||||||
if tools:
|
if tools:
|
||||||
kwargs.update(tools=tools, tool_choice="auto")
|
kwargs.update(tools=tools, tool_choice="auto")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -3,9 +3,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from ipaddress import ip_address
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from .registry import find_by_model, find_by_name, find_gateway
|
from .registry import find_by_model, find_by_name, find_gateway
|
||||||
@ -23,7 +25,24 @@ except ModuleNotFoundError: # pragma: no cover
|
|||||||
litellm = None # type: ignore[assignment]
|
litellm = None # type: ignore[assignment]
|
||||||
acompletion = None # type: ignore[assignment]
|
acompletion = None # type: ignore[assignment]
|
||||||
|
|
||||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_like_local_vllm_api_base(api_base: str | None) -> bool:
|
||||||
|
if not api_base:
|
||||||
|
return False
|
||||||
|
lowered = api_base.lower()
|
||||||
|
if "vllm" in lowered or "localhost" in lowered:
|
||||||
|
return True
|
||||||
|
|
||||||
|
host = urlsplit(lowered).hostname or ""
|
||||||
|
if host in {"127.0.0.1", "::1", "0.0.0.0"}:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
parsed_host = ip_address(host)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return parsed_host.is_private or parsed_host.is_loopback
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMProvider(LLMProvider):
|
class LiteLLMProvider(LLMProvider):
|
||||||
@ -119,13 +138,23 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
sanitized = []
|
sanitized = []
|
||||||
|
system_contents: list[str] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
|
clean = {key: value for key, value in message.items() if key in _ALLOWED_MSG_KEYS}
|
||||||
|
if clean.get("role") == "system":
|
||||||
|
content = clean.get("content")
|
||||||
|
if isinstance(content, str) and content.strip():
|
||||||
|
system_contents.append(content.strip())
|
||||||
|
elif content is not None:
|
||||||
|
system_contents.append(str(content))
|
||||||
|
continue
|
||||||
if clean.get("role") == "assistant" and "content" not in clean:
|
if clean.get("role") == "assistant" and "content" not in clean:
|
||||||
clean["content"] = None
|
clean["content"] = None
|
||||||
if isinstance(clean.get("tool_calls"), list):
|
if isinstance(clean.get("tool_calls"), list):
|
||||||
clean["tool_calls"] = LiteLLMProvider._sanitize_tool_calls(clean["tool_calls"])
|
clean["tool_calls"] = LiteLLMProvider._sanitize_tool_calls(clean["tool_calls"])
|
||||||
sanitized.append(clean)
|
sanitized.append(clean)
|
||||||
|
if system_contents:
|
||||||
|
sanitized.insert(0, {"role": "system", "content": "\n\n".join(system_contents)})
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -175,23 +204,34 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
kwargs["provider"] = provider_payload
|
kwargs["provider"] = provider_payload
|
||||||
|
|
||||||
def _apply_thinking_mode(self, original_model: str, resolved_model: str, kwargs: dict[str, Any], enabled: bool | None) -> None:
|
def _apply_thinking_mode(self, original_model: str, resolved_model: str, kwargs: dict[str, Any], enabled: bool | None) -> None:
|
||||||
if enabled is None:
|
if self._uses_mistral_reasoning_parser(original_model, resolved_model):
|
||||||
return
|
if enabled is not None:
|
||||||
model_key = f"{original_model} {resolved_model}".lower()
|
extra_body = dict(kwargs.get("extra_body") or {})
|
||||||
if "qwen" not in model_key:
|
extra_body["reasoning_effort"] = "high" if enabled else "none"
|
||||||
|
kwargs["extra_body"] = extra_body
|
||||||
return
|
return
|
||||||
|
|
||||||
extra_body = dict(kwargs.get("extra_body") or {})
|
extra_body = dict(kwargs.get("extra_body") or {})
|
||||||
chat_template_kwargs = dict(extra_body.get("chat_template_kwargs") or {})
|
chat_template_kwargs = dict(extra_body.get("chat_template_kwargs") or {})
|
||||||
chat_template_kwargs["enable_thinking"] = bool(enabled)
|
chat_template_kwargs["enable_thinking"] = False
|
||||||
extra_body["chat_template_kwargs"] = chat_template_kwargs
|
extra_body["chat_template_kwargs"] = chat_template_kwargs
|
||||||
|
extra_body["thinking"] = {"type": "disabled"}
|
||||||
kwargs["extra_body"] = extra_body
|
kwargs["extra_body"] = extra_body
|
||||||
|
|
||||||
|
def _uses_mistral_reasoning_parser(self, original_model: str, resolved_model: str) -> bool:
|
||||||
|
model_names = f"{original_model} {resolved_model}".lower()
|
||||||
|
if "mistral" not in model_names:
|
||||||
|
return False
|
||||||
|
if self.provider_name == "vllm":
|
||||||
|
return True
|
||||||
|
return self.provider_name in {"openai", "custom"} and _looks_like_local_vllm_api_base(self.api_base)
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int | None = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
thinking_enabled: bool | None = None,
|
thinking_enabled: bool | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
@ -204,10 +244,11 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": resolved_model,
|
"model": resolved_model,
|
||||||
"messages": sanitized_messages,
|
"messages": sanitized_messages,
|
||||||
"max_tokens": max(1, max_tokens),
|
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"timeout": self.request_timeout_seconds or 45.0,
|
"timeout": self.request_timeout_seconds or 45.0,
|
||||||
}
|
}
|
||||||
|
if max_tokens is not None:
|
||||||
|
kwargs["max_tokens"] = max(1, max_tokens)
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
kwargs["api_key"] = self.api_key
|
kwargs["api_key"] = self.api_key
|
||||||
if self.api_base:
|
if self.api_base:
|
||||||
|
|||||||
@ -84,8 +84,10 @@ class MessageRecord:
|
|||||||
payload["task_id"] = self.event_payload.get("task_id")
|
payload["task_id"] = self.event_payload.get("task_id")
|
||||||
if self.event_payload.get("task_status"):
|
if self.event_payload.get("task_status"):
|
||||||
payload["task_status"] = self.event_payload.get("task_status")
|
payload["task_status"] = self.event_payload.get("task_status")
|
||||||
if self.event_payload.get("validation_status"):
|
if self.event_payload.get("evidence_status"):
|
||||||
payload["validation_status"] = self.event_payload.get("validation_status")
|
payload["evidence_status"] = self.event_payload.get("evidence_status")
|
||||||
|
if self.event_payload.get("acceptance_state"):
|
||||||
|
payload["acceptance_state"] = self.event_payload.get("acceptance_state")
|
||||||
if self.event_payload.get("feedback_state"):
|
if self.event_payload.get("feedback_state"):
|
||||||
payload["feedback_state"] = self.event_payload.get("feedback_state")
|
payload["feedback_state"] = self.event_payload.get("feedback_state")
|
||||||
if self.event_payload.get("feedback_error"):
|
if self.event_payload.get("feedback_error"):
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from .schema import (
|
|||||||
AuthzConfig,
|
AuthzConfig,
|
||||||
BackendIdentityConfig,
|
BackendIdentityConfig,
|
||||||
BeaverConfig,
|
BeaverConfig,
|
||||||
|
ChannelConfig,
|
||||||
EmbeddingConfig,
|
EmbeddingConfig,
|
||||||
MCPServerConfig,
|
MCPServerConfig,
|
||||||
ProviderConfig,
|
ProviderConfig,
|
||||||
@ -73,6 +74,7 @@ def load_config(
|
|||||||
embedding=_parse_embedding(data),
|
embedding=_parse_embedding(data),
|
||||||
tools=_parse_tools(data.get("tools")),
|
tools=_parse_tools(data.get("tools")),
|
||||||
authz=_parse_authz(data.get("authz")),
|
authz=_parse_authz(data.get("authz")),
|
||||||
|
channels=_parse_channels(data.get("channels")),
|
||||||
backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")),
|
backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")),
|
||||||
config_path=path,
|
config_path=path,
|
||||||
)
|
)
|
||||||
@ -86,6 +88,25 @@ def _parse_agent_defaults(data: dict[str, Any]) -> AgentDefaultsConfig:
|
|||||||
model=_string(defaults.get("model") or data.get("model")),
|
model=_string(defaults.get("model") or data.get("model")),
|
||||||
provider=_string(defaults.get("provider") or data.get("provider")),
|
provider=_string(defaults.get("provider") or data.get("provider")),
|
||||||
embedding_model=_string(defaults.get("embeddingModel") or defaults.get("embedding_model") or data.get("embeddingModel")),
|
embedding_model=_string(defaults.get("embeddingModel") or defaults.get("embedding_model") or data.get("embeddingModel")),
|
||||||
|
max_tokens=_int(_first_config_value(
|
||||||
|
defaults.get("maxTokens"),
|
||||||
|
defaults.get("max_tokens"),
|
||||||
|
data.get("maxTokens"),
|
||||||
|
data.get("max_tokens"),
|
||||||
|
)),
|
||||||
|
temperature=_float(_first_config_value(defaults.get("temperature"), data.get("temperature"))),
|
||||||
|
max_context_messages=_int(
|
||||||
|
defaults.get("maxContextMessages")
|
||||||
|
or defaults.get("max_context_messages")
|
||||||
|
or data.get("maxContextMessages")
|
||||||
|
or data.get("max_context_messages")
|
||||||
|
),
|
||||||
|
max_tool_iterations=_int(_first_config_value(
|
||||||
|
defaults.get("maxToolIterations"),
|
||||||
|
defaults.get("max_tool_iterations"),
|
||||||
|
data.get("maxToolIterations"),
|
||||||
|
data.get("max_tool_iterations"),
|
||||||
|
)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -177,6 +198,48 @@ def _parse_authz(raw: Any) -> AuthzConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_channels(raw: Any) -> dict[str, ChannelConfig]:
|
||||||
|
channels: dict[str, ChannelConfig] = {}
|
||||||
|
for channel_id, payload in _as_dict(raw).items():
|
||||||
|
cleaned_id = str(channel_id).strip()
|
||||||
|
if not cleaned_id:
|
||||||
|
continue
|
||||||
|
channels[cleaned_id] = _parse_channel_config(payload)
|
||||||
|
return channels
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_channel_config(payload: Any) -> ChannelConfig:
|
||||||
|
data = _as_dict(payload)
|
||||||
|
return ChannelConfig(
|
||||||
|
enabled=_bool(data.get("enabled"), default=False),
|
||||||
|
kind=_string(data.get("kind")) or "",
|
||||||
|
mode=_string(data.get("mode")) or "webhook",
|
||||||
|
account_id=_string(data.get("accountId") or data.get("account_id")) or "",
|
||||||
|
display_name=_string(data.get("displayName") or data.get("display_name")) or "",
|
||||||
|
config=_normalize_config_map(data.get("config")),
|
||||||
|
secrets=_string_dict(data.get("secrets")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_config_map(value: Any) -> dict[str, Any]:
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
_camel_to_snake_key(str(key)): item
|
||||||
|
for key, item in value.items()
|
||||||
|
if str(key).strip()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _camel_to_snake_key(value: str) -> str:
|
||||||
|
result: list[str] = []
|
||||||
|
for char in value:
|
||||||
|
if char.isupper() and result:
|
||||||
|
result.append("_")
|
||||||
|
result.append(char.lower())
|
||||||
|
return "".join(result)
|
||||||
|
|
||||||
|
|
||||||
def _parse_backend_identity(raw: Any) -> BackendIdentityConfig:
|
def _parse_backend_identity(raw: Any) -> BackendIdentityConfig:
|
||||||
data = _as_dict(raw)
|
data = _as_dict(raw)
|
||||||
return BackendIdentityConfig(
|
return BackendIdentityConfig(
|
||||||
@ -192,6 +255,13 @@ def _as_dict(value: Any) -> dict[str, Any]:
|
|||||||
return value if isinstance(value, dict) else {}
|
return value if isinstance(value, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _first_config_value(*values: Any) -> Any:
|
||||||
|
for value in values:
|
||||||
|
if value not in (None, ""):
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _string(value: Any) -> str | None:
|
def _string(value: Any) -> str | None:
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
@ -217,6 +287,13 @@ def _float(value: Any) -> float | None:
|
|||||||
return float(value)
|
return float(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _int(value: Any) -> int | None:
|
||||||
|
parsed = _float(value)
|
||||||
|
if parsed is None:
|
||||||
|
return None
|
||||||
|
return int(parsed)
|
||||||
|
|
||||||
|
|
||||||
def _bool(value: Any, *, default: bool) -> bool:
|
def _bool(value: Any, *, default: bool) -> bool:
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return value
|
return value
|
||||||
|
|||||||
@ -25,6 +25,10 @@ class AgentDefaultsConfig:
|
|||||||
model: str | None = None
|
model: str | None = None
|
||||||
provider: str | None = None
|
provider: str | None = None
|
||||||
embedding_model: str | None = None
|
embedding_model: str | None = None
|
||||||
|
max_tokens: int | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
max_context_messages: int | None = None
|
||||||
|
max_tool_iterations: int | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -87,6 +91,19 @@ class AuthzConfig:
|
|||||||
outlook_mcp_url: str = ""
|
outlook_mcp_url: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ChannelConfig:
|
||||||
|
"""One configured channel adapter instance."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
kind: str = ""
|
||||||
|
mode: str = "webhook"
|
||||||
|
account_id: str = ""
|
||||||
|
display_name: str = ""
|
||||||
|
config: dict[str, Any] = field(default_factory=dict)
|
||||||
|
secrets: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class BackendIdentityConfig:
|
class BackendIdentityConfig:
|
||||||
"""This backend's AuthZ client identity."""
|
"""This backend's AuthZ client identity."""
|
||||||
@ -107,6 +124,7 @@ class BeaverConfig:
|
|||||||
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
|
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
|
||||||
tools: ToolsConfig = field(default_factory=ToolsConfig)
|
tools: ToolsConfig = field(default_factory=ToolsConfig)
|
||||||
authz: AuthzConfig = field(default_factory=AuthzConfig)
|
authz: AuthzConfig = field(default_factory=AuthzConfig)
|
||||||
|
channels: dict[str, ChannelConfig] = field(default_factory=dict)
|
||||||
backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig)
|
backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig)
|
||||||
config_path: Path | None = None
|
config_path: Path | None = None
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""Event contracts and dispatch helpers."""
|
"""Event contracts and dispatch helpers."""
|
||||||
|
|
||||||
from .message_bus import InboundMessage, MessageBus, OutboundMessage
|
from .message_bus import ChannelIdentity, InboundMessage, MessageBus, OutboundMessage
|
||||||
|
|
||||||
__all__ = ["InboundMessage", "MessageBus", "OutboundMessage"]
|
__all__ = ["ChannelIdentity", "InboundMessage", "MessageBus", "OutboundMessage"]
|
||||||
|
|||||||
@ -9,12 +9,58 @@ from typing import Any
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ChannelIdentity:
|
||||||
|
"""Normalized channel routing identity.
|
||||||
|
|
||||||
|
`channel_id` is the Beaver adapter instance id, not the platform kind.
|
||||||
|
"""
|
||||||
|
|
||||||
|
channel_id: str
|
||||||
|
kind: str
|
||||||
|
account_id: str
|
||||||
|
peer_id: str
|
||||||
|
thread_id: str | None = None
|
||||||
|
peer_type: str = "unknown"
|
||||||
|
user_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
|
||||||
|
def validation_error(self) -> str | None:
|
||||||
|
if not self.channel_id.strip():
|
||||||
|
return "channel_id is required"
|
||||||
|
if not self.account_id.strip():
|
||||||
|
return "account_id is required"
|
||||||
|
if not self.peer_id.strip():
|
||||||
|
return "peer_id is required"
|
||||||
|
return None
|
||||||
|
|
||||||
|
def session_id(self) -> str:
|
||||||
|
parts = [self.channel_id, self.account_id, self.peer_id]
|
||||||
|
if self.thread_id:
|
||||||
|
parts.append(self.thread_id)
|
||||||
|
return ":".join(_clean_session_part(part) for part in parts)
|
||||||
|
|
||||||
|
def dedupe_key(self) -> str | None:
|
||||||
|
if not self.message_id:
|
||||||
|
return None
|
||||||
|
return f"{self.session_id()}:{_clean_session_part(self.message_id)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_session_part(value: str) -> str:
|
||||||
|
cleaned = str(value).strip()
|
||||||
|
if not cleaned:
|
||||||
|
return "unknown"
|
||||||
|
return cleaned.replace(":", "_")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class InboundMessage:
|
class InboundMessage:
|
||||||
"""A minimal inbound message accepted by the gateway bridge."""
|
"""A minimal inbound message accepted by the gateway bridge."""
|
||||||
|
|
||||||
channel: str
|
channel: str
|
||||||
content: str
|
content: str
|
||||||
|
content_type: str = "text"
|
||||||
|
channel_identity: ChannelIdentity | None = None
|
||||||
session_id: str | None = None
|
session_id: str | None = None
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
@ -35,6 +81,8 @@ class OutboundMessage:
|
|||||||
content: str
|
content: str
|
||||||
session_id: str | None
|
session_id: str | None
|
||||||
finish_reason: str
|
finish_reason: str
|
||||||
|
content_type: str = "text"
|
||||||
|
channel_identity: ChannelIdentity | None = None
|
||||||
message_id: str = field(default_factory=lambda: str(uuid4()))
|
message_id: str = field(default_factory=lambda: str(uuid4()))
|
||||||
run_id: str | None = None
|
run_id: str | None = None
|
||||||
provider_name: str | None = None
|
provider_name: str | None = None
|
||||||
|
|||||||
@ -1,7 +1,17 @@
|
|||||||
"""Channel interfaces."""
|
"""Channel interfaces."""
|
||||||
|
|
||||||
from .base import ChannelAdapter
|
from .base import ChannelAdapter
|
||||||
|
from .base import ChannelInboundSink
|
||||||
|
from .external_connector import ExternalConnectorChannel
|
||||||
from .manager import ChannelManager
|
from .manager import ChannelManager
|
||||||
from .memory import MemoryChannelAdapter
|
from .memory import MemoryChannelAdapter
|
||||||
|
from .terminal_websocket import TerminalWebSocketAdapter
|
||||||
|
|
||||||
__all__ = ["ChannelAdapter", "ChannelManager", "MemoryChannelAdapter"]
|
__all__ = [
|
||||||
|
"ChannelAdapter",
|
||||||
|
"ChannelInboundSink",
|
||||||
|
"ExternalConnectorChannel",
|
||||||
|
"ChannelManager",
|
||||||
|
"MemoryChannelAdapter",
|
||||||
|
"TerminalWebSocketAdapter",
|
||||||
|
]
|
||||||
|
|||||||
@ -2,16 +2,17 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from beaver.foundation.events import MessageBus, OutboundMessage
|
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||||
|
|
||||||
|
|
||||||
class ChannelAdapter(Protocol):
|
class ChannelAdapter(Protocol):
|
||||||
"""Minimal contract every gateway channel must implement."""
|
"""Minimal contract every runtime channel adapter must implement."""
|
||||||
|
|
||||||
name: str
|
channel_id: str
|
||||||
bus: MessageBus
|
kind: str
|
||||||
|
mode: str
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Prepare the channel before messages are routed."""
|
"""Prepare the channel before messages are routed."""
|
||||||
@ -22,3 +23,9 @@ class ChannelAdapter(Protocol):
|
|||||||
async def send(self, message: OutboundMessage) -> None:
|
async def send(self, message: OutboundMessage) -> None:
|
||||||
"""Deliver an outbound message to the concrete channel."""
|
"""Deliver an outbound message to the concrete channel."""
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelInboundSink(Protocol):
|
||||||
|
"""Runtime callback used by adapters to submit normalized inbound messages."""
|
||||||
|
|
||||||
|
async def accept_inbound(self, message: InboundMessage) -> Any:
|
||||||
|
"""Accept a normalized inbound message from an adapter."""
|
||||||
|
|||||||
@ -0,0 +1,29 @@
|
|||||||
|
"""Channel connection setup layer."""
|
||||||
|
|
||||||
|
from .connectors import ChannelConnector, ChannelConnectorRegistry
|
||||||
|
from .dedupe import ConnectorMessageDedupeRecord, DedupeBeginResult, MessageDedupeStore
|
||||||
|
from .external import ExternalConnectorBase, FeishuConnector, WeixinConnector
|
||||||
|
from .models import ChannelConnection, ChannelRuntimeSpec, PairingSession, ValidationResult
|
||||||
|
from .sidecar_client import ConnectorSidecarClient
|
||||||
|
from .store import ChannelConnectionStore, CredentialStore, PairingTokenStore
|
||||||
|
from .telegram import TelegramConnector
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChannelConnector",
|
||||||
|
"ChannelConnectorRegistry",
|
||||||
|
"ConnectorMessageDedupeRecord",
|
||||||
|
"DedupeBeginResult",
|
||||||
|
"MessageDedupeStore",
|
||||||
|
"ExternalConnectorBase",
|
||||||
|
"FeishuConnector",
|
||||||
|
"WeixinConnector",
|
||||||
|
"ConnectorSidecarClient",
|
||||||
|
"ChannelConnection",
|
||||||
|
"ChannelRuntimeSpec",
|
||||||
|
"PairingSession",
|
||||||
|
"ValidationResult",
|
||||||
|
"ChannelConnectionStore",
|
||||||
|
"CredentialStore",
|
||||||
|
"PairingTokenStore",
|
||||||
|
"TelegramConnector",
|
||||||
|
]
|
||||||
@ -0,0 +1,93 @@
|
|||||||
|
"""Channel connector registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from beaver.foundation.config.schema import ChannelConfig
|
||||||
|
|
||||||
|
from .models import ChannelRuntimeSpec, ValidationResult
|
||||||
|
from .store import ChannelConnectionStore, CredentialStore
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnector(Protocol):
|
||||||
|
kind: str
|
||||||
|
|
||||||
|
async def validate(self, connection_id: str) -> ValidationResult:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def revoke(self, connection_id: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectorRegistry:
|
||||||
|
def __init__(self, *, connection_store: ChannelConnectionStore, credential_store: CredentialStore) -> None:
|
||||||
|
self.connection_store = connection_store
|
||||||
|
self.credential_store = credential_store
|
||||||
|
self._connectors: dict[str, ChannelConnector] = {}
|
||||||
|
|
||||||
|
def register(self, connector: ChannelConnector) -> None:
|
||||||
|
kind = connector.kind.strip()
|
||||||
|
if not kind:
|
||||||
|
raise ValueError("Connector kind is required")
|
||||||
|
if kind in self._connectors:
|
||||||
|
raise ValueError(f"Connector already registered: {kind}")
|
||||||
|
self._connectors[kind] = connector
|
||||||
|
|
||||||
|
def connectors(self) -> list[dict[str, str]]:
|
||||||
|
return [{"kind": kind} for kind in sorted(self._connectors)]
|
||||||
|
|
||||||
|
def connector_for_kind(self, kind: str) -> ChannelConnector:
|
||||||
|
return self._connector(kind)
|
||||||
|
|
||||||
|
async def validate(self, connection_id: str) -> ValidationResult:
|
||||||
|
connection = self.connection_store.get(connection_id)
|
||||||
|
connector = self._connector(connection.kind)
|
||||||
|
result = await connector.validate(connection_id)
|
||||||
|
self.connection_store.update_status(
|
||||||
|
connection_id,
|
||||||
|
status=result.status,
|
||||||
|
last_error=result.error,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||||
|
connection = self.connection_store.get(connection_id)
|
||||||
|
return await self._connector(connection.kind).materialize_runtime(connection_id)
|
||||||
|
|
||||||
|
async def materialize_connected_runtime_specs(self) -> list[ChannelRuntimeSpec]:
|
||||||
|
specs: list[ChannelRuntimeSpec] = []
|
||||||
|
for connection in self.connection_store.list():
|
||||||
|
if connection.status not in {"connected", "running"}:
|
||||||
|
continue
|
||||||
|
specs.append(await self._connector(connection.kind).materialize_runtime(connection.connection_id))
|
||||||
|
return specs
|
||||||
|
|
||||||
|
async def materialize_channel_configs(self) -> dict[str, ChannelConfig]:
|
||||||
|
channels: dict[str, ChannelConfig] = {}
|
||||||
|
for spec in await self.materialize_connected_runtime_specs():
|
||||||
|
secrets = self.credential_store.get(spec.secrets_ref) if spec.secrets_ref else {}
|
||||||
|
channels[spec.channel_id] = ChannelConfig(
|
||||||
|
enabled=True,
|
||||||
|
kind=spec.kind,
|
||||||
|
mode=spec.mode,
|
||||||
|
account_id=spec.account_id,
|
||||||
|
display_name=spec.display_name,
|
||||||
|
config=dict(spec.config),
|
||||||
|
secrets=secrets,
|
||||||
|
)
|
||||||
|
return channels
|
||||||
|
|
||||||
|
async def revoke(self, connection_id: str) -> None:
|
||||||
|
connection = self.connection_store.get(connection_id)
|
||||||
|
await self._connector(connection.kind).revoke(connection_id)
|
||||||
|
self.connection_store.revoke(connection_id)
|
||||||
|
|
||||||
|
def _connector(self, kind: str) -> ChannelConnector:
|
||||||
|
connector = self._connectors.get(kind)
|
||||||
|
if connector is None:
|
||||||
|
raise KeyError(f"Connector not registered: {kind}")
|
||||||
|
return connector
|
||||||
@ -0,0 +1,144 @@
|
|||||||
|
"""Bridge event dedupe store for external connector retries."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def _iso_now() -> str:
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_iso(value: str) -> datetime:
|
||||||
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ConnectorMessageDedupeRecord:
|
||||||
|
dedupe_key: str
|
||||||
|
connection_id: str
|
||||||
|
event_id: str
|
||||||
|
status: str
|
||||||
|
first_seen_at: str
|
||||||
|
updated_at: str
|
||||||
|
delivery_attempts: int
|
||||||
|
message_id: str | None = None
|
||||||
|
last_error: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "ConnectorMessageDedupeRecord":
|
||||||
|
return cls(
|
||||||
|
dedupe_key=str(data.get("dedupe_key") or ""),
|
||||||
|
connection_id=str(data.get("connection_id") or ""),
|
||||||
|
event_id=str(data.get("event_id") or ""),
|
||||||
|
status=str(data.get("status") or "processing"),
|
||||||
|
first_seen_at=str(data.get("first_seen_at") or _iso_now()),
|
||||||
|
updated_at=str(data.get("updated_at") or _iso_now()),
|
||||||
|
delivery_attempts=int(data.get("delivery_attempts") or 0),
|
||||||
|
message_id=str(data["message_id"]) if data.get("message_id") is not None else None,
|
||||||
|
last_error=str(data["last_error"]) if data.get("last_error") is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class DedupeBeginResult:
|
||||||
|
should_process: bool
|
||||||
|
dedupe_key: str
|
||||||
|
status: str
|
||||||
|
http_status: int
|
||||||
|
retry_after_seconds: int | None
|
||||||
|
record: ConnectorMessageDedupeRecord
|
||||||
|
|
||||||
|
|
||||||
|
class MessageDedupeStore:
|
||||||
|
def __init__(self, path: Path, *, processing_ttl_seconds: int = 60) -> None:
|
||||||
|
self.path = Path(path)
|
||||||
|
self.processing_ttl_seconds = int(processing_ttl_seconds)
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def begin(self, *, connection_id: str, event_id: str, delivery_attempt: int) -> DedupeBeginResult:
|
||||||
|
dedupe_key = f"{connection_id}:{event_id}"
|
||||||
|
now = _iso_now()
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
raw = data["records"].get(dedupe_key)
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
record = ConnectorMessageDedupeRecord.from_dict(raw)
|
||||||
|
if record.status == "completed":
|
||||||
|
return DedupeBeginResult(False, dedupe_key, record.status, 200, None, record)
|
||||||
|
if record.status == "processing" and not self._is_stale(record, now):
|
||||||
|
return DedupeBeginResult(False, dedupe_key, record.status, 409, 5, record)
|
||||||
|
record.status = "processing"
|
||||||
|
record.updated_at = now
|
||||||
|
record.delivery_attempts = max(record.delivery_attempts + 1, int(delivery_attempt))
|
||||||
|
record.last_error = None
|
||||||
|
else:
|
||||||
|
record = ConnectorMessageDedupeRecord(
|
||||||
|
dedupe_key=dedupe_key,
|
||||||
|
connection_id=connection_id,
|
||||||
|
event_id=event_id,
|
||||||
|
status="processing",
|
||||||
|
first_seen_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
delivery_attempts=max(1, int(delivery_attempt)),
|
||||||
|
)
|
||||||
|
data["records"][dedupe_key] = record.to_dict()
|
||||||
|
self._save(data)
|
||||||
|
return DedupeBeginResult(True, dedupe_key, record.status, 200, None, record)
|
||||||
|
|
||||||
|
def complete(self, dedupe_key: str, *, message_id: str | None) -> ConnectorMessageDedupeRecord:
|
||||||
|
return self._mark(dedupe_key, status="completed", message_id=message_id, error=None)
|
||||||
|
|
||||||
|
def fail(self, dedupe_key: str, *, error: str) -> ConnectorMessageDedupeRecord:
|
||||||
|
return self._mark(dedupe_key, status="failed", message_id=None, error=error)
|
||||||
|
|
||||||
|
def _mark(
|
||||||
|
self,
|
||||||
|
dedupe_key: str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
message_id: str | None,
|
||||||
|
error: str | None,
|
||||||
|
) -> ConnectorMessageDedupeRecord:
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
raw = data["records"].get(dedupe_key)
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise KeyError(dedupe_key)
|
||||||
|
record = ConnectorMessageDedupeRecord.from_dict(raw)
|
||||||
|
record.status = status
|
||||||
|
record.updated_at = _iso_now()
|
||||||
|
record.message_id = message_id or record.message_id
|
||||||
|
record.last_error = error
|
||||||
|
data["records"][dedupe_key] = record.to_dict()
|
||||||
|
self._save(data)
|
||||||
|
return record
|
||||||
|
|
||||||
|
def _is_stale(self, record: ConnectorMessageDedupeRecord, now: str) -> bool:
|
||||||
|
age = (_parse_iso(now) - _parse_iso(record.updated_at)).total_seconds()
|
||||||
|
return age >= self.processing_ttl_seconds
|
||||||
|
|
||||||
|
def _load(self) -> dict[str, Any]:
|
||||||
|
if not self.path.exists():
|
||||||
|
return {"records": {}}
|
||||||
|
try:
|
||||||
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return {"records": {}}
|
||||||
|
if not isinstance(data, dict) or not isinstance(data.get("records"), dict):
|
||||||
|
return {"records": {}}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _save(self, data: dict[str, Any]) -> None:
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||||
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||||
|
tmp_path.replace(self.path)
|
||||||
@ -0,0 +1,210 @@
|
|||||||
|
"""Sidecar-backed channel connectors."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .models import ChannelRuntimeSpec, ValidationResult
|
||||||
|
from .sidecar_client import ConnectorSidecarClient
|
||||||
|
from .store import ChannelConnectionStore, CredentialStore
|
||||||
|
|
||||||
|
POLICY_CONFIG_KEYS = {
|
||||||
|
"allowFrom",
|
||||||
|
"groupAllowFrom",
|
||||||
|
"requireMentionInGroups",
|
||||||
|
"respondToMentionAll",
|
||||||
|
"dmMode",
|
||||||
|
"maxMessageChars",
|
||||||
|
"textBatchDelayMs",
|
||||||
|
"textBatchMaxMessages",
|
||||||
|
"textBatchMaxChars",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalConnectorBase:
|
||||||
|
kind = ""
|
||||||
|
capabilities: list[str] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
connection_store: ChannelConnectionStore,
|
||||||
|
credential_store: CredentialStore,
|
||||||
|
sidecar_client: ConnectorSidecarClient | Any,
|
||||||
|
sidecar_base_url: str,
|
||||||
|
) -> None:
|
||||||
|
self.connection_store = connection_store
|
||||||
|
self.credential_store = credential_store
|
||||||
|
self.sidecar_client = sidecar_client
|
||||||
|
self.sidecar_base_url = sidecar_base_url
|
||||||
|
self.callback_base_url = _callback_base_url()
|
||||||
|
|
||||||
|
async def start_session(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
display_name: str,
|
||||||
|
owner_user_id: str | None,
|
||||||
|
options: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
runtime_config = {"sidecarBaseUrl": self.sidecar_base_url}
|
||||||
|
runtime_config.update(_policy_runtime_config(options))
|
||||||
|
connection = self.connection_store.create(
|
||||||
|
kind=self.kind,
|
||||||
|
mode="sidecar",
|
||||||
|
display_name=display_name or self.kind,
|
||||||
|
account_id="",
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
auth_type="connector_session",
|
||||||
|
runtime_config=runtime_config,
|
||||||
|
capabilities=list(self.capabilities),
|
||||||
|
)
|
||||||
|
connection = self.connection_store.update_status(connection.connection_id, status="pairing", last_error=None)
|
||||||
|
payload = {
|
||||||
|
"kind": self.kind,
|
||||||
|
"connectionId": connection.connection_id,
|
||||||
|
"channelId": connection.channel_id,
|
||||||
|
"displayName": connection.display_name,
|
||||||
|
"callbackBaseUrl": self.callback_base_url,
|
||||||
|
"options": dict(options),
|
||||||
|
}
|
||||||
|
view = dict(await self.sidecar_client.start_session(payload))
|
||||||
|
connection.pairing_session_id = str(view.get("sessionId") or "")
|
||||||
|
connection = self.connection_store.update(connection)
|
||||||
|
connection = self._apply_session_view(connection, view)
|
||||||
|
view["connectionId"] = connection.connection_id
|
||||||
|
view["channelId"] = connection.channel_id
|
||||||
|
return view
|
||||||
|
|
||||||
|
async def poll_session(self, session_id: str) -> dict[str, Any]:
|
||||||
|
view = dict(await self.sidecar_client.get_session(session_id))
|
||||||
|
connection = self._connection_for_session(session_id)
|
||||||
|
connection = self._apply_session_view(connection, view)
|
||||||
|
view["connectionId"] = connection.connection_id
|
||||||
|
view["channelId"] = connection.channel_id
|
||||||
|
return view
|
||||||
|
|
||||||
|
def _apply_session_view(self, connection: Any, view: dict[str, Any]) -> Any:
|
||||||
|
status = str(view.get("status") or "")
|
||||||
|
if status == "connected":
|
||||||
|
connection.account_id = str(view.get("accountId") or connection.account_id)
|
||||||
|
connection.display_name = str(view.get("displayName") or connection.display_name)
|
||||||
|
metadata = view.get("metadata") if isinstance(view.get("metadata"), dict) else {}
|
||||||
|
state_ref = metadata.get("stateRef")
|
||||||
|
if state_ref:
|
||||||
|
connection.credentials_ref = self.credential_store.put(kind=self.kind, values={"stateRef": state_ref})
|
||||||
|
self.connection_store.update(connection)
|
||||||
|
self.connection_store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||||
|
elif status in {"expired", "error", "cancelled"}:
|
||||||
|
self.connection_store.update_status(
|
||||||
|
connection.connection_id,
|
||||||
|
status="error",
|
||||||
|
last_error=str(view.get("error") or status),
|
||||||
|
)
|
||||||
|
return self.connection_store.get(connection.connection_id)
|
||||||
|
|
||||||
|
async def validate(self, connection_id: str) -> ValidationResult:
|
||||||
|
connection = self.connection_store.get(connection_id)
|
||||||
|
if connection.status in {"connected", "running"}:
|
||||||
|
return ValidationResult(
|
||||||
|
ok=True,
|
||||||
|
status="connected",
|
||||||
|
account_id=connection.account_id,
|
||||||
|
display_name=connection.display_name,
|
||||||
|
)
|
||||||
|
return ValidationResult(ok=False, status=connection.status, error=connection.last_error)
|
||||||
|
|
||||||
|
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||||
|
connection = self.connection_store.get(connection_id)
|
||||||
|
if connection.status not in {"connected", "running"}:
|
||||||
|
raise ValueError(f"Connection is not connected: {connection.connection_id}")
|
||||||
|
return ChannelRuntimeSpec(
|
||||||
|
channel_id=connection.channel_id,
|
||||||
|
kind="external_connector",
|
||||||
|
mode="http",
|
||||||
|
account_id=connection.account_id,
|
||||||
|
display_name=connection.display_name,
|
||||||
|
config={
|
||||||
|
"platformKind": self.kind,
|
||||||
|
"connectionId": connection.connection_id,
|
||||||
|
**dict(connection.runtime_config),
|
||||||
|
"sidecarBaseUrl": connection.runtime_config.get("sidecarBaseUrl") or self.sidecar_base_url,
|
||||||
|
},
|
||||||
|
secrets_ref=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke(self, connection_id: str) -> None:
|
||||||
|
await self.sidecar_client.logout(connection_id)
|
||||||
|
|
||||||
|
def _connection_for_session(self, session_id: str):
|
||||||
|
for connection in self.connection_store.list():
|
||||||
|
if connection.pairing_session_id == session_id:
|
||||||
|
return connection
|
||||||
|
raise KeyError(session_id)
|
||||||
|
|
||||||
|
|
||||||
|
class WeixinConnector(ExternalConnectorBase):
|
||||||
|
kind = "weixin"
|
||||||
|
capabilities = ["receive_text", "send_text", "receive_media", "direct_messages"]
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuConnector(ExternalConnectorBase):
|
||||||
|
kind = "feishu"
|
||||||
|
capabilities = ["receive_text", "send_text", "receive_media", "groups"]
|
||||||
|
|
||||||
|
|
||||||
|
def _policy_runtime_config(options: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
for key in POLICY_CONFIG_KEYS:
|
||||||
|
if key not in options:
|
||||||
|
continue
|
||||||
|
value = options[key]
|
||||||
|
if key in {"allowFrom", "groupAllowFrom"}:
|
||||||
|
items = _string_list(value)
|
||||||
|
if items:
|
||||||
|
result[key] = items
|
||||||
|
continue
|
||||||
|
if key in {"maxMessageChars", "textBatchDelayMs", "textBatchMaxMessages", "textBatchMaxChars"}:
|
||||||
|
number = _positive_int(value)
|
||||||
|
if number is not None:
|
||||||
|
result[key] = number
|
||||||
|
continue
|
||||||
|
if key in {"requireMentionInGroups", "respondToMentionAll"}:
|
||||||
|
result[key] = _bool(value)
|
||||||
|
continue
|
||||||
|
text = str(value or "").strip()
|
||||||
|
if text:
|
||||||
|
result[key] = text
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _callback_base_url() -> str:
|
||||||
|
for name in ("EXTERNAL_CONNECTOR_CALLBACK_BASE_URL", "BEAVER_CONNECTOR_CALLBACK_BASE_URL"):
|
||||||
|
value = os.environ.get(name, "").strip()
|
||||||
|
if value:
|
||||||
|
return value.rstrip("/")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _string_list(value: Any) -> list[str]:
|
||||||
|
if isinstance(value, str):
|
||||||
|
raw_items = value.replace("\n", ",").split(",")
|
||||||
|
elif isinstance(value, list):
|
||||||
|
raw_items = value
|
||||||
|
else:
|
||||||
|
raw_items = []
|
||||||
|
return [str(item).strip() for item in raw_items if str(item).strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def _positive_int(value: Any) -> int | None:
|
||||||
|
try:
|
||||||
|
number = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
return number if number > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def _bool(value: Any) -> bool:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||||
@ -0,0 +1,117 @@
|
|||||||
|
"""Channel connection setup models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
CONNECTION_STATUSES = {"draft", "pairing", "connected", "running", "degraded", "error", "revoked"}
|
||||||
|
|
||||||
|
|
||||||
|
def iso_now() -> str:
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ChannelConnection:
|
||||||
|
connection_id: str
|
||||||
|
owner_user_id: str | None
|
||||||
|
channel_id: str
|
||||||
|
kind: str
|
||||||
|
mode: str
|
||||||
|
display_name: str
|
||||||
|
account_id: str
|
||||||
|
status: str
|
||||||
|
auth_type: str
|
||||||
|
credentials_ref: str | None = None
|
||||||
|
connector_ref: str | None = None
|
||||||
|
pairing_session_id: str | None = None
|
||||||
|
runtime_config: dict[str, Any] = field(default_factory=dict)
|
||||||
|
capabilities: list[str] = field(default_factory=list)
|
||||||
|
created_at: str = field(default_factory=iso_now)
|
||||||
|
updated_at: str = field(default_factory=iso_now)
|
||||||
|
last_seen_at: str | None = None
|
||||||
|
last_error: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "ChannelConnection":
|
||||||
|
return cls(
|
||||||
|
connection_id=str(data.get("connection_id") or ""),
|
||||||
|
owner_user_id=_optional_string(data.get("owner_user_id")),
|
||||||
|
channel_id=str(data.get("channel_id") or ""),
|
||||||
|
kind=str(data.get("kind") or ""),
|
||||||
|
mode=str(data.get("mode") or ""),
|
||||||
|
display_name=str(data.get("display_name") or ""),
|
||||||
|
account_id=str(data.get("account_id") or ""),
|
||||||
|
status=str(data.get("status") or "draft"),
|
||||||
|
auth_type=str(data.get("auth_type") or ""),
|
||||||
|
credentials_ref=_optional_string(data.get("credentials_ref")),
|
||||||
|
connector_ref=_optional_string(data.get("connector_ref")),
|
||||||
|
pairing_session_id=_optional_string(data.get("pairing_session_id")),
|
||||||
|
runtime_config=dict(data.get("runtime_config") or {}),
|
||||||
|
capabilities=[str(item) for item in data.get("capabilities") or []],
|
||||||
|
created_at=str(data.get("created_at") or iso_now()),
|
||||||
|
updated_at=str(data.get("updated_at") or iso_now()),
|
||||||
|
last_seen_at=_optional_string(data.get("last_seen_at")),
|
||||||
|
last_error=_optional_string(data.get("last_error")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class PairingSession:
|
||||||
|
pairing_session_id: str
|
||||||
|
kind: str
|
||||||
|
scope: str
|
||||||
|
token: str
|
||||||
|
status: str
|
||||||
|
expires_at_ms: int
|
||||||
|
created_at_ms: int
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "PairingSession":
|
||||||
|
return cls(
|
||||||
|
pairing_session_id=str(data.get("pairing_session_id") or ""),
|
||||||
|
kind=str(data.get("kind") or ""),
|
||||||
|
scope=str(data.get("scope") or ""),
|
||||||
|
token=str(data.get("token") or ""),
|
||||||
|
status=str(data.get("status") or "pending"),
|
||||||
|
expires_at_ms=int(data.get("expires_at_ms") or 0),
|
||||||
|
created_at_ms=int(data.get("created_at_ms") or 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ChannelRuntimeSpec:
|
||||||
|
channel_id: str
|
||||||
|
kind: str
|
||||||
|
mode: str
|
||||||
|
account_id: str
|
||||||
|
display_name: str
|
||||||
|
config: dict[str, Any] = field(default_factory=dict)
|
||||||
|
secrets_ref: str | None = None
|
||||||
|
external_endpoint: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ValidationResult:
|
||||||
|
ok: bool
|
||||||
|
status: str
|
||||||
|
account_id: str | None = None
|
||||||
|
display_name: str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _optional_string(value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
text = str(value).strip()
|
||||||
|
return text or None
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
"""HTTP client for the generic external connector sidecar."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectorSidecarClient:
|
||||||
|
def __init__(self, *, base_url: str, token: str, timeout_seconds: float = 20.0) -> None:
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.token = token
|
||||||
|
self.timeout_seconds = float(timeout_seconds)
|
||||||
|
|
||||||
|
async def get_connectors(self) -> list[dict[str, Any]]:
|
||||||
|
return await self._request("GET", "/connectors")
|
||||||
|
|
||||||
|
async def start_session(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return await self._request("POST", "/connector-sessions", json=payload)
|
||||||
|
|
||||||
|
async def get_session(self, session_id: str) -> dict[str, Any]:
|
||||||
|
return await self._request("GET", f"/connector-sessions/{session_id}")
|
||||||
|
|
||||||
|
async def cancel_session(self, session_id: str) -> dict[str, Any]:
|
||||||
|
return await self._request("POST", f"/connector-sessions/{session_id}/cancel", json={})
|
||||||
|
|
||||||
|
async def logout(self, connection_id: str) -> dict[str, Any]:
|
||||||
|
return await self._request("POST", f"/connections/{connection_id}/logout", json={})
|
||||||
|
|
||||||
|
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return await self._request("POST", "/send", json=payload)
|
||||||
|
|
||||||
|
async def _request(self, method: str, path: str, *, json: dict[str, Any] | None = None) -> Any:
|
||||||
|
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
|
||||||
|
response = await client.request(method, f"{self.base_url}{path}", json=json, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
@ -0,0 +1,222 @@
|
|||||||
|
"""Persistent channel connection stores."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from .models import CONNECTION_STATUSES, ChannelConnection, PairingSession, iso_now
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectionStore:
|
||||||
|
def __init__(self, path: Path) -> None:
|
||||||
|
self.path = Path(path)
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
kind: str,
|
||||||
|
mode: str,
|
||||||
|
display_name: str,
|
||||||
|
account_id: str,
|
||||||
|
owner_user_id: str | None,
|
||||||
|
auth_type: str,
|
||||||
|
runtime_config: dict[str, Any] | None = None,
|
||||||
|
capabilities: list[str] | None = None,
|
||||||
|
credentials_ref: str | None = None,
|
||||||
|
) -> ChannelConnection:
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
connection_id = f"conn_{uuid4().hex}"
|
||||||
|
channel_id = f"{_slug(kind)}-{uuid4().hex[:8]}"
|
||||||
|
now = iso_now()
|
||||||
|
connection = ChannelConnection(
|
||||||
|
connection_id=connection_id,
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=kind,
|
||||||
|
mode=mode,
|
||||||
|
display_name=display_name or channel_id,
|
||||||
|
account_id=account_id,
|
||||||
|
status="draft",
|
||||||
|
auth_type=auth_type,
|
||||||
|
credentials_ref=credentials_ref,
|
||||||
|
runtime_config=runtime_config or {},
|
||||||
|
capabilities=capabilities or [],
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
data["connections"][connection_id] = connection.to_dict()
|
||||||
|
self._save(data)
|
||||||
|
return connection
|
||||||
|
|
||||||
|
def get(self, connection_id: str) -> ChannelConnection:
|
||||||
|
data = self._load()
|
||||||
|
raw = data["connections"].get(connection_id)
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise KeyError(connection_id)
|
||||||
|
return ChannelConnection.from_dict(raw)
|
||||||
|
|
||||||
|
def list(self) -> list[ChannelConnection]:
|
||||||
|
data = self._load()
|
||||||
|
return [ChannelConnection.from_dict(item) for item in data["connections"].values() if isinstance(item, dict)]
|
||||||
|
|
||||||
|
def update(self, connection: ChannelConnection) -> ChannelConnection:
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
if connection.connection_id not in data["connections"]:
|
||||||
|
raise KeyError(connection.connection_id)
|
||||||
|
connection.updated_at = iso_now()
|
||||||
|
data["connections"][connection.connection_id] = connection.to_dict()
|
||||||
|
self._save(data)
|
||||||
|
return connection
|
||||||
|
|
||||||
|
def update_status(self, connection_id: str, *, status: str, last_error: str | None) -> ChannelConnection:
|
||||||
|
if status not in CONNECTION_STATUSES:
|
||||||
|
raise ValueError(f"Unsupported connection status: {status}")
|
||||||
|
connection = self.get(connection_id)
|
||||||
|
connection.status = status
|
||||||
|
connection.last_error = last_error
|
||||||
|
if status in {"connected", "running"}:
|
||||||
|
connection.last_seen_at = iso_now()
|
||||||
|
return self.update(connection)
|
||||||
|
|
||||||
|
def revoke(self, connection_id: str) -> ChannelConnection:
|
||||||
|
return self.update_status(connection_id, status="revoked", last_error=None)
|
||||||
|
|
||||||
|
def _load(self) -> dict[str, Any]:
|
||||||
|
if not self.path.exists():
|
||||||
|
return {"connections": {}}
|
||||||
|
try:
|
||||||
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return {"connections": {}}
|
||||||
|
if not isinstance(data, dict) or not isinstance(data.get("connections"), dict):
|
||||||
|
return {"connections": {}}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _save(self, data: dict[str, Any]) -> None:
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||||
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||||
|
tmp_path.replace(self.path)
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialStore:
|
||||||
|
def __init__(self, path: Path) -> None:
|
||||||
|
self.path = Path(path)
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def put(self, *, kind: str, values: dict[str, Any]) -> str:
|
||||||
|
cleaned = {str(key): str(value) for key, value in values.items() if str(key).strip() and str(value).strip()}
|
||||||
|
ref = f"cred_{uuid4().hex}"
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
data["credentials"][ref] = {"kind": kind, "values": cleaned, "created_at": iso_now()}
|
||||||
|
self._save(data)
|
||||||
|
return ref
|
||||||
|
|
||||||
|
def get(self, ref: str) -> dict[str, str]:
|
||||||
|
data = self._load()
|
||||||
|
item = data["credentials"].get(ref)
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
raise KeyError(ref)
|
||||||
|
values = item.get("values")
|
||||||
|
if not isinstance(values, dict):
|
||||||
|
return {}
|
||||||
|
return {str(key): str(value) for key, value in values.items()}
|
||||||
|
|
||||||
|
def redacted(self, ref: str | None) -> dict[str, str]:
|
||||||
|
if not ref:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
values = self.get(ref)
|
||||||
|
except KeyError:
|
||||||
|
return {}
|
||||||
|
return {key: "***" for key in values}
|
||||||
|
|
||||||
|
def _load(self) -> dict[str, Any]:
|
||||||
|
if not self.path.exists():
|
||||||
|
return {"credentials": {}}
|
||||||
|
try:
|
||||||
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return {"credentials": {}}
|
||||||
|
if not isinstance(data, dict) or not isinstance(data.get("credentials"), dict):
|
||||||
|
return {"credentials": {}}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _save(self, data: dict[str, Any]) -> None:
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||||
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||||
|
tmp_path.replace(self.path)
|
||||||
|
|
||||||
|
|
||||||
|
class PairingTokenStore:
|
||||||
|
def __init__(self, path: Path) -> None:
|
||||||
|
self.path = Path(path)
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def create(self, *, kind: str, ttl_seconds: int, scope: str) -> PairingSession:
|
||||||
|
now_ms = _now_ms()
|
||||||
|
session = PairingSession(
|
||||||
|
pairing_session_id=f"pair_{uuid4().hex}",
|
||||||
|
kind=kind,
|
||||||
|
scope=scope,
|
||||||
|
token=f"pair_{uuid4().hex}",
|
||||||
|
status="pending",
|
||||||
|
expires_at_ms=now_ms + int(ttl_seconds * 1000),
|
||||||
|
created_at_ms=now_ms,
|
||||||
|
)
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
data["sessions"][session.pairing_session_id] = session.to_dict()
|
||||||
|
self._save(data)
|
||||||
|
return session
|
||||||
|
|
||||||
|
def consume(self, token: str, *, expected_kind: str) -> PairingSession | None:
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
for key, raw in data["sessions"].items():
|
||||||
|
session = PairingSession.from_dict(raw)
|
||||||
|
if session.token != token or session.kind != expected_kind:
|
||||||
|
continue
|
||||||
|
if session.status != "pending" or session.expires_at_ms <= _now_ms():
|
||||||
|
return None
|
||||||
|
session.status = "consumed"
|
||||||
|
data["sessions"][key] = session.to_dict()
|
||||||
|
self._save(data)
|
||||||
|
return session
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load(self) -> dict[str, Any]:
|
||||||
|
if not self.path.exists():
|
||||||
|
return {"sessions": {}}
|
||||||
|
try:
|
||||||
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return {"sessions": {}}
|
||||||
|
if not isinstance(data, dict) or not isinstance(data.get("sessions"), dict):
|
||||||
|
return {"sessions": {}}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _save(self, data: dict[str, Any]) -> None:
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||||
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||||
|
tmp_path.replace(self.path)
|
||||||
|
|
||||||
|
|
||||||
|
def _now_ms() -> int:
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _slug(value: str) -> str:
|
||||||
|
text = "".join(char if char.isalnum() else "-" for char in str(value).strip().lower())
|
||||||
|
return "-".join(part for part in text.split("-") if part) or "channel"
|
||||||
@ -0,0 +1,92 @@
|
|||||||
|
"""Telegram channel connector."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .models import ChannelRuntimeSpec, ValidationResult
|
||||||
|
from .store import ChannelConnectionStore, CredentialStore
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramConnector:
|
||||||
|
kind = "telegram"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
connection_store: ChannelConnectionStore,
|
||||||
|
credential_store: CredentialStore,
|
||||||
|
client_factory: Callable[[str], Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.connection_store = connection_store
|
||||||
|
self.credential_store = credential_store
|
||||||
|
self.client_factory = client_factory or _default_client_factory
|
||||||
|
|
||||||
|
async def validate(self, connection_id: str) -> ValidationResult:
|
||||||
|
connection = self.connection_store.get(connection_id)
|
||||||
|
token = self._bot_token(connection.credentials_ref)
|
||||||
|
try:
|
||||||
|
client = self.client_factory(token)
|
||||||
|
raw = await client.get_me()
|
||||||
|
bot_id = _value(raw, "id")
|
||||||
|
username = _value(raw, "username")
|
||||||
|
first_name = _value(raw, "first_name") or "Telegram Bot"
|
||||||
|
account_id = f"telegram:{bot_id}" if bot_id else connection.account_id
|
||||||
|
display_name = f"{first_name} (@{username})" if username else first_name
|
||||||
|
connection.account_id = account_id
|
||||||
|
connection.display_name = display_name
|
||||||
|
connection.capabilities = ["receive_text", "send_text", "receive_media", "groups"]
|
||||||
|
self.connection_store.update(connection)
|
||||||
|
return ValidationResult(
|
||||||
|
ok=True,
|
||||||
|
status="connected",
|
||||||
|
account_id=account_id,
|
||||||
|
display_name=display_name,
|
||||||
|
metadata={"username": username} if username else {},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return ValidationResult(ok=False, status="error", error=str(exc))
|
||||||
|
|
||||||
|
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||||
|
connection = self.connection_store.get(connection_id)
|
||||||
|
if connection.status not in {"connected", "running"}:
|
||||||
|
raise ValueError(f"Connection is not connected: {connection.connection_id}")
|
||||||
|
return ChannelRuntimeSpec(
|
||||||
|
channel_id=connection.channel_id,
|
||||||
|
kind=connection.kind,
|
||||||
|
mode=connection.mode,
|
||||||
|
account_id=connection.account_id,
|
||||||
|
display_name=connection.display_name,
|
||||||
|
config=dict(connection.runtime_config),
|
||||||
|
secrets_ref=connection.credentials_ref,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke(self, connection_id: str) -> None:
|
||||||
|
# Telegram bot tokens do not have a Beaver-managed platform revoke action.
|
||||||
|
# The registry owns local connection state transitions.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _bot_token(self, credentials_ref: str | None) -> str:
|
||||||
|
if not credentials_ref:
|
||||||
|
raise ValueError("Telegram credentials are missing")
|
||||||
|
token = self.credential_store.get(credentials_ref).get("botToken")
|
||||||
|
if not token:
|
||||||
|
raise ValueError("botToken is required")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def _value(raw: Any, key: str) -> str:
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
value = raw.get(key)
|
||||||
|
else:
|
||||||
|
value = getattr(raw, key, None)
|
||||||
|
return str(value).strip() if value is not None else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _default_client_factory(token: str) -> Any:
|
||||||
|
try:
|
||||||
|
from telegram import Bot
|
||||||
|
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||||
|
raise RuntimeError("Install beaver-backend[telegram] to validate Telegram connections") from exc
|
||||||
|
return Bot(token=token)
|
||||||
@ -0,0 +1,97 @@
|
|||||||
|
"""Generic runtime channel backed by an external connector sidecar."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.events import OutboundMessage
|
||||||
|
from beaver.interfaces.channels.connections.sidecar_client import ConnectorSidecarClient
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalConnectorChannel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
platform_kind: str,
|
||||||
|
connection_id: str,
|
||||||
|
account_id: str,
|
||||||
|
display_name: str,
|
||||||
|
sidecar_client: ConnectorSidecarClient | Any,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.kind = "external_connector"
|
||||||
|
self.mode = "http"
|
||||||
|
self.platform_kind = platform_kind
|
||||||
|
self.connection_id = connection_id
|
||||||
|
self.account_id = account_id
|
||||||
|
self.display_name = display_name or channel_id
|
||||||
|
self.sidecar_client = sidecar_client
|
||||||
|
self.started = False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self.started = True
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self.started = False
|
||||||
|
|
||||||
|
async def send(self, message: OutboundMessage) -> None:
|
||||||
|
identity = message.channel_identity
|
||||||
|
if identity is None:
|
||||||
|
raise ValueError("channel_identity is required for external connector sends")
|
||||||
|
metadata = {
|
||||||
|
"inboundMessageId": identity.message_id,
|
||||||
|
"sessionId": message.session_id,
|
||||||
|
}
|
||||||
|
context_token = _context_token(message)
|
||||||
|
if context_token:
|
||||||
|
metadata["contextToken"] = context_token
|
||||||
|
payload = {
|
||||||
|
"requestId": _request_id(message),
|
||||||
|
"connectionId": self.connection_id,
|
||||||
|
"channelId": self.channel_id,
|
||||||
|
"kind": self.platform_kind,
|
||||||
|
"target": {
|
||||||
|
"peerId": identity.peer_id,
|
||||||
|
"peerType": identity.peer_type,
|
||||||
|
"threadId": identity.thread_id,
|
||||||
|
},
|
||||||
|
"content": message.content,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
await self.sidecar_client.send(payload)
|
||||||
|
|
||||||
|
|
||||||
|
def _request_id(message: OutboundMessage) -> str:
|
||||||
|
identity = message.channel_identity
|
||||||
|
channel = message.channel or (identity.channel_id if identity else "unknown")
|
||||||
|
session_id = message.session_id or (identity.session_id() if identity else "unknown")
|
||||||
|
message_id = str(message.message_id or "").strip()
|
||||||
|
if not message_id:
|
||||||
|
basis = "|".join(
|
||||||
|
[
|
||||||
|
message.content,
|
||||||
|
identity.message_id if identity and identity.message_id else "",
|
||||||
|
identity.peer_id if identity else "",
|
||||||
|
message.finish_reason,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
message_id = hashlib.sha256(basis.encode("utf-8")).hexdigest()[:24]
|
||||||
|
return f"out_{channel}:{session_id}:{message_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def _context_token(message: OutboundMessage) -> str | None:
|
||||||
|
inbound_metadata = message.metadata.get("inbound_metadata")
|
||||||
|
if isinstance(inbound_metadata, dict):
|
||||||
|
value = _clean_optional(inbound_metadata.get("contextToken") or inbound_metadata.get("context_token"))
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
return _clean_optional(message.metadata.get("contextToken") or message.metadata.get("context_token"))
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_optional(value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
text = str(value).strip()
|
||||||
|
return text or None
|
||||||
@ -0,0 +1,116 @@
|
|||||||
|
"""Generic fixed-schema text webhook channel adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||||
|
|
||||||
|
|
||||||
|
class GenericWebhookAdapter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
mode: str,
|
||||||
|
account_id: str,
|
||||||
|
display_name: str = "",
|
||||||
|
inbound_sink: ChannelInboundSink,
|
||||||
|
response_timeout_seconds: float = 1800,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.kind = kind
|
||||||
|
self.mode = mode
|
||||||
|
self.account_id = account_id
|
||||||
|
self.display_name = display_name or channel_id
|
||||||
|
self.inbound_sink = inbound_sink
|
||||||
|
self.response_timeout_seconds = max(1.0, float(response_timeout_seconds))
|
||||||
|
self.started = False
|
||||||
|
self._pending: dict[str, asyncio.Future[OutboundMessage]] = {}
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self.started = True
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self.started = False
|
||||||
|
for future in list(self._pending.values()):
|
||||||
|
if not future.done():
|
||||||
|
future.cancel()
|
||||||
|
self._pending.clear()
|
||||||
|
|
||||||
|
async def handle_webhook_payload(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
text = str(payload.get("text") or "").strip()
|
||||||
|
peer_id = str(payload.get("peer_id") or "").strip()
|
||||||
|
message_id = str(payload.get("message_id") or "").strip()
|
||||||
|
thread_id = str(payload.get("thread_id") or "").strip() or None
|
||||||
|
peer_type = str(payload.get("peer_type") or "unknown").strip() or "unknown"
|
||||||
|
user_id = str(payload.get("user_id") or "").strip() or None
|
||||||
|
if not text:
|
||||||
|
return {"ok": False, "error": "text is required"}
|
||||||
|
if not peer_id:
|
||||||
|
return {"ok": False, "error": "peer_id is required"}
|
||||||
|
if not message_id:
|
||||||
|
return {"ok": False, "error": "message_id is required"}
|
||||||
|
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=self.kind,
|
||||||
|
account_id=self.account_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
peer_type=peer_type,
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
inbound = InboundMessage(
|
||||||
|
channel=self.channel_id,
|
||||||
|
content=text,
|
||||||
|
user_id=user_id,
|
||||||
|
channel_identity=identity,
|
||||||
|
metadata={"webhook": {"peer_type": peer_type}},
|
||||||
|
)
|
||||||
|
future = asyncio.get_running_loop().create_future()
|
||||||
|
self._pending[inbound.message_id] = future
|
||||||
|
accept = await self.inbound_sink.accept_inbound(inbound)
|
||||||
|
if not accept.accepted:
|
||||||
|
self._pending.pop(inbound.message_id, None)
|
||||||
|
record = accept.record or {}
|
||||||
|
return {
|
||||||
|
"ok": accept.error is None,
|
||||||
|
"duplicate": accept.duplicate,
|
||||||
|
"pending": accept.pending,
|
||||||
|
"session_id": accept.session_id,
|
||||||
|
"status": record.get("status"),
|
||||||
|
"run_id": record.get("run_id"),
|
||||||
|
"reply": record.get("reply"),
|
||||||
|
"error": accept.error or record.get("error"),
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
outbound = await asyncio.wait_for(future, timeout=self.response_timeout_seconds)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending.pop(inbound.message_id, None)
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"duplicate": False,
|
||||||
|
"pending": True,
|
||||||
|
"session_id": accept.session_id,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"ok": outbound.finish_reason != "error",
|
||||||
|
"duplicate": False,
|
||||||
|
"pending": False,
|
||||||
|
"session_id": outbound.session_id,
|
||||||
|
"run_id": outbound.run_id,
|
||||||
|
"reply": outbound.content,
|
||||||
|
"error": outbound.metadata.get("error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def send(self, message: OutboundMessage) -> None:
|
||||||
|
future = self._pending.pop(message.message_id, None)
|
||||||
|
if future is None or future.done():
|
||||||
|
message.metadata["delivery_status"] = "unclaimed"
|
||||||
|
return
|
||||||
|
future.set_result(message)
|
||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
|
||||||
from beaver.foundation.events import MessageBus, OutboundMessage
|
from beaver.foundation.events import MessageBus, OutboundMessage
|
||||||
@ -20,13 +21,17 @@ class ChannelManager:
|
|||||||
self.started = False
|
self.started = False
|
||||||
|
|
||||||
def register(self, channel: ChannelAdapter) -> None:
|
def register(self, channel: ChannelAdapter) -> None:
|
||||||
if self.started:
|
if channel.channel_id in self.channels:
|
||||||
raise RuntimeError("Cannot register channels after ChannelManager.start()")
|
raise ValueError(f"Channel already registered: {channel.channel_id}")
|
||||||
if channel.name in self.channels:
|
self.channels[channel.channel_id] = channel
|
||||||
raise ValueError(f"Channel already registered: {channel.name}")
|
|
||||||
if channel.bus is not self.bus:
|
def unregister(self, channel_id: str) -> ChannelAdapter | None:
|
||||||
raise ValueError("Channel must share the same MessageBus as ChannelManager")
|
return self.channels.pop(channel_id, None)
|
||||||
self.channels[channel.name] = channel
|
|
||||||
|
def replace_registered(self, channel: ChannelAdapter) -> ChannelAdapter | None:
|
||||||
|
old = self.channels.get(channel.channel_id)
|
||||||
|
self.channels[channel.channel_id] = channel
|
||||||
|
return old
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
started: list[ChannelAdapter] = []
|
started: list[ChannelAdapter] = []
|
||||||
@ -53,7 +58,13 @@ class ChannelManager:
|
|||||||
if errors:
|
if errors:
|
||||||
raise RuntimeError(f"Failed to stop {len(errors)} channel(s)") from errors[0]
|
raise RuntimeError(f"Failed to stop {len(errors)} channel(s)") from errors[0]
|
||||||
|
|
||||||
async def dispatch_outbound(self, stop_event: asyncio.Event) -> None:
|
async def dispatch_outbound(
|
||||||
|
self,
|
||||||
|
stop_event: asyncio.Event,
|
||||||
|
*,
|
||||||
|
on_delivered: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
||||||
|
on_failed: Callable[[OutboundMessage, Exception | None], Awaitable[None]] | None = None,
|
||||||
|
) -> None:
|
||||||
"""Route bus outbound messages until stopped and the queue is drained."""
|
"""Route bus outbound messages until stopped and the queue is drained."""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -68,9 +79,16 @@ class ChannelManager:
|
|||||||
channel = self.channels.get(message.channel)
|
channel = self.channels.get(message.channel)
|
||||||
if channel is None:
|
if channel is None:
|
||||||
self.undeliverable.append(message)
|
self.undeliverable.append(message)
|
||||||
|
if on_failed is not None:
|
||||||
|
await on_failed(message, None)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await channel.send(message)
|
await channel.send(message)
|
||||||
except Exception: # pragma: no cover - defensive channel isolation
|
except Exception as exc: # pragma: no cover - defensive channel isolation
|
||||||
self.undeliverable.append(message)
|
self.undeliverable.append(message)
|
||||||
|
if on_failed is not None:
|
||||||
|
await on_failed(message, exc)
|
||||||
|
else:
|
||||||
|
if on_delivered is not None:
|
||||||
|
await on_delivered(message)
|
||||||
|
|||||||
@ -4,15 +4,27 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||||
|
|
||||||
|
|
||||||
class MemoryChannelAdapter:
|
class MemoryChannelAdapter:
|
||||||
"""A local channel that stores outbound messages in memory."""
|
"""A local channel that stores outbound messages in memory."""
|
||||||
|
|
||||||
def __init__(self, bus: MessageBus, *, name: str = "memory") -> None:
|
def __init__(
|
||||||
self.name = name
|
self,
|
||||||
self.bus = bus
|
inbound_sink: ChannelInboundSink,
|
||||||
|
*,
|
||||||
|
channel_id: str = "memory-dev",
|
||||||
|
kind: str = "memory",
|
||||||
|
mode: str = "webhook",
|
||||||
|
account_id: str = "memory",
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.kind = kind
|
||||||
|
self.mode = mode
|
||||||
|
self.account_id = account_id
|
||||||
|
self.inbound_sink = inbound_sink
|
||||||
self.started = False
|
self.started = False
|
||||||
self.sent_messages: list[OutboundMessage] = []
|
self.sent_messages: list[OutboundMessage] = []
|
||||||
|
|
||||||
@ -36,12 +48,24 @@ class MemoryChannelAdapter:
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
provider_name: str | None = None,
|
provider_name: str | None = None,
|
||||||
embedding_model: str | None = None,
|
embedding_model: str | None = None,
|
||||||
|
peer_id: str = "default",
|
||||||
|
thread_id: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> InboundMessage:
|
) -> InboundMessage:
|
||||||
"""Publish a text message from this channel into the shared bus."""
|
"""Publish a text message from this channel into the shared bus."""
|
||||||
|
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=self.kind,
|
||||||
|
account_id=self.account_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
message = InboundMessage(
|
message = InboundMessage(
|
||||||
channel=self.name,
|
channel=self.channel_id,
|
||||||
content=content,
|
content=content,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@ -50,9 +74,10 @@ class MemoryChannelAdapter:
|
|||||||
model=model,
|
model=model,
|
||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
embedding_model=embedding_model,
|
embedding_model=embedding_model,
|
||||||
|
channel_identity=identity,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
)
|
)
|
||||||
await self.bus.publish_inbound(message)
|
await self.inbound_sink.accept_inbound(message)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
async def publish_external_text(
|
async def publish_external_text(
|
||||||
@ -73,9 +98,6 @@ class MemoryChannelAdapter:
|
|||||||
the shared gateway bus.
|
the shared gateway bus.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
session_parts = [self.name, chat_id]
|
|
||||||
if thread_id:
|
|
||||||
session_parts.append(thread_id)
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
@ -84,8 +106,10 @@ class MemoryChannelAdapter:
|
|||||||
}
|
}
|
||||||
return await self.publish_text(
|
return await self.publish_text(
|
||||||
content,
|
content,
|
||||||
session_id=":".join(str(part) for part in session_parts if str(part)),
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
title=title,
|
title=title,
|
||||||
|
peer_id=chat_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_id=message_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -0,0 +1 @@
|
|||||||
|
"""Platform channel adapters."""
|
||||||
@ -0,0 +1,138 @@
|
|||||||
|
"""Shared helpers for platform channel adapters."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class OutboundTarget:
|
||||||
|
peer_id: str | None
|
||||||
|
thread_id: str | None = None
|
||||||
|
peer_type: str = "unknown"
|
||||||
|
user_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PlatformDeliveryError(RuntimeError):
|
||||||
|
"""Raised when a platform client rejects a delivery."""
|
||||||
|
|
||||||
|
|
||||||
|
def config_bool(config: dict[str, Any], key: str, *, default: bool = False) -> bool:
|
||||||
|
value = config.get(key)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return bool(value)
|
||||||
|
text = str(value).strip().lower()
|
||||||
|
if text in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if text in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def config_list(config: dict[str, Any], key: str) -> list[str]:
|
||||||
|
value = config.get(key)
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if isinstance(value, str):
|
||||||
|
return [part.strip() for part in value.split(",") if part.strip()]
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
return [str(item).strip() for item in value if str(item).strip()]
|
||||||
|
text = str(value).strip()
|
||||||
|
return [text] if text else []
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_text(text: str, *, max_chars: int) -> list[str]:
|
||||||
|
if max_chars <= 0:
|
||||||
|
raise ValueError("max_chars must be positive")
|
||||||
|
if not text:
|
||||||
|
return [""]
|
||||||
|
return [text[index : index + max_chars] for index in range(0, len(text), max_chars)]
|
||||||
|
|
||||||
|
|
||||||
|
def compact_media_summary(media_type: str, *, file_name: str | None = None) -> str:
|
||||||
|
label = str(media_type or "attachment").strip() or "attachment"
|
||||||
|
if file_name:
|
||||||
|
return f"[{label}: {file_name}]"
|
||||||
|
return f"[{label}]"
|
||||||
|
|
||||||
|
|
||||||
|
def target_from_session_id(session_id: str | None) -> OutboundTarget:
|
||||||
|
if not session_id:
|
||||||
|
return OutboundTarget(peer_id=None)
|
||||||
|
parts = str(session_id).split(":")
|
||||||
|
if len(parts) < 3:
|
||||||
|
return OutboundTarget(peer_id=None)
|
||||||
|
thread_id = parts[3] if len(parts) > 3 and parts[3] else None
|
||||||
|
return OutboundTarget(peer_id=parts[2] or None, thread_id=thread_id)
|
||||||
|
|
||||||
|
|
||||||
|
def outbound_target(message: OutboundMessage) -> OutboundTarget:
|
||||||
|
identity = message.channel_identity
|
||||||
|
if identity is None:
|
||||||
|
return target_from_session_id(message.session_id)
|
||||||
|
return OutboundTarget(
|
||||||
|
peer_id=identity.peer_id,
|
||||||
|
thread_id=identity.thread_id,
|
||||||
|
peer_type=identity.peer_type,
|
||||||
|
user_id=identity.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mark_unclaimed(message: OutboundMessage) -> None:
|
||||||
|
message.metadata["delivery_status"] = "unclaimed"
|
||||||
|
|
||||||
|
|
||||||
|
def build_inbound_message(
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
account_id: str,
|
||||||
|
peer_id: str,
|
||||||
|
content: str,
|
||||||
|
message_id: str | None,
|
||||||
|
peer_type: str,
|
||||||
|
user_id: str | None = None,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> InboundMessage:
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=kind,
|
||||||
|
account_id=account_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
peer_type=peer_type,
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
return InboundMessage(
|
||||||
|
channel=channel_id,
|
||||||
|
content=content,
|
||||||
|
session_id=identity.session_id(),
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id or "",
|
||||||
|
channel_identity=identity,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def allowed_by_policy(
|
||||||
|
*,
|
||||||
|
policy: str | None,
|
||||||
|
identifier: str | None,
|
||||||
|
allowlist: list[str],
|
||||||
|
default: str = "open",
|
||||||
|
) -> bool:
|
||||||
|
effective = (policy or default).strip().lower()
|
||||||
|
if effective == "disabled":
|
||||||
|
return False
|
||||||
|
if effective == "allowlist":
|
||||||
|
return bool(identifier and identifier in allowlist)
|
||||||
|
return True
|
||||||
@ -0,0 +1,207 @@
|
|||||||
|
"""Feishu/Lark channel adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
build_inbound_message,
|
||||||
|
chunk_text,
|
||||||
|
compact_media_summary,
|
||||||
|
config_bool,
|
||||||
|
config_list,
|
||||||
|
mark_unclaimed,
|
||||||
|
outbound_target,
|
||||||
|
)
|
||||||
|
|
||||||
|
EventRecorder = Callable[..., None]
|
||||||
|
|
||||||
|
|
||||||
|
class FeishuAdapter:
|
||||||
|
"""Feishu/Lark bot adapter with injectable client support."""
|
||||||
|
|
||||||
|
KIND = "feishu"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
mode: str,
|
||||||
|
account_id: str,
|
||||||
|
display_name: str | None,
|
||||||
|
inbound_sink: ChannelInboundSink,
|
||||||
|
secrets: dict[str, Any] | None = None,
|
||||||
|
config: dict[str, Any] | None = None,
|
||||||
|
event_recorder: EventRecorder | None = None,
|
||||||
|
client: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.kind = kind
|
||||||
|
self.mode = mode
|
||||||
|
self.account_id = account_id
|
||||||
|
self.display_name = display_name
|
||||||
|
self.inbound_sink = inbound_sink
|
||||||
|
self.secrets = secrets or {}
|
||||||
|
self.config = config or {}
|
||||||
|
self.event_recorder = event_recorder
|
||||||
|
self._client = client
|
||||||
|
self.max_message_chars = int(self.config.get("maxMessageChars") or 4096)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self._client is not None:
|
||||||
|
return
|
||||||
|
if self.mode not in {"websocket", "webhook"}:
|
||||||
|
raise ValueError(f"Unsupported feishu mode: {self.mode}")
|
||||||
|
self._client = self._build_client()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
close = getattr(self._client, "close", None)
|
||||||
|
if close is not None:
|
||||||
|
result = close()
|
||||||
|
if hasattr(result, "__await__"):
|
||||||
|
await result
|
||||||
|
|
||||||
|
async def handle_event_payload(self, payload: dict[str, Any]) -> None:
|
||||||
|
message = self._normalize_payload(payload)
|
||||||
|
if message is None:
|
||||||
|
return
|
||||||
|
await self.inbound_sink.accept_inbound(message)
|
||||||
|
|
||||||
|
async def send(self, message: OutboundMessage) -> None:
|
||||||
|
target = outbound_target(message)
|
||||||
|
if not target.peer_id:
|
||||||
|
mark_unclaimed(message)
|
||||||
|
return
|
||||||
|
client = self._require_client()
|
||||||
|
for chunk in chunk_text(message.content, max_chars=self.max_message_chars):
|
||||||
|
await client.send_text(receive_id_type="chat_id", receive_id=target.peer_id, text=chunk)
|
||||||
|
|
||||||
|
def _normalize_payload(self, payload: dict[str, Any]) -> InboundMessage | None:
|
||||||
|
event = payload.get("event") if isinstance(payload.get("event"), dict) else payload
|
||||||
|
message = event.get("message") if isinstance(event.get("message"), dict) else {}
|
||||||
|
sender = event.get("sender") if isinstance(event.get("sender"), dict) else {}
|
||||||
|
|
||||||
|
peer_id = _string_or_none(message.get("chat_id"))
|
||||||
|
if not peer_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_id = _string_or_none(message.get("message_id"))
|
||||||
|
message_type = str(message.get("message_type") or "unknown")
|
||||||
|
chat_type = str(message.get("chat_type") or "unknown")
|
||||||
|
peer_type = "dm" if chat_type == "p2p" else "group"
|
||||||
|
user_id = _sender_open_id(sender)
|
||||||
|
|
||||||
|
if peer_type == "dm" and not self._dm_allowed(user_id or peer_id):
|
||||||
|
return None
|
||||||
|
if peer_type == "group" and not self._group_allowed(peer_id, user_id):
|
||||||
|
return None
|
||||||
|
if peer_type == "group" and config_bool(self.config, "requireMentionInGroups", default=False):
|
||||||
|
if not self._message_mentions_bot(message):
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = self._message_content(message_type, message)
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"chat_id": peer_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
"chat_type": chat_type,
|
||||||
|
"message_type": message_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
return build_inbound_message(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=self.kind,
|
||||||
|
account_id=self.account_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
peer_type=peer_type,
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
content=content,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _message_content(self, message_type: str, message: dict[str, Any]) -> str:
|
||||||
|
content = _parse_json_object(message.get("content"))
|
||||||
|
if message_type == "text":
|
||||||
|
return str(content.get("text") or "").strip()
|
||||||
|
file_name = _string_or_none(content.get("file_name") or content.get("name"))
|
||||||
|
return compact_media_summary(message_type, file_name=file_name)
|
||||||
|
|
||||||
|
def _message_mentions_bot(self, message: dict[str, Any]) -> bool:
|
||||||
|
bot_open_id = _string_or_none(self.config.get("botOpenId"))
|
||||||
|
if not bot_open_id:
|
||||||
|
return False
|
||||||
|
mentions = message.get("mentions")
|
||||||
|
if not isinstance(mentions, list):
|
||||||
|
return False
|
||||||
|
for mention in mentions:
|
||||||
|
if not isinstance(mention, dict):
|
||||||
|
continue
|
||||||
|
mention_id = mention.get("id") if isinstance(mention.get("id"), dict) else {}
|
||||||
|
if _string_or_none(mention_id.get("open_id")) == bot_open_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _dm_allowed(self, identifier: str | None) -> bool:
|
||||||
|
allowlist = config_list(self.config, "allowFrom")
|
||||||
|
if not allowlist:
|
||||||
|
return True
|
||||||
|
return bool(identifier and identifier in allowlist)
|
||||||
|
|
||||||
|
def _group_allowed(self, peer_id: str | None, user_id: str | None) -> bool:
|
||||||
|
allowlist = config_list(self.config, "groupAllowFrom")
|
||||||
|
if not allowlist:
|
||||||
|
return True
|
||||||
|
return bool((peer_id and peer_id in allowlist) or (user_id and user_id in allowlist))
|
||||||
|
|
||||||
|
def _require_client(self) -> Any:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = self._build_client()
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _build_client(self) -> Any:
|
||||||
|
self._require_secret("appId")
|
||||||
|
self._require_secret("appSecret")
|
||||||
|
try:
|
||||||
|
import lark_oapi # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||||
|
raise RuntimeError("Install beaver-backend[feishu] to enable FeishuAdapter") from exc
|
||||||
|
raise RuntimeError("Feishu live client is not configured for direct construction")
|
||||||
|
|
||||||
|
def _require_secret(self, key: str) -> str:
|
||||||
|
value = self.secrets.get(key)
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"{key} is required")
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_object(value: Any) -> dict[str, Any]:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return value
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {}
|
||||||
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _sender_open_id(sender: dict[str, Any]) -> str | None:
|
||||||
|
sender_id = sender.get("sender_id") if isinstance(sender.get("sender_id"), dict) else {}
|
||||||
|
return _string_or_none(sender_id.get("open_id"))
|
||||||
|
|
||||||
|
|
||||||
|
def _string_or_none(value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
text = str(value).strip()
|
||||||
|
return text or None
|
||||||
@ -0,0 +1,206 @@
|
|||||||
|
"""QQ Bot channel adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
allowed_by_policy,
|
||||||
|
build_inbound_message,
|
||||||
|
chunk_text,
|
||||||
|
compact_media_summary,
|
||||||
|
config_list,
|
||||||
|
mark_unclaimed,
|
||||||
|
outbound_target,
|
||||||
|
)
|
||||||
|
|
||||||
|
EventRecorder = Callable[..., None]
|
||||||
|
|
||||||
|
|
||||||
|
class QQBotAdapter:
|
||||||
|
"""QQ Bot API adapter with injectable client support."""
|
||||||
|
|
||||||
|
KIND = "qqbot"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
mode: str,
|
||||||
|
account_id: str,
|
||||||
|
display_name: str | None,
|
||||||
|
inbound_sink: ChannelInboundSink,
|
||||||
|
secrets: dict[str, Any] | None = None,
|
||||||
|
config: dict[str, Any] | None = None,
|
||||||
|
event_recorder: EventRecorder | None = None,
|
||||||
|
client: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.kind = kind
|
||||||
|
self.mode = mode
|
||||||
|
self.account_id = account_id
|
||||||
|
self.display_name = display_name
|
||||||
|
self.inbound_sink = inbound_sink
|
||||||
|
self.secrets = secrets or {}
|
||||||
|
self.config = config or {}
|
||||||
|
self.event_recorder = event_recorder
|
||||||
|
self._client = client
|
||||||
|
self.max_message_chars = int(self.config.get("maxMessageChars") or 2000)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self._client is not None:
|
||||||
|
return
|
||||||
|
if self.mode != "websocket":
|
||||||
|
raise ValueError(f"Unsupported qqbot mode: {self.mode}")
|
||||||
|
self._client = self._build_client()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
close = getattr(self._client, "close", None)
|
||||||
|
if close is not None:
|
||||||
|
result = close()
|
||||||
|
if hasattr(result, "__await__"):
|
||||||
|
await result
|
||||||
|
|
||||||
|
async def handle_event_payload(self, payload: dict[str, Any]) -> None:
|
||||||
|
message = self._normalize_payload(payload)
|
||||||
|
if message is None:
|
||||||
|
return
|
||||||
|
await self.inbound_sink.accept_inbound(message)
|
||||||
|
|
||||||
|
async def send(self, message: OutboundMessage) -> None:
|
||||||
|
target = outbound_target(message)
|
||||||
|
if not target.peer_id:
|
||||||
|
mark_unclaimed(message)
|
||||||
|
return
|
||||||
|
client = self._require_client()
|
||||||
|
platform_message_id = message.channel_identity.message_id if message.channel_identity else None
|
||||||
|
for chunk in chunk_text(message.content, max_chars=self.max_message_chars):
|
||||||
|
await client.send_text(
|
||||||
|
peer_type=target.peer_type,
|
||||||
|
peer_id=target.peer_id,
|
||||||
|
content=chunk,
|
||||||
|
message_id=platform_message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_payload(self, payload: dict[str, Any]) -> InboundMessage | None:
|
||||||
|
event_type = str(payload.get("t") or payload.get("type") or "")
|
||||||
|
data = payload.get("d") if isinstance(payload.get("d"), dict) else payload
|
||||||
|
author = data.get("author") if isinstance(data.get("author"), dict) else {}
|
||||||
|
|
||||||
|
route = self._route(event_type, data, author)
|
||||||
|
if route is None:
|
||||||
|
return None
|
||||||
|
peer_id, peer_type, user_id, thread_id = route
|
||||||
|
|
||||||
|
if peer_type == "dm":
|
||||||
|
if not allowed_by_policy(
|
||||||
|
policy=self.config.get("dmPolicy"),
|
||||||
|
identifier=user_id or peer_id,
|
||||||
|
allowlist=config_list(self.config, "allowFrom"),
|
||||||
|
default="open",
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
elif peer_type == "group":
|
||||||
|
if not allowed_by_policy(
|
||||||
|
policy=self.config.get("groupPolicy"),
|
||||||
|
identifier=peer_id,
|
||||||
|
allowlist=config_list(self.config, "groupAllowFrom"),
|
||||||
|
default="open",
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
message_id = _string_or_none(data.get("id"))
|
||||||
|
content = str(data.get("content") or "").strip()
|
||||||
|
media_entries = self._media_entries(data)
|
||||||
|
if media_entries:
|
||||||
|
content = "\n".join([part for part in [content, *media_entries] if part]).strip()
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"event_type": event_type,
|
||||||
|
"message_id": message_id,
|
||||||
|
"peer_type": peer_type,
|
||||||
|
}
|
||||||
|
if media_entries:
|
||||||
|
metadata["media"] = media_entries
|
||||||
|
|
||||||
|
return build_inbound_message(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=self.kind,
|
||||||
|
account_id=self.account_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
peer_type=peer_type,
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
content=content,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _route(
|
||||||
|
self,
|
||||||
|
event_type: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
author: dict[str, Any],
|
||||||
|
) -> tuple[str, str, str | None, str | None] | None:
|
||||||
|
if event_type == "C2C_MESSAGE_CREATE":
|
||||||
|
peer_id = _string_or_none(author.get("user_openid"))
|
||||||
|
if not peer_id:
|
||||||
|
return None
|
||||||
|
return peer_id, "dm", peer_id, None
|
||||||
|
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||||
|
peer_id = _string_or_none(data.get("group_openid"))
|
||||||
|
if not peer_id:
|
||||||
|
return None
|
||||||
|
return peer_id, "group", _string_or_none(author.get("member_openid")), None
|
||||||
|
if data.get("guild_id") and data.get("channel_id"):
|
||||||
|
peer_id = _string_or_none(data.get("channel_id"))
|
||||||
|
if not peer_id:
|
||||||
|
return None
|
||||||
|
return peer_id, "channel", _string_or_none(author.get("id")), _string_or_none(data.get("guild_id"))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _media_entries(self, data: dict[str, Any]) -> list[str]:
|
||||||
|
entries: list[str] = []
|
||||||
|
attachments = data.get("attachments")
|
||||||
|
if not isinstance(attachments, list):
|
||||||
|
return entries
|
||||||
|
for attachment in attachments:
|
||||||
|
if not isinstance(attachment, dict):
|
||||||
|
continue
|
||||||
|
media_type = str(attachment.get("content_type") or attachment.get("type") or "attachment")
|
||||||
|
entries.append(compact_media_summary(media_type, file_name=_string_or_none(attachment.get("filename"))))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def _require_client(self) -> Any:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = self._build_client()
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _build_client(self) -> Any:
|
||||||
|
self._require_secret("appId")
|
||||||
|
self._require_secret("clientSecret")
|
||||||
|
try:
|
||||||
|
import aiohttp # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||||
|
raise RuntimeError("Install beaver-backend[qqbot] to enable QQBotAdapter") from exc
|
||||||
|
raise RuntimeError("QQBot live client is not configured for direct construction")
|
||||||
|
|
||||||
|
def _require_secret(self, key: str) -> str:
|
||||||
|
value = self.secrets.get(key)
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"{key} is required")
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _string_or_none(value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
text = str(value).strip()
|
||||||
|
return text or None
|
||||||
@ -0,0 +1,244 @@
|
|||||||
|
"""Telegram channel adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
build_inbound_message,
|
||||||
|
chunk_text,
|
||||||
|
compact_media_summary,
|
||||||
|
config_bool,
|
||||||
|
config_list,
|
||||||
|
mark_unclaimed,
|
||||||
|
outbound_target,
|
||||||
|
)
|
||||||
|
|
||||||
|
EventRecorder = Callable[..., None]
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramAdapter:
|
||||||
|
"""Telegram Bot API adapter with injectable client support."""
|
||||||
|
|
||||||
|
KIND = "telegram"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
mode: str,
|
||||||
|
account_id: str,
|
||||||
|
display_name: str | None,
|
||||||
|
inbound_sink: ChannelInboundSink,
|
||||||
|
secrets: dict[str, Any] | None = None,
|
||||||
|
config: dict[str, Any] | None = None,
|
||||||
|
event_recorder: EventRecorder | None = None,
|
||||||
|
client: Any | None = None,
|
||||||
|
application_factory: Callable[[], Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.kind = kind
|
||||||
|
self.mode = mode
|
||||||
|
self.account_id = account_id
|
||||||
|
self.display_name = display_name
|
||||||
|
self.inbound_sink = inbound_sink
|
||||||
|
self.secrets = secrets or {}
|
||||||
|
self.config = config or {}
|
||||||
|
self.event_recorder = event_recorder
|
||||||
|
self._client = client
|
||||||
|
self._application_factory = application_factory
|
||||||
|
self._application: Any | None = None
|
||||||
|
self.max_message_chars = int(self.config.get("maxMessageChars") or 4096)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self._client is not None:
|
||||||
|
return
|
||||||
|
if self.mode == "polling":
|
||||||
|
self._application = self._build_application()
|
||||||
|
await self._application.initialize()
|
||||||
|
await self._application.start()
|
||||||
|
if getattr(self._application, "updater", None) is not None:
|
||||||
|
await self._application.updater.start_polling()
|
||||||
|
self._client = self._application.bot
|
||||||
|
return
|
||||||
|
if self.mode == "webhook":
|
||||||
|
self._client = self._build_bot()
|
||||||
|
return
|
||||||
|
raise ValueError(f"Unsupported telegram mode: {self.mode}")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if self._application is None:
|
||||||
|
return
|
||||||
|
updater = getattr(self._application, "updater", None)
|
||||||
|
if updater is not None:
|
||||||
|
await updater.stop()
|
||||||
|
await self._application.stop()
|
||||||
|
await self._application.shutdown()
|
||||||
|
self._application = None
|
||||||
|
|
||||||
|
async def handle_update_payload(self, payload: dict[str, Any]) -> None:
|
||||||
|
message = self._normalize_payload(payload)
|
||||||
|
if message is None:
|
||||||
|
return
|
||||||
|
await self.inbound_sink.accept_inbound(message)
|
||||||
|
|
||||||
|
async def send(self, message: OutboundMessage) -> None:
|
||||||
|
target = outbound_target(message)
|
||||||
|
if not target.peer_id:
|
||||||
|
mark_unclaimed(message)
|
||||||
|
return
|
||||||
|
client = self._require_client()
|
||||||
|
kwargs: dict[str, Any] = {"chat_id": target.peer_id}
|
||||||
|
if target.thread_id:
|
||||||
|
kwargs["message_thread_id"] = int(target.thread_id) if str(target.thread_id).isdigit() else target.thread_id
|
||||||
|
for chunk in chunk_text(message.content, max_chars=self.max_message_chars):
|
||||||
|
await client.send_message(**kwargs, text=chunk)
|
||||||
|
|
||||||
|
def _normalize_payload(self, payload: dict[str, Any]) -> InboundMessage | None:
|
||||||
|
data = payload.get("message") or payload.get("edited_message")
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
chat = data.get("chat") if isinstance(data.get("chat"), dict) else {}
|
||||||
|
sender = data.get("from") if isinstance(data.get("from"), dict) else {}
|
||||||
|
peer_id = _string_or_none(chat.get("id"))
|
||||||
|
if not peer_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
chat_type = str(chat.get("type") or "unknown")
|
||||||
|
peer_type = self._peer_type(chat_type)
|
||||||
|
user_id = _string_or_none(sender.get("id"))
|
||||||
|
message_id = _string_or_none(data.get("message_id"))
|
||||||
|
thread_id = _string_or_none(data.get("message_thread_id"))
|
||||||
|
|
||||||
|
content = str(data.get("text") or data.get("caption") or "").strip()
|
||||||
|
media_entries = self._media_entries(data)
|
||||||
|
if media_entries:
|
||||||
|
content = "\n".join([part for part in [content, *media_entries] if part]).strip()
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if peer_type in {"group", "channel"} and not self._group_allowed(peer_id, user_id):
|
||||||
|
return None
|
||||||
|
if peer_type == "dm" and not self._dm_allowed(user_id or peer_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if peer_type in {"group", "channel"} and config_bool(self.config, "requireMentionInGroups", default=False):
|
||||||
|
gated = self._strip_required_mention(content)
|
||||||
|
if gated is None:
|
||||||
|
return None
|
||||||
|
content = gated
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"chat_id": peer_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
"chat_type": chat_type,
|
||||||
|
}
|
||||||
|
if media_entries:
|
||||||
|
metadata["media"] = media_entries
|
||||||
|
|
||||||
|
return build_inbound_message(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=self.kind,
|
||||||
|
account_id=self.account_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
peer_type=peer_type,
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
content=content,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _media_entries(self, data: dict[str, Any]) -> list[str]:
|
||||||
|
entries: list[str] = []
|
||||||
|
if data.get("photo"):
|
||||||
|
entries.append(compact_media_summary("photo"))
|
||||||
|
for media_type in ("document", "audio", "video"):
|
||||||
|
value = data.get(media_type)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
entries.append(compact_media_summary(media_type, file_name=_string_or_none(value.get("file_name"))))
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def _strip_required_mention(self, content: str) -> str | None:
|
||||||
|
username = str(self.config.get("botUsername") or "").strip().lstrip("@")
|
||||||
|
if not username:
|
||||||
|
return None
|
||||||
|
mention = f"@{username}"
|
||||||
|
if mention not in content:
|
||||||
|
return None
|
||||||
|
return content.replace(mention, "", 1).strip()
|
||||||
|
|
||||||
|
def _dm_allowed(self, identifier: str | None) -> bool:
|
||||||
|
allowlist = config_list(self.config, "allowFrom")
|
||||||
|
if not allowlist:
|
||||||
|
return True
|
||||||
|
return bool(identifier and identifier in allowlist)
|
||||||
|
|
||||||
|
def _group_allowed(self, peer_id: str | None, user_id: str | None) -> bool:
|
||||||
|
allowlist = config_list(self.config, "groupAllowFrom")
|
||||||
|
if not allowlist:
|
||||||
|
return True
|
||||||
|
return bool((peer_id and peer_id in allowlist) or (user_id and user_id in allowlist))
|
||||||
|
|
||||||
|
def _peer_type(self, chat_type: str) -> str:
|
||||||
|
if chat_type == "private":
|
||||||
|
return "dm"
|
||||||
|
if chat_type in {"group", "supergroup"}:
|
||||||
|
return "group"
|
||||||
|
if chat_type == "channel":
|
||||||
|
return "channel"
|
||||||
|
return chat_type or "unknown"
|
||||||
|
|
||||||
|
def _require_client(self) -> Any:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = self._build_bot()
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _build_bot(self) -> Any:
|
||||||
|
token = self._require_secret("botToken")
|
||||||
|
try:
|
||||||
|
from telegram import Bot
|
||||||
|
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||||
|
raise RuntimeError("Install beaver-backend[telegram] to enable TelegramAdapter") from exc
|
||||||
|
return Bot(token=token)
|
||||||
|
|
||||||
|
def _build_application(self) -> Any:
|
||||||
|
if self._application_factory is not None:
|
||||||
|
return self._application_factory()
|
||||||
|
token = self._require_secret("botToken")
|
||||||
|
try:
|
||||||
|
from telegram.ext import Application
|
||||||
|
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||||
|
raise RuntimeError("Install beaver-backend[telegram] to enable TelegramAdapter") from exc
|
||||||
|
|
||||||
|
async def handle(update: Any, context: Any) -> None:
|
||||||
|
if hasattr(update, "to_dict"):
|
||||||
|
await self.handle_update_payload(update.to_dict())
|
||||||
|
|
||||||
|
application = Application.builder().token(token).build()
|
||||||
|
try:
|
||||||
|
from telegram.ext import MessageHandler, filters
|
||||||
|
|
||||||
|
application.add_handler(MessageHandler(filters.ALL, handle))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return application
|
||||||
|
|
||||||
|
def _require_secret(self, key: str) -> str:
|
||||||
|
value = self.secrets.get(key)
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"{key} is required")
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _string_or_none(value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
text = str(value).strip()
|
||||||
|
return text or None
|
||||||
@ -0,0 +1,180 @@
|
|||||||
|
"""Weixin channel adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
allowed_by_policy,
|
||||||
|
build_inbound_message,
|
||||||
|
chunk_text,
|
||||||
|
compact_media_summary,
|
||||||
|
config_list,
|
||||||
|
mark_unclaimed,
|
||||||
|
outbound_target,
|
||||||
|
)
|
||||||
|
|
||||||
|
EventRecorder = Callable[..., None]
|
||||||
|
|
||||||
|
|
||||||
|
class WeixinAdapter:
|
||||||
|
"""Tencent iLink-style Weixin adapter with injectable client support."""
|
||||||
|
|
||||||
|
KIND = "weixin"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
mode: str,
|
||||||
|
account_id: str,
|
||||||
|
display_name: str | None,
|
||||||
|
inbound_sink: ChannelInboundSink,
|
||||||
|
secrets: dict[str, Any] | None = None,
|
||||||
|
config: dict[str, Any] | None = None,
|
||||||
|
event_recorder: EventRecorder | None = None,
|
||||||
|
client: Any | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.kind = kind
|
||||||
|
self.mode = mode
|
||||||
|
self.account_id = account_id
|
||||||
|
self.display_name = display_name
|
||||||
|
self.inbound_sink = inbound_sink
|
||||||
|
self.secrets = secrets or {}
|
||||||
|
self.config = config or {}
|
||||||
|
self.event_recorder = event_recorder
|
||||||
|
self._client = client
|
||||||
|
self.max_message_chars = int(self.config.get("maxMessageChars") or 2000)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
if self._client is not None:
|
||||||
|
return
|
||||||
|
if self.mode != "polling":
|
||||||
|
raise ValueError(f"Unsupported weixin mode: {self.mode}")
|
||||||
|
self._client = self._build_client()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
close = getattr(self._client, "close", None)
|
||||||
|
if close is not None:
|
||||||
|
result = close()
|
||||||
|
if hasattr(result, "__await__"):
|
||||||
|
await result
|
||||||
|
|
||||||
|
async def handle_message_payload(self, payload: dict[str, Any]) -> None:
|
||||||
|
message = self._normalize_payload(payload)
|
||||||
|
if message is None:
|
||||||
|
return
|
||||||
|
await self.inbound_sink.accept_inbound(message)
|
||||||
|
|
||||||
|
async def send(self, message: OutboundMessage) -> None:
|
||||||
|
target = outbound_target(message)
|
||||||
|
if not target.peer_id:
|
||||||
|
mark_unclaimed(message)
|
||||||
|
return
|
||||||
|
client = self._require_client()
|
||||||
|
context_token = self._context_token(message)
|
||||||
|
for chunk in chunk_text(message.content, max_chars=self.max_message_chars):
|
||||||
|
await client.send_text(peer_id=target.peer_id, text=chunk, context_token=context_token)
|
||||||
|
|
||||||
|
def _normalize_payload(self, payload: dict[str, Any]) -> InboundMessage | None:
|
||||||
|
sender_id = _string_or_none(payload.get("from") or payload.get("from_user"))
|
||||||
|
room_id = _string_or_none(payload.get("room_id") or payload.get("roomId"))
|
||||||
|
message_id = _string_or_none(payload.get("id") or payload.get("message_id"))
|
||||||
|
message_type = str(payload.get("type") or payload.get("message_type") or "text")
|
||||||
|
|
||||||
|
if room_id:
|
||||||
|
peer_id = room_id
|
||||||
|
peer_type = "group"
|
||||||
|
user_id = sender_id
|
||||||
|
if not allowed_by_policy(
|
||||||
|
policy=self.config.get("groupPolicy"),
|
||||||
|
identifier=peer_id,
|
||||||
|
allowlist=config_list(self.config, "groupAllowFrom"),
|
||||||
|
default="disabled",
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
peer_id = sender_id
|
||||||
|
peer_type = "dm"
|
||||||
|
user_id = sender_id
|
||||||
|
if not allowed_by_policy(
|
||||||
|
policy=self.config.get("dmPolicy"),
|
||||||
|
identifier=peer_id,
|
||||||
|
allowlist=config_list(self.config, "allowFrom"),
|
||||||
|
default="open",
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
if not peer_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = self._content(message_type, payload)
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"message_id": message_id,
|
||||||
|
"message_type": message_type,
|
||||||
|
}
|
||||||
|
context_token = _string_or_none(payload.get("context_token") or payload.get("contextToken"))
|
||||||
|
if context_token:
|
||||||
|
metadata["context_token"] = context_token
|
||||||
|
if room_id:
|
||||||
|
metadata["room_id"] = room_id
|
||||||
|
|
||||||
|
return build_inbound_message(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=self.kind,
|
||||||
|
account_id=self.account_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
peer_type=peer_type,
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
content=content,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _content(self, message_type: str, payload: dict[str, Any]) -> str:
|
||||||
|
if message_type == "text":
|
||||||
|
return str(payload.get("text") or payload.get("content") or "").strip()
|
||||||
|
file_name = _string_or_none(payload.get("file_name") or payload.get("filename"))
|
||||||
|
return compact_media_summary(message_type, file_name=file_name)
|
||||||
|
|
||||||
|
def _context_token(self, message: OutboundMessage) -> str | None:
|
||||||
|
inbound_metadata = message.metadata.get("inbound_metadata")
|
||||||
|
if isinstance(inbound_metadata, dict):
|
||||||
|
value = _string_or_none(inbound_metadata.get("context_token"))
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
return _string_or_none(message.metadata.get("context_token"))
|
||||||
|
|
||||||
|
def _require_client(self) -> Any:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = self._build_client()
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _build_client(self) -> Any:
|
||||||
|
self._require_secret("token")
|
||||||
|
try:
|
||||||
|
import aiohttp # noqa: F401
|
||||||
|
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||||
|
raise RuntimeError("Install beaver-backend[weixin] to enable WeixinAdapter") from exc
|
||||||
|
raise RuntimeError("Weixin live client is not configured for direct construction")
|
||||||
|
|
||||||
|
def _require_secret(self, key: str) -> str:
|
||||||
|
value = self.secrets.get(key)
|
||||||
|
if not value:
|
||||||
|
raise ValueError(f"{key} is required")
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _string_or_none(value: Any) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
text = str(value).strip()
|
||||||
|
return text or None
|
||||||
526
app-instance/backend/beaver/interfaces/channels/runtime.py
Normal file
526
app-instance/backend/beaver/interfaces/channels/runtime.py
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
"""Channel runtime host for adapter lifecycle and bus-first routing."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.config.schema import ChannelConfig
|
||||||
|
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.base import ChannelAdapter
|
||||||
|
from beaver.interfaces.channels.manager import ChannelManager
|
||||||
|
from beaver.interfaces.channels.state import ChannelDedupeStore, ChannelEventLog
|
||||||
|
from beaver.services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
|
def _iso_now() -> str:
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def _channel_capabilities(kind: str, mode: str) -> list[str]:
|
||||||
|
if kind == "webhook":
|
||||||
|
return ["receive_text", "send_text", "sync_webhook_response"]
|
||||||
|
if kind == "terminal" and mode == "websocket":
|
||||||
|
return ["receive_text", "send_text", "persistent_connection"]
|
||||||
|
if kind in {"feishu", "qqbot", "telegram"}:
|
||||||
|
return ["receive_text", "send_text", "receive_media", "groups"]
|
||||||
|
if kind == "weixin":
|
||||||
|
return ["receive_text", "send_text", "receive_media", "direct_messages"]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ChannelAcceptResult:
|
||||||
|
accepted: bool
|
||||||
|
duplicate: bool = False
|
||||||
|
pending: bool = False
|
||||||
|
rejected: bool = False
|
||||||
|
session_id: str | None = None
|
||||||
|
dedupe_key: str | None = None
|
||||||
|
record: dict[str, Any] | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelRuntime:
|
||||||
|
"""Own channel adapters, state, and the inbound/outbound bus bridge."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
service: AgentService,
|
||||||
|
workspace: Path,
|
||||||
|
channels: dict[str, ChannelConfig],
|
||||||
|
bus: MessageBus | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.service = service
|
||||||
|
self.workspace = Path(workspace)
|
||||||
|
self.bus = bus or MessageBus()
|
||||||
|
self.manager = ChannelManager(self.bus)
|
||||||
|
self.channel_configs = dict(channels)
|
||||||
|
self.adapters: dict[str, ChannelAdapter] = {}
|
||||||
|
self.states: dict[str, dict[str, Any]] = {}
|
||||||
|
state_dir = self.workspace / "state" / "channels"
|
||||||
|
retention = self._default_dedupe_retention_hours()
|
||||||
|
self.dedupe = ChannelDedupeStore(state_dir / "dedupe.json", retention_hours=retention)
|
||||||
|
self.events = ChannelEventLog(state_dir / "events.jsonl")
|
||||||
|
self._bridge_task: asyncio.Task[None] | None = None
|
||||||
|
self._dispatch_task: asyncio.Task[None] | None = None
|
||||||
|
self._stop_event = asyncio.Event()
|
||||||
|
self._dispatch_stop_event = asyncio.Event()
|
||||||
|
self._lifecycle_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self._stop_event.clear()
|
||||||
|
self._dispatch_stop_event.clear()
|
||||||
|
for channel_id, cfg in self.channel_configs.items():
|
||||||
|
if not cfg.enabled:
|
||||||
|
self.states[channel_id] = {"state": "disabled", "last_error": None}
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
adapter = self._build_adapter(channel_id, cfg)
|
||||||
|
self.adapters[channel_id] = adapter
|
||||||
|
self.manager.register(adapter)
|
||||||
|
await adapter.start()
|
||||||
|
self.states[channel_id] = {
|
||||||
|
"state": "running",
|
||||||
|
"last_error": None,
|
||||||
|
"started_at": _iso_now(),
|
||||||
|
}
|
||||||
|
self.events.record(channel_id=channel_id, kind="adapter_started")
|
||||||
|
except Exception as exc: # pragma: no cover - defensive startup isolation
|
||||||
|
self.states[channel_id] = {"state": "error", "last_error": str(exc)}
|
||||||
|
self.events.record(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind="adapter_error",
|
||||||
|
status="error",
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
self._bridge_task = asyncio.create_task(self._bridge_inbound_to_agent())
|
||||||
|
self._dispatch_task = asyncio.create_task(
|
||||||
|
self.manager.dispatch_outbound(
|
||||||
|
self._dispatch_stop_event,
|
||||||
|
on_delivered=self._record_outbound_delivered,
|
||||||
|
on_failed=self._record_outbound_failed,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._bridge_task is not None:
|
||||||
|
self._bridge_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._bridge_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._dispatch_stop_event.set()
|
||||||
|
if self._dispatch_task is not None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._dispatch_task, timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._dispatch_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._dispatch_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
await self.manager.stop()
|
||||||
|
for channel_id in self.adapters:
|
||||||
|
self.events.record(channel_id=channel_id, kind="adapter_stopped")
|
||||||
|
|
||||||
|
async def add_channel(self, channel_id: str, config: ChannelConfig) -> None:
|
||||||
|
async with self._lifecycle_lock:
|
||||||
|
current = self.channel_configs.get(channel_id)
|
||||||
|
if current == config and channel_id in self.adapters:
|
||||||
|
return
|
||||||
|
if not config.enabled:
|
||||||
|
await self._remove_channel_locked(channel_id)
|
||||||
|
self.channel_configs[channel_id] = config
|
||||||
|
self.states[channel_id] = {"state": "disabled", "last_error": None}
|
||||||
|
return
|
||||||
|
|
||||||
|
adapter = self._build_adapter(channel_id, config)
|
||||||
|
await adapter.start()
|
||||||
|
old_adapter = self.adapters.get(channel_id)
|
||||||
|
self.manager.replace_registered(adapter)
|
||||||
|
self.adapters[channel_id] = adapter
|
||||||
|
self.channel_configs[channel_id] = config
|
||||||
|
self.states[channel_id] = {"state": "running", "last_error": None, "started_at": _iso_now()}
|
||||||
|
self.events.record(channel_id=channel_id, kind="adapter_started")
|
||||||
|
if old_adapter is not None and old_adapter is not adapter:
|
||||||
|
await old_adapter.stop()
|
||||||
|
|
||||||
|
async def remove_channel(self, channel_id: str) -> None:
|
||||||
|
async with self._lifecycle_lock:
|
||||||
|
await self._remove_channel_locked(channel_id)
|
||||||
|
|
||||||
|
async def _remove_channel_locked(self, channel_id: str) -> None:
|
||||||
|
adapter = self.adapters.pop(channel_id, None)
|
||||||
|
self.manager.unregister(channel_id)
|
||||||
|
self.channel_configs.pop(channel_id, None)
|
||||||
|
if adapter is not None:
|
||||||
|
await adapter.stop()
|
||||||
|
self.events.record(channel_id=channel_id, kind="adapter_stopped")
|
||||||
|
self.states[channel_id] = {"state": "removed", "last_error": None}
|
||||||
|
|
||||||
|
async def accept_inbound(self, message: InboundMessage) -> ChannelAcceptResult:
|
||||||
|
identity = message.channel_identity
|
||||||
|
if identity is None:
|
||||||
|
self.events.record(
|
||||||
|
channel_id=message.channel,
|
||||||
|
kind="inbound_rejected",
|
||||||
|
status="error",
|
||||||
|
error="channel_identity is required",
|
||||||
|
)
|
||||||
|
return ChannelAcceptResult(
|
||||||
|
accepted=False,
|
||||||
|
rejected=True,
|
||||||
|
error="channel_identity is required",
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_error = identity.validation_error()
|
||||||
|
if validation_error:
|
||||||
|
self.events.record(
|
||||||
|
channel_id=identity.channel_id,
|
||||||
|
kind="inbound_rejected",
|
||||||
|
status="error",
|
||||||
|
error=validation_error,
|
||||||
|
)
|
||||||
|
return ChannelAcceptResult(accepted=False, rejected=True, error=validation_error)
|
||||||
|
|
||||||
|
expected_session_id = identity.session_id()
|
||||||
|
if message.session_id != expected_session_id:
|
||||||
|
self.events.record(
|
||||||
|
channel_id=identity.channel_id,
|
||||||
|
kind="session_id_normalized",
|
||||||
|
session_id=expected_session_id,
|
||||||
|
message_id=identity.message_id,
|
||||||
|
)
|
||||||
|
message.session_id = expected_session_id
|
||||||
|
message.channel = identity.channel_id
|
||||||
|
|
||||||
|
dedupe_key = identity.dedupe_key()
|
||||||
|
if dedupe_key:
|
||||||
|
write = self.dedupe.mark_processing(
|
||||||
|
dedupe_key=dedupe_key,
|
||||||
|
session_id=expected_session_id,
|
||||||
|
message_id=identity.message_id or "",
|
||||||
|
)
|
||||||
|
if not write.created:
|
||||||
|
record = write.record or {}
|
||||||
|
self.events.record(
|
||||||
|
channel_id=identity.channel_id,
|
||||||
|
kind="inbound_duplicate",
|
||||||
|
session_id=expected_session_id,
|
||||||
|
message_id=identity.message_id,
|
||||||
|
status=str(record.get("status") or "processing"),
|
||||||
|
)
|
||||||
|
return ChannelAcceptResult(
|
||||||
|
accepted=False,
|
||||||
|
duplicate=True,
|
||||||
|
pending=record.get("status") == "processing",
|
||||||
|
session_id=expected_session_id,
|
||||||
|
dedupe_key=dedupe_key,
|
||||||
|
record=record,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.events.record(
|
||||||
|
channel_id=identity.channel_id,
|
||||||
|
kind="inbound_accepted",
|
||||||
|
session_id=expected_session_id,
|
||||||
|
message_id=identity.message_id,
|
||||||
|
text=message.content,
|
||||||
|
)
|
||||||
|
await self.bus.publish_inbound(message)
|
||||||
|
return ChannelAcceptResult(
|
||||||
|
accepted=True,
|
||||||
|
session_id=expected_session_id,
|
||||||
|
dedupe_key=dedupe_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
def statuses(self) -> list[dict[str, Any]]:
|
||||||
|
items: list[dict[str, Any]] = []
|
||||||
|
recent = self.events.recent(limit=500)
|
||||||
|
last_by_channel = {event["channel_id"]: event for event in recent if event.get("channel_id")}
|
||||||
|
for channel_id, cfg in self.channel_configs.items():
|
||||||
|
state = self.states.get(channel_id, {"state": "configured", "last_error": None})
|
||||||
|
capabilities = _channel_capabilities(cfg.kind, cfg.mode)
|
||||||
|
webhook_url = None
|
||||||
|
websocket_url = None
|
||||||
|
connected_peers = 0
|
||||||
|
if cfg.kind == "webhook":
|
||||||
|
webhook_url = f"/api/channels/{channel_id}/webhook"
|
||||||
|
elif cfg.kind == "terminal" and cfg.mode == "websocket":
|
||||||
|
websocket_url = f"/api/channels/{channel_id}/ws"
|
||||||
|
adapter = self.adapters.get(channel_id)
|
||||||
|
if adapter is not None and hasattr(adapter, "status_extra"):
|
||||||
|
extra = adapter.status_extra() # type: ignore[attr-defined]
|
||||||
|
connected_peers = int(extra.get("connected_peers") or 0)
|
||||||
|
items.append(
|
||||||
|
{
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"name": channel_id,
|
||||||
|
"kind": cfg.kind,
|
||||||
|
"mode": cfg.mode,
|
||||||
|
"display_name": cfg.display_name or channel_id,
|
||||||
|
"enabled": cfg.enabled,
|
||||||
|
"state": state.get("state", "configured"),
|
||||||
|
"account_id": cfg.account_id,
|
||||||
|
"last_error": state.get("last_error"),
|
||||||
|
"started_at": state.get("started_at"),
|
||||||
|
"last_event_at": last_by_channel.get(channel_id, {}).get("created_at"),
|
||||||
|
"capabilities": capabilities,
|
||||||
|
"webhook_url": webhook_url,
|
||||||
|
"websocket_url": websocket_url,
|
||||||
|
"connected_peers": connected_peers,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return items
|
||||||
|
|
||||||
|
def recent_events(self, channel_id: str, *, limit: int = 100) -> list[dict[str, Any]]:
|
||||||
|
return self.events.recent(channel_id=channel_id, limit=limit)
|
||||||
|
|
||||||
|
def record_event(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
session_id: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
|
run_id: str | None = None,
|
||||||
|
status: str = "ok",
|
||||||
|
error: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.events.record(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=kind,
|
||||||
|
session_id=session_id,
|
||||||
|
message_id=message_id,
|
||||||
|
run_id=run_id,
|
||||||
|
status=status,
|
||||||
|
error=error,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_adapter(self, channel_id: str, cfg: ChannelConfig) -> ChannelAdapter:
|
||||||
|
if cfg.kind == "webhook" and cfg.mode == "webhook":
|
||||||
|
from beaver.interfaces.channels.generic_webhook import GenericWebhookAdapter
|
||||||
|
|
||||||
|
return GenericWebhookAdapter(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=cfg.kind,
|
||||||
|
mode=cfg.mode,
|
||||||
|
account_id=cfg.account_id,
|
||||||
|
display_name=cfg.display_name,
|
||||||
|
inbound_sink=self,
|
||||||
|
response_timeout_seconds=float(cfg.config.get("response_timeout_seconds") or 1800),
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.kind == "terminal" and cfg.mode == "websocket":
|
||||||
|
from beaver.interfaces.channels.terminal_websocket import TerminalWebSocketAdapter
|
||||||
|
|
||||||
|
return TerminalWebSocketAdapter(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=cfg.kind,
|
||||||
|
mode=cfg.mode,
|
||||||
|
account_id=cfg.account_id,
|
||||||
|
display_name=cfg.display_name,
|
||||||
|
inbound_sink=self,
|
||||||
|
event_recorder=self.record_event,
|
||||||
|
heartbeat_seconds=float(cfg.config.get("heartbeat_seconds") or 30),
|
||||||
|
max_message_chars=int(cfg.config.get("max_message_chars") or 20000),
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.kind == "telegram" and cfg.mode in {"polling", "webhook"}:
|
||||||
|
from beaver.interfaces.channels.platforms.telegram import TelegramAdapter
|
||||||
|
|
||||||
|
return TelegramAdapter(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=cfg.kind,
|
||||||
|
mode=cfg.mode,
|
||||||
|
account_id=cfg.account_id,
|
||||||
|
display_name=cfg.display_name,
|
||||||
|
inbound_sink=self,
|
||||||
|
secrets=cfg.secrets,
|
||||||
|
config=cfg.config,
|
||||||
|
event_recorder=self.record_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.kind == "feishu" and cfg.mode in {"websocket", "webhook"}:
|
||||||
|
from beaver.interfaces.channels.platforms.feishu import FeishuAdapter
|
||||||
|
|
||||||
|
return FeishuAdapter(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=cfg.kind,
|
||||||
|
mode=cfg.mode,
|
||||||
|
account_id=cfg.account_id,
|
||||||
|
display_name=cfg.display_name,
|
||||||
|
inbound_sink=self,
|
||||||
|
secrets=cfg.secrets,
|
||||||
|
config=cfg.config,
|
||||||
|
event_recorder=self.record_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.kind == "qqbot" and cfg.mode == "websocket":
|
||||||
|
from beaver.interfaces.channels.platforms.qqbot import QQBotAdapter
|
||||||
|
|
||||||
|
return QQBotAdapter(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=cfg.kind,
|
||||||
|
mode=cfg.mode,
|
||||||
|
account_id=cfg.account_id,
|
||||||
|
display_name=cfg.display_name,
|
||||||
|
inbound_sink=self,
|
||||||
|
secrets=cfg.secrets,
|
||||||
|
config=cfg.config,
|
||||||
|
event_recorder=self.record_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.kind == "weixin" and cfg.mode == "polling":
|
||||||
|
from beaver.interfaces.channels.platforms.weixin import WeixinAdapter
|
||||||
|
|
||||||
|
return WeixinAdapter(
|
||||||
|
channel_id=channel_id,
|
||||||
|
kind=cfg.kind,
|
||||||
|
mode=cfg.mode,
|
||||||
|
account_id=cfg.account_id,
|
||||||
|
display_name=cfg.display_name,
|
||||||
|
inbound_sink=self,
|
||||||
|
secrets=cfg.secrets,
|
||||||
|
config=cfg.config,
|
||||||
|
event_recorder=self.record_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.kind == "external_connector" and cfg.mode == "http":
|
||||||
|
import os
|
||||||
|
|
||||||
|
from beaver.interfaces.channels.connections.sidecar_client import ConnectorSidecarClient
|
||||||
|
from beaver.interfaces.channels.external_connector import ExternalConnectorChannel
|
||||||
|
|
||||||
|
base_url = str(cfg.config.get("sidecarBaseUrl") or os.getenv("EXTERNAL_CONNECTOR_BASE_URL") or "").strip()
|
||||||
|
token = os.getenv("EXTERNAL_CONNECTOR_TOKEN", "")
|
||||||
|
platform_kind = str(cfg.config.get("platformKind") or "").strip()
|
||||||
|
connection_id = str(cfg.config.get("connectionId") or "").strip()
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("external connector sidecarBaseUrl is required")
|
||||||
|
if not platform_kind:
|
||||||
|
raise ValueError("external connector platformKind is required")
|
||||||
|
if not connection_id:
|
||||||
|
raise ValueError("external connector connectionId is required")
|
||||||
|
return ExternalConnectorChannel(
|
||||||
|
channel_id=channel_id,
|
||||||
|
platform_kind=platform_kind,
|
||||||
|
connection_id=connection_id,
|
||||||
|
account_id=cfg.account_id,
|
||||||
|
display_name=cfg.display_name,
|
||||||
|
sidecar_client=ConnectorSidecarClient(base_url=base_url, token=token),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported channel kind/mode: {cfg.kind}/{cfg.mode}")
|
||||||
|
|
||||||
|
async def _bridge_inbound_to_agent(self) -> None:
|
||||||
|
current_inbound: InboundMessage | None = None
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
try:
|
||||||
|
current_inbound = await asyncio.wait_for(self.bus.consume_inbound(), timeout=0.25)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
inbound = current_inbound
|
||||||
|
identity = inbound.channel_identity
|
||||||
|
try:
|
||||||
|
self.events.record(
|
||||||
|
channel_id=inbound.channel,
|
||||||
|
kind="direct_run_started",
|
||||||
|
session_id=inbound.session_id,
|
||||||
|
message_id=identity.message_id if identity else inbound.message_id,
|
||||||
|
)
|
||||||
|
outbound = await self.service.handle_inbound_message(inbound)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
outbound = AgentService.build_outbound_error(
|
||||||
|
inbound,
|
||||||
|
detail="Channel runtime stopped before completing the inbound message",
|
||||||
|
finish_reason="cancelled",
|
||||||
|
)
|
||||||
|
self._mark_dedupe_result(inbound, outbound)
|
||||||
|
await self.bus.publish_outbound(outbound)
|
||||||
|
current_inbound = None
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
self.events.record(
|
||||||
|
channel_id=inbound.channel,
|
||||||
|
kind="direct_run_failed",
|
||||||
|
session_id=inbound.session_id,
|
||||||
|
message_id=identity.message_id if identity else inbound.message_id,
|
||||||
|
status="error",
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
outbound = AgentService.build_outbound_error(
|
||||||
|
inbound,
|
||||||
|
detail=str(exc),
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.events.record(
|
||||||
|
channel_id=outbound.channel,
|
||||||
|
kind="direct_run_finished",
|
||||||
|
session_id=outbound.session_id,
|
||||||
|
message_id=identity.message_id if identity else inbound.message_id,
|
||||||
|
run_id=outbound.run_id,
|
||||||
|
)
|
||||||
|
self._mark_dedupe_result(inbound, outbound)
|
||||||
|
await self.bus.publish_outbound(outbound)
|
||||||
|
current_inbound = None
|
||||||
|
|
||||||
|
def _mark_dedupe_result(self, inbound: InboundMessage, outbound: OutboundMessage) -> None:
|
||||||
|
identity = inbound.channel_identity
|
||||||
|
dedupe_key = identity.dedupe_key() if identity else None
|
||||||
|
if not dedupe_key:
|
||||||
|
return
|
||||||
|
cfg = self.channel_configs.get(identity.channel_id)
|
||||||
|
max_reply_chars = int((cfg.config if cfg else {}).get("max_cached_reply_chars") or 20000)
|
||||||
|
max_error_chars = int((cfg.config if cfg else {}).get("max_cached_error_chars") or 4000)
|
||||||
|
if outbound.finish_reason == "error":
|
||||||
|
self.dedupe.mark_error(
|
||||||
|
dedupe_key=dedupe_key,
|
||||||
|
error=outbound.content,
|
||||||
|
max_error_chars=max_error_chars,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.dedupe.mark_done(
|
||||||
|
dedupe_key=dedupe_key,
|
||||||
|
run_id=outbound.run_id,
|
||||||
|
reply=outbound.content,
|
||||||
|
max_reply_chars=max_reply_chars,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _record_outbound_delivered(self, message: OutboundMessage) -> None:
|
||||||
|
kind = "outbound_unclaimed" if message.metadata.get("delivery_status") == "unclaimed" else "outbound_delivered"
|
||||||
|
self.events.record(
|
||||||
|
channel_id=message.channel,
|
||||||
|
kind=kind,
|
||||||
|
session_id=message.session_id,
|
||||||
|
message_id=message.channel_identity.message_id if message.channel_identity else message.message_id,
|
||||||
|
run_id=message.run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _record_outbound_failed(self, message: OutboundMessage, exc: Exception | None) -> None:
|
||||||
|
self.events.record(
|
||||||
|
channel_id=message.channel,
|
||||||
|
kind="outbound_delivery_failed",
|
||||||
|
session_id=message.session_id,
|
||||||
|
message_id=message.channel_identity.message_id if message.channel_identity else message.message_id,
|
||||||
|
run_id=message.run_id,
|
||||||
|
status="error",
|
||||||
|
error=str(exc) if exc else "channel not registered",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _default_dedupe_retention_hours(self) -> int:
|
||||||
|
for cfg in self.channel_configs.values():
|
||||||
|
value = cfg.config.get("dedupe_retention_hours")
|
||||||
|
if value is not None:
|
||||||
|
return int(value)
|
||||||
|
return 48
|
||||||
198
app-instance/backend/beaver/interfaces/channels/state.py
Normal file
198
app-instance/backend/beaver/interfaces/channels/state.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
"""Persistent channel runtime state."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
|
def _now_ms() -> int:
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _iso_now() -> str:
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class DedupeWriteResult:
|
||||||
|
created: bool
|
||||||
|
record: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDedupeStore:
|
||||||
|
def __init__(self, path: Path, *, retention_hours: int = 48) -> None:
|
||||||
|
self.path = path
|
||||||
|
self.retention_ms = max(1, int(retention_hours)) * 60 * 60 * 1000
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def get(self, dedupe_key: str) -> dict[str, Any] | None:
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
self._prune_unlocked(data, _now_ms())
|
||||||
|
record = data["records"].get(dedupe_key)
|
||||||
|
self._save(data)
|
||||||
|
return record
|
||||||
|
|
||||||
|
def mark_processing(self, *, dedupe_key: str, session_id: str, message_id: str) -> DedupeWriteResult:
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
now_ms = _now_ms()
|
||||||
|
self._prune_unlocked(data, now_ms)
|
||||||
|
existing = data["records"].get(dedupe_key)
|
||||||
|
if existing is not None:
|
||||||
|
self._save(data)
|
||||||
|
return DedupeWriteResult(created=False, record=existing)
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"dedupe_key": dedupe_key,
|
||||||
|
"status": "processing",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
"run_id": None,
|
||||||
|
"reply": None,
|
||||||
|
"error": None,
|
||||||
|
"created_at_ms": now_ms,
|
||||||
|
"updated_at_ms": now_ms,
|
||||||
|
}
|
||||||
|
data["records"][dedupe_key] = record
|
||||||
|
self._save(data)
|
||||||
|
return DedupeWriteResult(created=True, record=record)
|
||||||
|
|
||||||
|
def mark_done(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dedupe_key: str,
|
||||||
|
run_id: str | None,
|
||||||
|
reply: str,
|
||||||
|
max_reply_chars: int,
|
||||||
|
) -> None:
|
||||||
|
self._mark_result(
|
||||||
|
dedupe_key=dedupe_key,
|
||||||
|
status="done",
|
||||||
|
run_id=run_id,
|
||||||
|
reply=reply[: max(0, int(max_reply_chars))],
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def mark_error(self, *, dedupe_key: str, error: str, max_error_chars: int) -> None:
|
||||||
|
self._mark_result(
|
||||||
|
dedupe_key=dedupe_key,
|
||||||
|
status="error",
|
||||||
|
run_id=None,
|
||||||
|
reply=None,
|
||||||
|
error=error[: max(0, int(max_error_chars))],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _mark_result(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dedupe_key: str,
|
||||||
|
status: str,
|
||||||
|
run_id: str | None,
|
||||||
|
reply: str | None,
|
||||||
|
error: str | None,
|
||||||
|
) -> None:
|
||||||
|
with self._lock:
|
||||||
|
data = self._load()
|
||||||
|
record = data["records"].get(dedupe_key)
|
||||||
|
if record is None:
|
||||||
|
record = {"dedupe_key": dedupe_key, "created_at_ms": _now_ms()}
|
||||||
|
data["records"][dedupe_key] = record
|
||||||
|
record.update(
|
||||||
|
{
|
||||||
|
"status": status,
|
||||||
|
"run_id": run_id,
|
||||||
|
"reply": reply,
|
||||||
|
"error": error,
|
||||||
|
"updated_at_ms": _now_ms(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._save(data)
|
||||||
|
|
||||||
|
def _load(self) -> dict[str, Any]:
|
||||||
|
if not self.path.exists():
|
||||||
|
return {"records": {}}
|
||||||
|
try:
|
||||||
|
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return {"records": {}}
|
||||||
|
if not isinstance(data, dict) or not isinstance(data.get("records"), dict):
|
||||||
|
return {"records": {}}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _save(self, data: dict[str, Any]) -> None:
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||||
|
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||||
|
tmp_path.replace(self.path)
|
||||||
|
|
||||||
|
def _prune_unlocked(self, data: dict[str, Any], now_ms: int) -> None:
|
||||||
|
records = data.get("records", {})
|
||||||
|
expired_before = now_ms - self.retention_ms
|
||||||
|
for key, record in list(records.items()):
|
||||||
|
updated_at_ms = int(record.get("updated_at_ms") or record.get("created_at_ms") or 0)
|
||||||
|
if updated_at_ms < expired_before:
|
||||||
|
records.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelEventLog:
|
||||||
|
def __init__(self, path: Path) -> None:
|
||||||
|
self.path = path
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
session_id: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
|
run_id: str | None = None,
|
||||||
|
status: str = "ok",
|
||||||
|
error: str | None = None,
|
||||||
|
text: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
entry = {
|
||||||
|
"event_id": uuid4().hex,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
"kind": kind,
|
||||||
|
"session_id": session_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
"run_id": run_id,
|
||||||
|
"status": status,
|
||||||
|
"error": error,
|
||||||
|
"text_preview": (text or "")[:120] if text else None,
|
||||||
|
"text_length": len(text) if text else 0,
|
||||||
|
"metadata": metadata or {},
|
||||||
|
"created_at": _iso_now(),
|
||||||
|
}
|
||||||
|
with self._lock:
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with self.path.open("a", encoding="utf-8") as handle:
|
||||||
|
handle.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def recent(self, *, channel_id: str | None = None, limit: int = 100) -> list[dict[str, Any]]:
|
||||||
|
if not self.path.exists():
|
||||||
|
return []
|
||||||
|
lines = self.path.read_text(encoding="utf-8").splitlines()
|
||||||
|
items: list[dict[str, Any]] = []
|
||||||
|
for line in reversed(lines):
|
||||||
|
try:
|
||||||
|
item = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if channel_id and item.get("channel_id") != channel_id:
|
||||||
|
continue
|
||||||
|
items.append(item)
|
||||||
|
if len(items) >= max(1, int(limit)):
|
||||||
|
break
|
||||||
|
return list(reversed(items))
|
||||||
@ -0,0 +1,301 @@
|
|||||||
|
"""Text-only terminal WebSocket channel adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import suppress
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastapi import WebSocket
|
||||||
|
from starlette.websockets import WebSocketDisconnect
|
||||||
|
except ModuleNotFoundError: # pragma: no cover - import-only fallback
|
||||||
|
class WebSocketDisconnect(Exception):
|
||||||
|
"""Fallback disconnect exception for skeleton import environments."""
|
||||||
|
|
||||||
|
class WebSocket: # type: ignore[override]
|
||||||
|
"""Fallback websocket annotation shim."""
|
||||||
|
|
||||||
|
|
||||||
|
def _clean(value: Any) -> str:
|
||||||
|
return str(value or "").strip()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class TerminalConnection:
|
||||||
|
websocket: WebSocket
|
||||||
|
peer_id: str
|
||||||
|
session_id: str
|
||||||
|
thread_id: str | None = None
|
||||||
|
user_id: str | None = None
|
||||||
|
device_name: str = ""
|
||||||
|
capabilities: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TerminalWebSocketAdapter:
|
||||||
|
"""Accept text terminal websocket frames and deliver final assistant replies."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
channel_id: str,
|
||||||
|
kind: str,
|
||||||
|
mode: str,
|
||||||
|
account_id: str,
|
||||||
|
display_name: str = "",
|
||||||
|
inbound_sink: ChannelInboundSink,
|
||||||
|
event_recorder: Callable[..., None] | None = None,
|
||||||
|
heartbeat_seconds: float = 30,
|
||||||
|
max_message_chars: int = 20000,
|
||||||
|
) -> None:
|
||||||
|
self.channel_id = channel_id
|
||||||
|
self.kind = kind
|
||||||
|
self.mode = mode
|
||||||
|
self.account_id = account_id
|
||||||
|
self.display_name = display_name or channel_id
|
||||||
|
self.inbound_sink = inbound_sink
|
||||||
|
self.event_recorder = event_recorder
|
||||||
|
self.heartbeat_seconds = max(1.0, float(heartbeat_seconds))
|
||||||
|
self.max_message_chars = max(1, int(max_message_chars))
|
||||||
|
self.started = False
|
||||||
|
self._connections_by_session: dict[str, TerminalConnection] = {}
|
||||||
|
self._session_by_peer: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self.started = True
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self.started = False
|
||||||
|
for connection in list(self._connections_by_session.values()):
|
||||||
|
with suppress(Exception):
|
||||||
|
await connection.websocket.close(code=1001)
|
||||||
|
self._connections_by_session.clear()
|
||||||
|
self._session_by_peer.clear()
|
||||||
|
|
||||||
|
def status_extra(self) -> dict[str, Any]:
|
||||||
|
return {"connected_peers": len(self._connections_by_session)}
|
||||||
|
|
||||||
|
async def handle_websocket(self, websocket: WebSocket) -> None:
|
||||||
|
await websocket.accept()
|
||||||
|
connection: TerminalConnection | None = None
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
payload = await websocket.receive_json()
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
await websocket.send_json({"type": "error", "error": "Invalid websocket JSON payload"})
|
||||||
|
continue
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
await websocket.send_json({"type": "error", "error": "Websocket payload must be a JSON object"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_type = _clean(payload.get("type")).lower()
|
||||||
|
if frame_type == "ping":
|
||||||
|
await websocket.send_json({"type": "pong"})
|
||||||
|
continue
|
||||||
|
if frame_type == "connect":
|
||||||
|
connection = await self._handle_connect(websocket, payload, current=connection)
|
||||||
|
continue
|
||||||
|
if frame_type == "message":
|
||||||
|
if connection is None:
|
||||||
|
await websocket.send_json({"type": "error", "error": "connect is required before message"})
|
||||||
|
continue
|
||||||
|
await self._handle_message(websocket, connection, payload)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"error": f"Unsupported websocket frame type: {frame_type or '<empty>'}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if connection is not None:
|
||||||
|
self._remove_connection(connection)
|
||||||
|
self._record(
|
||||||
|
kind="terminal_disconnected",
|
||||||
|
session_id=connection.session_id,
|
||||||
|
metadata={"peer_id": connection.peer_id, "device_name": connection.device_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_connect(
|
||||||
|
self,
|
||||||
|
websocket: WebSocket,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
*,
|
||||||
|
current: TerminalConnection | None,
|
||||||
|
) -> TerminalConnection | None:
|
||||||
|
peer_id = _clean(payload.get("peer_id"))
|
||||||
|
if not peer_id:
|
||||||
|
await websocket.send_json({"type": "error", "error": "peer_id is required"})
|
||||||
|
return current
|
||||||
|
|
||||||
|
thread_id = _clean(payload.get("thread_id")) or None
|
||||||
|
user_id = _clean(payload.get("user_id")) or None
|
||||||
|
device_name = _clean(payload.get("device_name"))
|
||||||
|
capabilities = [str(item) for item in payload.get("capabilities") or [] if item is not None]
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=self.kind,
|
||||||
|
account_id=self.account_id,
|
||||||
|
peer_id=peer_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
peer_type="terminal",
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
session_id = identity.session_id()
|
||||||
|
connection = TerminalConnection(
|
||||||
|
websocket=websocket,
|
||||||
|
peer_id=peer_id,
|
||||||
|
session_id=session_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
user_id=user_id,
|
||||||
|
device_name=device_name,
|
||||||
|
capabilities=capabilities,
|
||||||
|
)
|
||||||
|
|
||||||
|
if current is not None and current.session_id != session_id:
|
||||||
|
self._remove_connection(current)
|
||||||
|
old = self._connections_by_session.get(session_id)
|
||||||
|
if old is not None and old.websocket is not websocket:
|
||||||
|
with suppress(Exception):
|
||||||
|
await old.websocket.close(code=1000)
|
||||||
|
self._connections_by_session[session_id] = connection
|
||||||
|
self._session_by_peer[peer_id] = session_id
|
||||||
|
self._record(
|
||||||
|
kind="terminal_connected",
|
||||||
|
session_id=session_id,
|
||||||
|
metadata={"peer_id": peer_id, "device_name": device_name, "capabilities": capabilities},
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "connected",
|
||||||
|
"channel_id": self.channel_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return connection
|
||||||
|
|
||||||
|
async def _handle_message(
|
||||||
|
self,
|
||||||
|
websocket: WebSocket,
|
||||||
|
connection: TerminalConnection,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
message_id = _clean(payload.get("message_id"))
|
||||||
|
text = _clean(payload.get("text"))
|
||||||
|
if not message_id:
|
||||||
|
await websocket.send_json({"type": "error", "error": "message_id is required"})
|
||||||
|
return
|
||||||
|
if not text:
|
||||||
|
await websocket.send_json({"type": "error", "error": "text is required"})
|
||||||
|
return
|
||||||
|
if len(text) > self.max_message_chars:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"error": f"text exceeds max_message_chars ({self.max_message_chars})",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
thread_id = _clean(payload.get("thread_id")) or connection.thread_id
|
||||||
|
user_id = _clean(payload.get("user_id")) or connection.user_id
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=self.kind,
|
||||||
|
account_id=self.account_id,
|
||||||
|
peer_id=connection.peer_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
peer_type="terminal",
|
||||||
|
user_id=user_id,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
inbound = InboundMessage(
|
||||||
|
channel=self.channel_id,
|
||||||
|
content=text,
|
||||||
|
content_type="text",
|
||||||
|
user_id=user_id,
|
||||||
|
channel_identity=identity,
|
||||||
|
metadata={
|
||||||
|
"terminal": {
|
||||||
|
"peer_id": connection.peer_id,
|
||||||
|
"device_name": connection.device_name,
|
||||||
|
"capabilities": connection.capabilities,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
accept = await self.inbound_sink.accept_inbound(inbound)
|
||||||
|
ack: dict[str, Any] = {
|
||||||
|
"type": "ack",
|
||||||
|
"message_id": message_id,
|
||||||
|
"session_id": accept.session_id or identity.session_id(),
|
||||||
|
"accepted": accept.accepted,
|
||||||
|
}
|
||||||
|
if accept.duplicate:
|
||||||
|
ack["duplicate"] = True
|
||||||
|
ack["pending"] = accept.pending
|
||||||
|
record = accept.record or {}
|
||||||
|
if record.get("reply"):
|
||||||
|
ack["reply"] = record["reply"]
|
||||||
|
if accept.error or record.get("error"):
|
||||||
|
ack["error"] = accept.error or record.get("error")
|
||||||
|
await websocket.send_json(ack)
|
||||||
|
|
||||||
|
async def send(self, message: OutboundMessage) -> None:
|
||||||
|
session_id = message.session_id
|
||||||
|
if not session_id and message.channel_identity is not None:
|
||||||
|
session_id = message.channel_identity.session_id()
|
||||||
|
connection = self._connections_by_session.get(session_id or "")
|
||||||
|
if connection is None:
|
||||||
|
message.metadata["delivery_status"] = "unclaimed"
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"message_id": message.channel_identity.message_id if message.channel_identity else message.message_id,
|
||||||
|
"run_id": message.run_id,
|
||||||
|
"text": message.content,
|
||||||
|
"finish_reason": message.finish_reason,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await connection.websocket.send_json(payload)
|
||||||
|
except Exception:
|
||||||
|
message.metadata["delivery_status"] = "unclaimed"
|
||||||
|
self._remove_connection(connection)
|
||||||
|
|
||||||
|
def _remove_connection(self, connection: TerminalConnection) -> None:
|
||||||
|
current = self._connections_by_session.get(connection.session_id)
|
||||||
|
if current is connection:
|
||||||
|
self._connections_by_session.pop(connection.session_id, None)
|
||||||
|
if self._session_by_peer.get(connection.peer_id) == connection.session_id:
|
||||||
|
self._session_by_peer.pop(connection.peer_id, None)
|
||||||
|
|
||||||
|
def _record(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
kind: str,
|
||||||
|
session_id: str | None = None,
|
||||||
|
message_id: str | None = None,
|
||||||
|
status: str = "ok",
|
||||||
|
error: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
if self.event_recorder is None:
|
||||||
|
return
|
||||||
|
self.event_recorder(
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
kind=kind,
|
||||||
|
session_id=session_id,
|
||||||
|
message_id=message_id,
|
||||||
|
status=status,
|
||||||
|
error=error,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
File diff suppressed because it is too large
Load Diff
152
app-instance/backend/beaver/interfaces/web/keycloak_auth.py
Normal file
152
app-instance/backend/beaver/interfaces/web/keycloak_auth.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
"""Keycloak OIDC token verification for the Beaver web app."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from jwt import PyJWKClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
except ModuleNotFoundError: # pragma: no cover
|
||||||
|
class HTTPException(Exception): # type: ignore[override]
|
||||||
|
def __init__(self, status_code: int, detail: str) -> None:
|
||||||
|
super().__init__(detail)
|
||||||
|
self.status_code = status_code
|
||||||
|
self.detail = detail
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_KEYCLOAK_ISSUER = "https://keycloak.bwgdi.com/realms/beaver"
|
||||||
|
DEFAULT_KEYCLOAK_CLIENT_ID = "beaver-agnet"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class KeycloakAuthConfig:
|
||||||
|
issuer: str
|
||||||
|
client_id: str
|
||||||
|
token_url: str
|
||||||
|
jwks_url: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "KeycloakAuthConfig":
|
||||||
|
issuer = _clean_base_url(os.getenv("BEAVER_KEYCLOAK_ISSUER") or DEFAULT_KEYCLOAK_ISSUER)
|
||||||
|
client_id = (os.getenv("BEAVER_KEYCLOAK_CLIENT_ID") or DEFAULT_KEYCLOAK_CLIENT_ID).strip()
|
||||||
|
token_url = (
|
||||||
|
os.getenv("BEAVER_KEYCLOAK_TOKEN_URL", "").strip()
|
||||||
|
or f"{issuer}/protocol/openid-connect/token"
|
||||||
|
)
|
||||||
|
jwks_url = (
|
||||||
|
os.getenv("BEAVER_KEYCLOAK_JWKS_URL", "").strip()
|
||||||
|
or f"{issuer}/protocol/openid-connect/certs"
|
||||||
|
)
|
||||||
|
return cls(issuer=issuer, client_id=client_id, token_url=token_url, jwks_url=jwks_url)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class KeycloakIdentity:
|
||||||
|
user_id: str
|
||||||
|
username: str
|
||||||
|
email: str = ""
|
||||||
|
name: str = ""
|
||||||
|
realm_roles: tuple[str, ...] = ()
|
||||||
|
client_roles: tuple[str, ...] = ()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_bearer_token(authorization: str | None) -> str:
|
||||||
|
if not authorization:
|
||||||
|
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
||||||
|
prefix = "bearer "
|
||||||
|
if not authorization.lower().startswith(prefix):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid Authorization header")
|
||||||
|
token = authorization[len(prefix):].strip()
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token")
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class KeycloakTokenVerifier:
|
||||||
|
def __init__(self, *, config: KeycloakAuthConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self._jwks_client = PyJWKClient(config.jwks_url)
|
||||||
|
|
||||||
|
def verify(self, token: str, *, expected_nonce: str | None = None) -> KeycloakIdentity:
|
||||||
|
try:
|
||||||
|
signing_key = self._jwks_client.get_signing_key_from_jwt(token).key
|
||||||
|
claims = jwt.decode(
|
||||||
|
token,
|
||||||
|
signing_key,
|
||||||
|
algorithms=["RS256"],
|
||||||
|
issuer=self.config.issuer,
|
||||||
|
options={
|
||||||
|
"require": ["exp", "iat", "iss"],
|
||||||
|
"verify_aud": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001 - normalize JWT/JWKS failures for HTTP callers
|
||||||
|
raise HTTPException(status_code=401, detail=f"Invalid token: {exc}") from exc
|
||||||
|
return self.validate_claims(claims, expected_nonce=expected_nonce)
|
||||||
|
|
||||||
|
def validate_claims(self, claims: dict[str, Any], *, expected_nonce: str | None = None) -> KeycloakIdentity:
|
||||||
|
now = int(time.time())
|
||||||
|
issuer = str(claims.get("iss") or "")
|
||||||
|
if issuer != self.config.issuer:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token issuer")
|
||||||
|
exp = _int_claim(claims, "exp")
|
||||||
|
iat = _int_claim(claims, "iat")
|
||||||
|
if exp <= now:
|
||||||
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
|
if iat > now + 120:
|
||||||
|
raise HTTPException(status_code=401, detail="Token issued in the future")
|
||||||
|
if not _matches_client(claims.get("aud"), self.config.client_id) and claims.get("azp") != self.config.client_id:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token audience")
|
||||||
|
if expected_nonce is not None and claims.get("nonce") != expected_nonce:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token nonce")
|
||||||
|
|
||||||
|
user_id = str(claims.get("sub") or "").strip()
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="Token subject is required")
|
||||||
|
username = (
|
||||||
|
str(claims.get("preferred_username") or "").strip()
|
||||||
|
or str(claims.get("email") or "").strip()
|
||||||
|
or user_id
|
||||||
|
)
|
||||||
|
return KeycloakIdentity(
|
||||||
|
user_id=user_id,
|
||||||
|
username=username,
|
||||||
|
email=str(claims.get("email") or "").strip(),
|
||||||
|
name=str(claims.get("name") or "").strip(),
|
||||||
|
realm_roles=_roles_from(claims.get("realm_access")),
|
||||||
|
client_roles=_roles_from((claims.get("resource_access") or {}).get(self.config.client_id) if isinstance(claims.get("resource_access"), dict) else None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_base_url(value: str) -> str:
|
||||||
|
return value.strip().rstrip("/")
|
||||||
|
|
||||||
|
|
||||||
|
def _int_claim(claims: dict[str, Any], key: str) -> int:
|
||||||
|
try:
|
||||||
|
return int(claims[key])
|
||||||
|
except (KeyError, TypeError, ValueError) as exc:
|
||||||
|
raise HTTPException(status_code=401, detail=f"Token {key} claim is required") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _matches_client(audience: Any, client_id: str) -> bool:
|
||||||
|
if isinstance(audience, str):
|
||||||
|
return audience == client_id
|
||||||
|
if isinstance(audience, list):
|
||||||
|
return client_id in {str(item) for item in audience}
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _roles_from(value: Any) -> tuple[str, ...]:
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
return ()
|
||||||
|
roles = value.get("roles")
|
||||||
|
if not isinstance(roles, list):
|
||||||
|
return ()
|
||||||
|
return tuple(str(role) for role in roles if str(role).strip())
|
||||||
@ -1,11 +1,25 @@
|
|||||||
"""Web request and response schemas."""
|
"""Web request and response schemas."""
|
||||||
|
|
||||||
from .chat import (
|
from .chat import (
|
||||||
|
WebChatAcceptanceRequest,
|
||||||
|
WebChatAcceptanceResponse,
|
||||||
WebChatFeedbackRequest,
|
WebChatFeedbackRequest,
|
||||||
WebChatFeedbackResponse,
|
WebChatFeedbackResponse,
|
||||||
WebChatRequest,
|
WebChatRequest,
|
||||||
WebChatResponse,
|
WebChatResponse,
|
||||||
WebErrorResponse,
|
WebErrorResponse,
|
||||||
|
WebAgentConfigRequest,
|
||||||
|
WebAgentConfigResponse,
|
||||||
|
WebChannelConfigRequest,
|
||||||
|
WebChannelConfigResponse,
|
||||||
|
WebChannelConnectionCreateRequest,
|
||||||
|
WebChannelConnectionResponse,
|
||||||
|
WebChannelConnectionUpdateRequest,
|
||||||
|
WebChannelValidationResponse,
|
||||||
|
WebConnectorBridgeEventRequest,
|
||||||
|
WebConnectorBridgeEventResponse,
|
||||||
|
WebConnectorSessionCreateRequest,
|
||||||
|
WebConnectorSessionResponse,
|
||||||
WebProviderConfigRequest,
|
WebProviderConfigRequest,
|
||||||
WebProviderConfigResponse,
|
WebProviderConfigResponse,
|
||||||
WebProviderTarget,
|
WebProviderTarget,
|
||||||
@ -13,11 +27,25 @@ from .chat import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"WebChatAcceptanceRequest",
|
||||||
|
"WebChatAcceptanceResponse",
|
||||||
"WebChatFeedbackRequest",
|
"WebChatFeedbackRequest",
|
||||||
"WebChatFeedbackResponse",
|
"WebChatFeedbackResponse",
|
||||||
"WebChatRequest",
|
"WebChatRequest",
|
||||||
"WebChatResponse",
|
"WebChatResponse",
|
||||||
"WebErrorResponse",
|
"WebErrorResponse",
|
||||||
|
"WebAgentConfigRequest",
|
||||||
|
"WebAgentConfigResponse",
|
||||||
|
"WebChannelConfigRequest",
|
||||||
|
"WebChannelConfigResponse",
|
||||||
|
"WebChannelConnectionCreateRequest",
|
||||||
|
"WebChannelConnectionResponse",
|
||||||
|
"WebChannelConnectionUpdateRequest",
|
||||||
|
"WebChannelValidationResponse",
|
||||||
|
"WebConnectorBridgeEventRequest",
|
||||||
|
"WebConnectorBridgeEventResponse",
|
||||||
|
"WebConnectorSessionCreateRequest",
|
||||||
|
"WebConnectorSessionResponse",
|
||||||
"WebProviderConfigRequest",
|
"WebProviderConfigRequest",
|
||||||
"WebProviderConfigResponse",
|
"WebProviderConfigResponse",
|
||||||
"WebProviderTarget",
|
"WebProviderTarget",
|
||||||
|
|||||||
@ -55,6 +55,7 @@ class WebChatRequest(BaseModel):
|
|||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
execution_context: str | None = None
|
execution_context: str | None = None
|
||||||
|
prompt_locale: str | None = None
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
provider_name: str | None = None
|
provider_name: str | None = None
|
||||||
embedding_model: str | None = None
|
embedding_model: str | None = None
|
||||||
@ -82,11 +83,34 @@ class WebChatResponse(BaseModel):
|
|||||||
usage: dict[str, Any] = Field(default_factory=dict)
|
usage: dict[str, Any] = Field(default_factory=dict)
|
||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
task_status: str | None = None
|
task_status: str | None = None
|
||||||
|
evidence_status: str | None = None
|
||||||
|
acceptance_state: str | None = None
|
||||||
validation_result: dict[str, Any] | None = None
|
validation_result: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WebChatAcceptanceRequest(BaseModel):
|
||||||
|
"""User acceptance on the latest assistant result in chat."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
run_id: str
|
||||||
|
acceptance_type: str
|
||||||
|
comment: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WebChatAcceptanceResponse(BaseModel):
|
||||||
|
"""Acceptance recording result."""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
run_id: str
|
||||||
|
task_id: str
|
||||||
|
task_status: str
|
||||||
|
acceptance_type: str
|
||||||
|
feedback_type: str
|
||||||
|
learning_candidates: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class WebChatFeedbackRequest(BaseModel):
|
class WebChatFeedbackRequest(BaseModel):
|
||||||
"""Feedback on the latest assistant result in chat."""
|
"""Backward-compatible feedback payload."""
|
||||||
|
|
||||||
session_id: str
|
session_id: str
|
||||||
run_id: str
|
run_id: str
|
||||||
@ -94,15 +118,8 @@ class WebChatFeedbackRequest(BaseModel):
|
|||||||
comment: str | None = None
|
comment: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class WebChatFeedbackResponse(BaseModel):
|
class WebChatFeedbackResponse(WebChatAcceptanceResponse):
|
||||||
"""Feedback recording result."""
|
"""Backward-compatible feedback response."""
|
||||||
|
|
||||||
session_id: str
|
|
||||||
run_id: str
|
|
||||||
task_id: str
|
|
||||||
task_status: str
|
|
||||||
feedback_type: str
|
|
||||||
learning_candidates: list[dict[str, Any]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class WebProviderConfigRequest(BaseModel):
|
class WebProviderConfigRequest(BaseModel):
|
||||||
@ -123,6 +140,127 @@ class WebProviderConfigResponse(BaseModel):
|
|||||||
enabled: bool
|
enabled: bool
|
||||||
|
|
||||||
|
|
||||||
|
class WebChannelConfigRequest(BaseModel):
|
||||||
|
"""Channel config update from the settings page."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
kind: str
|
||||||
|
mode: str
|
||||||
|
account_id: str | None = None
|
||||||
|
display_name: str | None = None
|
||||||
|
config: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
secrets: dict[str, str | None] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class WebChannelConfigResponse(BaseModel):
|
||||||
|
"""Channel config update result."""
|
||||||
|
|
||||||
|
ok: bool
|
||||||
|
channel_id: str
|
||||||
|
restart_required: bool
|
||||||
|
channel: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WebChannelConnectionCreateRequest(BaseModel):
|
||||||
|
"""Create a channel connection from the setup UI."""
|
||||||
|
|
||||||
|
kind: str
|
||||||
|
mode: str
|
||||||
|
display_name: str | None = Field(default=None, alias="displayName")
|
||||||
|
owner_user_id: str | None = Field(default=None, alias="ownerUserId")
|
||||||
|
auth_type: str = Field(default="token", alias="authType")
|
||||||
|
account_id: str | None = Field(default=None, alias="accountId")
|
||||||
|
config: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
secrets: dict[str, str | None] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class WebChannelConnectionResponse(BaseModel):
|
||||||
|
"""Channel connection response with redacted credentials."""
|
||||||
|
|
||||||
|
connection: dict[str, Any]
|
||||||
|
credentials: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class WebChannelConnectionUpdateRequest(BaseModel):
|
||||||
|
"""Update editable channel connection setup fields."""
|
||||||
|
|
||||||
|
display_name: str | None = Field(default=None, alias="displayName")
|
||||||
|
account_id: str | None = Field(default=None, alias="accountId")
|
||||||
|
config: dict[str, Any] | None = None
|
||||||
|
secrets: dict[str, str | None] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WebChannelValidationResponse(BaseModel):
|
||||||
|
"""Connector validation response."""
|
||||||
|
|
||||||
|
ok: bool
|
||||||
|
status: str
|
||||||
|
account_id: str | None = None
|
||||||
|
display_name: str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
connection: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WebConnectorBridgeEventRequest(BaseModel):
|
||||||
|
"""Inbound connector bridge event from the external sidecar."""
|
||||||
|
|
||||||
|
event_id: str = Field(alias="eventId")
|
||||||
|
timestamp: str
|
||||||
|
delivery_attempt: int = Field(default=1, alias="deliveryAttempt")
|
||||||
|
connection_id: str = Field(alias="connectionId")
|
||||||
|
channel_id: str = Field(alias="channelId")
|
||||||
|
kind: str
|
||||||
|
account_id: str = Field(alias="accountId")
|
||||||
|
peer_id: str = Field(alias="peerId")
|
||||||
|
peer_type: str = Field(default="unknown", alias="peerType")
|
||||||
|
user_id: str | None = Field(default=None, alias="userId")
|
||||||
|
thread_id: str | None = Field(default=None, alias="threadId")
|
||||||
|
message_id: str = Field(alias="messageId")
|
||||||
|
message_type: str = Field(default="text", alias="messageType")
|
||||||
|
content: str
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class WebConnectorBridgeEventResponse(BaseModel):
|
||||||
|
"""Connector bridge event accept/dedupe response."""
|
||||||
|
|
||||||
|
accepted: bool
|
||||||
|
duplicate: bool = False
|
||||||
|
pending: bool = False
|
||||||
|
retry_after_seconds: int | None = Field(default=None, alias="retryAfterSeconds")
|
||||||
|
|
||||||
|
|
||||||
|
class WebConnectorSessionCreateRequest(BaseModel):
|
||||||
|
"""Start a connector-managed onboarding session."""
|
||||||
|
|
||||||
|
kind: str
|
||||||
|
display_name: str | None = Field(default=None, alias="displayName")
|
||||||
|
owner_user_id: str | None = Field(default=None, alias="ownerUserId")
|
||||||
|
options: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class WebConnectorSessionResponse(BaseModel):
|
||||||
|
"""Connector session view plus optional connection view."""
|
||||||
|
|
||||||
|
session: dict[str, Any]
|
||||||
|
connection: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WebAgentConfigRequest(BaseModel):
|
||||||
|
"""Agent runtime defaults update from the settings page."""
|
||||||
|
|
||||||
|
max_tokens: int | None = None
|
||||||
|
temperature: float
|
||||||
|
max_tool_iterations: int
|
||||||
|
|
||||||
|
|
||||||
|
class WebAgentConfigResponse(BaseModel):
|
||||||
|
"""Agent runtime defaults update result."""
|
||||||
|
|
||||||
|
ok: bool
|
||||||
|
|
||||||
|
|
||||||
class WebStatusResponse(BaseModel):
|
class WebStatusResponse(BaseModel):
|
||||||
"""Web 宿主层状态响应。"""
|
"""Web 宿主层状态响应。"""
|
||||||
|
|
||||||
|
|||||||
@ -227,6 +227,15 @@ class SkillDraftEvalReport:
|
|||||||
cases: list[dict[str, Any]] = field(default_factory=list)
|
cases: list[dict[str, Any]] = field(default_factory=list)
|
||||||
status: str = "completed"
|
status: str = "completed"
|
||||||
created_at: str = ""
|
created_at: str = ""
|
||||||
|
eval_version: str = "heuristic-v1"
|
||||||
|
mode: str = "heuristic"
|
||||||
|
execution_coverage: float = 0.0
|
||||||
|
surrogate_coverage: float = 0.0
|
||||||
|
blocked_coverage: float = 0.0
|
||||||
|
confidence: str = "low"
|
||||||
|
case_reports: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
tool_mode_summary: dict[str, Any] = field(default_factory=dict)
|
||||||
|
preservation_report: dict[str, Any] | None = None
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@ -244,6 +253,17 @@ class SkillDraftEvalReport:
|
|||||||
"cases": [dict(item) for item in self.cases],
|
"cases": [dict(item) for item in self.cases],
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
"created_at": self.created_at,
|
"created_at": self.created_at,
|
||||||
|
"eval_version": self.eval_version,
|
||||||
|
"mode": self.mode,
|
||||||
|
"execution_coverage": self.execution_coverage,
|
||||||
|
"surrogate_coverage": self.surrogate_coverage,
|
||||||
|
"blocked_coverage": self.blocked_coverage,
|
||||||
|
"confidence": self.confidence,
|
||||||
|
"case_reports": [dict(item) for item in self.case_reports],
|
||||||
|
"tool_mode_summary": dict(self.tool_mode_summary),
|
||||||
|
"preservation_report": (
|
||||||
|
dict(self.preservation_report) if self.preservation_report is not None else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -263,6 +283,23 @@ class SkillDraftEvalReport:
|
|||||||
cases=[dict(item) for item in payload.get("cases") or [] if isinstance(item, dict)],
|
cases=[dict(item) for item in payload.get("cases") or [] if isinstance(item, dict)],
|
||||||
status=str(payload.get("status") or "completed"),
|
status=str(payload.get("status") or "completed"),
|
||||||
created_at=str(payload.get("created_at") or ""),
|
created_at=str(payload.get("created_at") or ""),
|
||||||
|
eval_version=str(payload.get("eval_version") or "heuristic-v1"),
|
||||||
|
mode=str(payload.get("mode") or "heuristic"),
|
||||||
|
execution_coverage=_bounded_float(payload.get("execution_coverage"), default=0.0),
|
||||||
|
surrogate_coverage=_bounded_float(payload.get("surrogate_coverage"), default=0.0),
|
||||||
|
blocked_coverage=_bounded_float(payload.get("blocked_coverage"), default=0.0),
|
||||||
|
confidence=str(payload.get("confidence") or "low"),
|
||||||
|
case_reports=[
|
||||||
|
dict(item)
|
||||||
|
for item in payload.get("case_reports") or []
|
||||||
|
if isinstance(item, dict)
|
||||||
|
],
|
||||||
|
tool_mode_summary=dict(payload.get("tool_mode_summary") or {}),
|
||||||
|
preservation_report=(
|
||||||
|
dict(payload["preservation_report"])
|
||||||
|
if isinstance(payload.get("preservation_report"), dict)
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -272,6 +309,15 @@ def _optional_str(value: Any) -> str | None:
|
|||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _bounded_float(value: Any, *, default: float = 0.0) -> float:
|
||||||
|
if value in (None, ""):
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return max(0.0, min(1.0, float(value)))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
def _summarize_evidence(payload: dict[str, Any]) -> str:
|
def _summarize_evidence(payload: dict[str, Any]) -> str:
|
||||||
evidence = payload.get("evidence")
|
evidence = payload.get("evidence")
|
||||||
if isinstance(evidence, dict):
|
if isinstance(evidence, dict):
|
||||||
|
|||||||
5
app-instance/backend/beaver/prompts/__init__.py
Normal file
5
app-instance/backend/beaver/prompts/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""Prompt templates used by Beaver runtime components."""
|
||||||
|
|
||||||
|
from .main_agent import get_main_agent_prompt
|
||||||
|
|
||||||
|
__all__ = ["get_main_agent_prompt"]
|
||||||
55
app-instance/backend/beaver/prompts/main_agent.py
Normal file
55
app-instance/backend/beaver/prompts/main_agent.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"""Locale-aware main agent prompt loading."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DEFAULT_MAIN_AGENT_PROMPT_LOCALE = "zh-Hans"
|
||||||
|
|
||||||
|
_PROMPT_FILES = {
|
||||||
|
"zh-Hans": "zh-Hans.md",
|
||||||
|
"zh-Hant": "zh-Hant.md",
|
||||||
|
"en": "en.md",
|
||||||
|
}
|
||||||
|
|
||||||
|
_LOCALE_ALIASES = {
|
||||||
|
"zh": "zh-Hans",
|
||||||
|
"zh-cn": "zh-Hans",
|
||||||
|
"zh-hans": "zh-Hans",
|
||||||
|
"zh-sg": "zh-Hans",
|
||||||
|
"zh-hant": "zh-Hant",
|
||||||
|
"zh-tw": "zh-Hant",
|
||||||
|
"zh-hk": "zh-Hant",
|
||||||
|
"zh-mo": "zh-Hant",
|
||||||
|
"en": "en",
|
||||||
|
"en-us": "en",
|
||||||
|
"en-gb": "en",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_main_agent_prompt(locale: str | None = None) -> str:
|
||||||
|
"""Return the main-agent identity prompt for a prompt locale."""
|
||||||
|
|
||||||
|
prompt_locale = normalize_main_agent_prompt_locale(locale)
|
||||||
|
return _load_main_agent_prompt(prompt_locale)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_main_agent_prompt_locale(locale: str | None = None) -> str:
|
||||||
|
cleaned = (locale or DEFAULT_MAIN_AGENT_PROMPT_LOCALE).strip()
|
||||||
|
if not cleaned:
|
||||||
|
return DEFAULT_MAIN_AGENT_PROMPT_LOCALE
|
||||||
|
normalized = _LOCALE_ALIASES.get(cleaned.lower())
|
||||||
|
if normalized:
|
||||||
|
return normalized
|
||||||
|
return cleaned if cleaned in _PROMPT_FILES else DEFAULT_MAIN_AGENT_PROMPT_LOCALE
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=len(_PROMPT_FILES))
|
||||||
|
def _load_main_agent_prompt(locale: str) -> str:
|
||||||
|
filename = _PROMPT_FILES.get(locale, _PROMPT_FILES[DEFAULT_MAIN_AGENT_PROMPT_LOCALE])
|
||||||
|
path = Path(__file__).with_name("main_agent") / filename
|
||||||
|
if not path.exists():
|
||||||
|
fallback_path = Path(__file__).with_name("main_agent") / _PROMPT_FILES[DEFAULT_MAIN_AGENT_PROMPT_LOCALE]
|
||||||
|
return fallback_path.read_text(encoding="utf-8").strip()
|
||||||
|
return path.read_text(encoding="utf-8").strip()
|
||||||
7
app-instance/backend/beaver/prompts/main_agent/en.md
Normal file
7
app-instance/backend/beaver/prompts/main_agent/en.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
You are Beaver, an AI assistant developed by Boway Information Systems Co., Ltd.
|
||||||
|
|
||||||
|
When communicating with users, keep this identity consistent. If users ask who you are, say that you are Beaver, an AI assistant developed by Boway Information Systems Co., Ltd.
|
||||||
|
|
||||||
|
# Language
|
||||||
|
|
||||||
|
Use English for user-facing replies, task titles, summaries, plans, and final reports while this prompt is active. If the user explicitly asks for another language, follow that request.
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
你是海狸 (Beaver),博维资讯系统有限公司研发的 AI 助手。
|
||||||
|
|
||||||
|
与用户沟通时,保持这个身份一致。用户问你是谁时,说明你是海狸 (Beaver),博维资讯系统有限公司研发的 AI 助手。
|
||||||
|
|
||||||
|
# 语言
|
||||||
|
|
||||||
|
使用简体中文进行面向用户的回复、任务标题、摘要、计划和最终报告。若用户明确要求其他语言,则按用户要求执行。
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
你是海狸 (Beaver),博維資訊系統有限公司研發的 AI 助手。
|
||||||
|
|
||||||
|
與使用者溝通時,保持這個身份一致。使用者問你是誰時,說明你是海狸 (Beaver),博維資訊系統有限公司研發的 AI 助手。
|
||||||
|
|
||||||
|
# 語言
|
||||||
|
|
||||||
|
使用繁體中文進行面向使用者的回覆、任務標題、摘要、計劃和最終報告。若使用者明確要求其他語言,則按使用者要求執行。
|
||||||
@ -22,7 +22,17 @@ from beaver.engine import AgentLoop, AgentProfile, AgentRunResult, EngineLoader
|
|||||||
from beaver.engine.providers import make_provider_bundle
|
from beaver.engine.providers import make_provider_bundle
|
||||||
from beaver.foundation.events import InboundMessage, OutboundMessage
|
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||||
from beaver.foundation.models import CronJob, CronRunRecord
|
from beaver.foundation.models import CronJob, CronRunRecord
|
||||||
from beaver.tasks import MainAgentRouter, TaskExecutionPlan, TaskRecord, ValidationResult
|
from beaver.prompts.main_agent import normalize_main_agent_prompt_locale
|
||||||
|
from beaver.tasks import (
|
||||||
|
EvidenceBuilder,
|
||||||
|
MainAgentRouter,
|
||||||
|
RunEvidence,
|
||||||
|
TaskEvidencePacket,
|
||||||
|
TaskExecutionPlan,
|
||||||
|
TaskRecord,
|
||||||
|
render_task_evidence,
|
||||||
|
)
|
||||||
|
from beaver.tasks.service import normalize_acceptance_type
|
||||||
|
|
||||||
|
|
||||||
NOTIFICATION_SESSION_ID = "notify:default:scheduled"
|
NOTIFICATION_SESSION_ID = "notify:default:scheduled"
|
||||||
@ -51,11 +61,27 @@ class AgentService:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.profile = profile or AgentProfile()
|
self.profile = profile or AgentProfile()
|
||||||
self.loader = loader or EngineLoader(workspace=workspace, config_path=config_path)
|
self.loader = loader or EngineLoader(workspace=workspace, config_path=config_path)
|
||||||
|
self._apply_configured_profile_defaults()
|
||||||
self._loop: AgentLoop | None = None
|
self._loop: AgentLoop | None = None
|
||||||
self._run_task: asyncio.Task[None] | None = None
|
self._run_task: asyncio.Task[None] | None = None
|
||||||
self._main_agent_router = MainAgentRouter()
|
self._main_agent_router = MainAgentRouter()
|
||||||
self._runtime_services: dict[str, Any] = {}
|
self._runtime_services: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def _apply_configured_profile_defaults(self) -> None:
|
||||||
|
defaults = self.loader.config.agents_defaults
|
||||||
|
self.profile.max_tokens = None
|
||||||
|
self.profile.temperature = 0.2
|
||||||
|
self.profile.max_context_messages = 1000
|
||||||
|
self.profile.max_tool_iterations = 30
|
||||||
|
if defaults.max_tokens is not None:
|
||||||
|
self.profile.max_tokens = max(1, defaults.max_tokens)
|
||||||
|
if defaults.temperature is not None:
|
||||||
|
self.profile.temperature = defaults.temperature
|
||||||
|
if defaults.max_context_messages is not None:
|
||||||
|
self.profile.max_context_messages = max(1, defaults.max_context_messages)
|
||||||
|
if defaults.max_tool_iterations is not None:
|
||||||
|
self.profile.max_tool_iterations = max(0, defaults.max_tool_iterations)
|
||||||
|
|
||||||
def create_loop(self) -> AgentLoop:
|
def create_loop(self) -> AgentLoop:
|
||||||
"""创建并缓存当前 service 使用的 AgentLoop。"""
|
"""创建并缓存当前 service 使用的 AgentLoop。"""
|
||||||
|
|
||||||
@ -223,7 +249,7 @@ class AgentService:
|
|||||||
|
|
||||||
Scheduled jobs are product-level Tasks, not hidden one-off agent turns.
|
Scheduled jobs are product-level Tasks, not hidden one-off agent turns.
|
||||||
This entry bypasses the main-agent classifier and forces Task mode so
|
This entry bypasses the main-agent classifier and forces Task mode so
|
||||||
every trigger produces a TaskRecord, validation, feedback state, and a
|
every trigger produces a TaskRecord, evidence, acceptance state, and a
|
||||||
run_id that the scheduled-task history can link to.
|
run_id that the scheduled-task history can link to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -271,9 +297,9 @@ class AgentService:
|
|||||||
result.run_id,
|
result.run_id,
|
||||||
{
|
{
|
||||||
"message_type": "scheduled_reply",
|
"message_type": "scheduled_reply",
|
||||||
"scheduled_job_id": job.id,
|
"scheduled_job_id": cron_job_id,
|
||||||
"scheduled_run_id": run.scheduled_run_id,
|
"scheduled_run_id": scheduled_run_id,
|
||||||
"cron_job_name": job.name,
|
"cron_job_name": cron_job_name,
|
||||||
"mode": "notification",
|
"mode": "notification",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -394,15 +420,15 @@ class AgentService:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def submit_feedback(
|
async def submit_acceptance(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
run_id: str,
|
run_id: str,
|
||||||
feedback_type: str,
|
acceptance_type: str,
|
||||||
comment: str | None = None,
|
comment: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Record chat feedback for the internal task linked to a run."""
|
"""Record user acceptance for the internal task linked to a run."""
|
||||||
|
|
||||||
loaded = self.create_loop().boot()
|
loaded = self.create_loop().boot()
|
||||||
task_service = self._require_loaded(loaded, "task_service")
|
task_service = self._require_loaded(loaded, "task_service")
|
||||||
@ -410,32 +436,31 @@ class AgentService:
|
|||||||
if task is None or task.session_id != session_id:
|
if task is None or task.session_id != session_id:
|
||||||
raise ValueError(f"No internal task found for run_id={run_id!r}")
|
raise ValueError(f"No internal task found for run_id={run_id!r}")
|
||||||
|
|
||||||
normalized = feedback_type.strip().lower()
|
normalized = normalize_acceptance_type(acceptance_type)
|
||||||
if normalized not in {"satisfied", "revise", "abandon"}:
|
legacy_feedback_type = "satisfied" if normalized == "accept" else normalized
|
||||||
raise ValueError("feedback_type must be one of: satisfied, revise, abandon")
|
|
||||||
|
|
||||||
already_recorded = any(
|
already_recorded = any(
|
||||||
item.get("run_id") == run_id and item.get("feedback_type") == normalized
|
item.get("run_id") == run_id and item.get("acceptance_type") == normalized
|
||||||
for item in task.feedback
|
for item in task.feedback
|
||||||
)
|
)
|
||||||
conflicting_feedback = next(
|
conflicting_acceptance = next(
|
||||||
(
|
(
|
||||||
item
|
item
|
||||||
for item in task.feedback
|
for item in task.feedback
|
||||||
if item.get("run_id") == run_id and item.get("feedback_type") != normalized
|
if item.get("run_id") == run_id and item.get("acceptance_type") != normalized
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if conflicting_feedback is not None:
|
if conflicting_acceptance is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Feedback for run_id={run_id!r} was already recorded as "
|
f"Acceptance for run_id={run_id!r} was already recorded as "
|
||||||
f"{conflicting_feedback.get('feedback_type')!r}"
|
f"{conflicting_acceptance.get('acceptance_type')!r}"
|
||||||
)
|
)
|
||||||
if task.status in {"closed", "abandoned"} and not already_recorded:
|
if task.status in {"closed", "abandoned"} and not already_recorded:
|
||||||
raise ValueError(f"Task {task.task_id} is already finalized as {task.status!r}")
|
raise ValueError(f"Task {task.task_id} is already finalized as {task.status!r}")
|
||||||
updated = task if already_recorded else task_service.add_feedback(
|
updated = task if already_recorded else task_service.add_acceptance(
|
||||||
task.task_id,
|
task.task_id,
|
||||||
feedback_type=normalized,
|
acceptance_type=normalized,
|
||||||
comment=comment,
|
comment=comment,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
)
|
)
|
||||||
@ -446,7 +471,8 @@ class AgentService:
|
|||||||
{
|
{
|
||||||
"task_id": updated.task_id,
|
"task_id": updated.task_id,
|
||||||
"task_status": updated.status,
|
"task_status": updated.status,
|
||||||
"feedback_state": normalized,
|
"acceptance_state": normalized,
|
||||||
|
"feedback_state": legacy_feedback_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if not already_recorded:
|
if not already_recorded:
|
||||||
@ -454,10 +480,11 @@ class AgentService:
|
|||||||
session_id,
|
session_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
role="system",
|
role="system",
|
||||||
event_type="task_feedback_recorded",
|
event_type="task_acceptance_recorded",
|
||||||
event_payload={
|
event_payload={
|
||||||
"task_id": task.task_id,
|
"task_id": task.task_id,
|
||||||
"feedback_type": normalized,
|
"acceptance_type": normalized,
|
||||||
|
"feedback_type": legacy_feedback_type,
|
||||||
"comment": comment,
|
"comment": comment,
|
||||||
"task_status": updated.status,
|
"task_status": updated.status,
|
||||||
},
|
},
|
||||||
@ -466,35 +493,36 @@ class AgentService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
generated_candidates = []
|
generated_candidates = []
|
||||||
validation = ValidationResult.from_dict(updated.validation_result)
|
|
||||||
if not already_recorded:
|
if not already_recorded:
|
||||||
run_memory_store = self._require_loaded(loaded, "run_memory_store")
|
run_memory_store = self._require_loaded(loaded, "run_memory_store")
|
||||||
feedback_payload = {
|
acceptance_payload = {
|
||||||
"feedback_type": normalized,
|
"acceptance_type": normalized,
|
||||||
|
"feedback_type": legacy_feedback_type,
|
||||||
"comment": comment or "",
|
"comment": comment or "",
|
||||||
"task_status": updated.status,
|
"task_status": updated.status,
|
||||||
|
"final_accepted_run_id": updated.metadata.get("final_accepted_run_id"),
|
||||||
}
|
}
|
||||||
run_memory_store.update_run_record(
|
run_memory_store.update_run_record(
|
||||||
run_id,
|
run_id,
|
||||||
success=normalized == "satisfied",
|
success=normalized == "accept",
|
||||||
feedback=feedback_payload,
|
feedback=acceptance_payload,
|
||||||
)
|
)
|
||||||
run_memory_store.update_skill_effects_for_run(
|
run_memory_store.update_skill_effects_for_run(
|
||||||
run_id,
|
run_id,
|
||||||
success=normalized == "satisfied",
|
success=normalized == "accept",
|
||||||
feedback_score=self._feedback_score_for_learning(normalized, validation),
|
feedback_score=self._acceptance_score_for_learning(normalized),
|
||||||
notes=(comment or normalized).strip(),
|
notes=(comment or normalized).strip(),
|
||||||
)
|
)
|
||||||
skill_learning_service = self._require_loaded(loaded, "skill_learning_service")
|
skill_learning_service = self._require_loaded(loaded, "skill_learning_service")
|
||||||
skill_learning_service.rescore_skill_versions()
|
skill_learning_service.rescore_skill_versions()
|
||||||
if already_recorded:
|
if already_recorded:
|
||||||
generated_candidates = []
|
generated_candidates = []
|
||||||
elif normalized == "satisfied" and validation is not None and validation.accepted:
|
elif normalized == "accept":
|
||||||
generated_candidates = [
|
generated_candidates = [
|
||||||
item.to_dict()
|
item.to_dict()
|
||||||
for item in skill_learning_service.build_learning_candidates_for_task(
|
for item in skill_learning_service.build_learning_candidates_for_task(
|
||||||
updated.task_id,
|
updated.task_id,
|
||||||
trigger_run_id=run_id,
|
final_accepted_run_id=run_id,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
elif normalized == "abandon":
|
elif normalized == "abandon":
|
||||||
@ -505,7 +533,8 @@ class AgentService:
|
|||||||
event_type="task_failure_evidence_recorded",
|
event_type="task_failure_evidence_recorded",
|
||||||
event_payload={
|
event_payload={
|
||||||
"task_id": updated.task_id,
|
"task_id": updated.task_id,
|
||||||
"feedback_type": normalized,
|
"acceptance_type": normalized,
|
||||||
|
"feedback_type": legacy_feedback_type,
|
||||||
"comment": comment or "",
|
"comment": comment or "",
|
||||||
"task_status": updated.status,
|
"task_status": updated.status,
|
||||||
"durable_memory_written": False,
|
"durable_memory_written": False,
|
||||||
@ -519,10 +548,28 @@ class AgentService:
|
|||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
"task_id": updated.task_id,
|
"task_id": updated.task_id,
|
||||||
"task_status": updated.status,
|
"task_status": updated.status,
|
||||||
"feedback_type": normalized,
|
"acceptance_type": normalized,
|
||||||
|
"feedback_type": legacy_feedback_type,
|
||||||
"learning_candidates": generated_candidates,
|
"learning_candidates": generated_candidates,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def submit_feedback(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
run_id: str,
|
||||||
|
feedback_type: str,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Backward-compatible wrapper for older clients."""
|
||||||
|
|
||||||
|
return await self.submit_acceptance(
|
||||||
|
session_id=session_id,
|
||||||
|
run_id=run_id,
|
||||||
|
acceptance_type=feedback_type,
|
||||||
|
comment=comment,
|
||||||
|
)
|
||||||
|
|
||||||
async def _process_with_main_agent(
|
async def _process_with_main_agent(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@ -558,6 +605,8 @@ class AgentService:
|
|||||||
if active_task is not None and decision.short_title and not active_task.metadata.get("short_title"):
|
if active_task is not None and decision.short_title and not active_task.metadata.get("short_title"):
|
||||||
active_task.metadata["short_title"] = decision.short_title
|
active_task.metadata["short_title"] = decision.short_title
|
||||||
task_service.store.upsert_task(active_task)
|
task_service.store.upsert_task(active_task)
|
||||||
|
if active_task is not None and (decision.action == "simple_chat" or decision.starts_new_task):
|
||||||
|
await self._accept_active_task_for_new_topic(active_task)
|
||||||
if active_task is not None and decision.closes_task:
|
if active_task is not None and decision.closes_task:
|
||||||
task_service.close_task(active_task.task_id, reason=decision.reason)
|
task_service.close_task(active_task.task_id, reason=decision.reason)
|
||||||
return await runner(message, **kwargs)
|
return await runner(message, **kwargs)
|
||||||
@ -574,6 +623,7 @@ class AgentService:
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
description=message,
|
description=message,
|
||||||
metadata={
|
metadata={
|
||||||
|
"prompt_locale": normalize_main_agent_prompt_locale(kwargs.get("prompt_locale")),
|
||||||
"router_reason": decision.reason,
|
"router_reason": decision.reason,
|
||||||
**({"short_title": decision.short_title} if decision.short_title else {}),
|
**({"short_title": decision.short_title} if decision.short_title else {}),
|
||||||
},
|
},
|
||||||
@ -582,7 +632,7 @@ class AgentService:
|
|||||||
else active_task
|
else active_task
|
||||||
)
|
)
|
||||||
if active_task is not None and decision.action == "revise_task" and task.task_id == active_task.task_id:
|
if active_task is not None and decision.action == "revise_task" and task.task_id == active_task.task_id:
|
||||||
task = self._record_revision_feedback_for_task(
|
task = self._record_revision_acceptance_for_task(
|
||||||
loaded,
|
loaded,
|
||||||
task=task,
|
task=task,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -590,7 +640,21 @@ class AgentService:
|
|||||||
)
|
)
|
||||||
return await self._run_task_mode(message, runner=runner, kwargs=kwargs, task=task)
|
return await self._run_task_mode(message, runner=runner, kwargs=kwargs, task=task)
|
||||||
|
|
||||||
def _record_revision_feedback_for_task(
|
async def _accept_active_task_for_new_topic(self, task: TaskRecord) -> None:
|
||||||
|
"""Accept a completed active Task before routing an unrelated new topic."""
|
||||||
|
|
||||||
|
if task.status != "awaiting_acceptance":
|
||||||
|
return
|
||||||
|
run_id = next((item for item in reversed(task.run_ids) if item), None)
|
||||||
|
if not run_id:
|
||||||
|
return
|
||||||
|
await self.submit_acceptance(
|
||||||
|
session_id=task.session_id,
|
||||||
|
run_id=run_id,
|
||||||
|
acceptance_type="accept",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _record_revision_acceptance_for_task(
|
||||||
self,
|
self,
|
||||||
loaded: Any,
|
loaded: Any,
|
||||||
*,
|
*,
|
||||||
@ -598,9 +662,9 @@ class AgentService:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
comment: str,
|
comment: str,
|
||||||
) -> TaskRecord:
|
) -> TaskRecord:
|
||||||
"""Mark the latest feedback-eligible run as revised before continuing a task."""
|
"""Mark the latest acceptance-eligible run as revised before continuing a task."""
|
||||||
|
|
||||||
if task.status not in {"awaiting_feedback", "needs_revision"}:
|
if task.status not in {"awaiting_acceptance", "needs_revision"}:
|
||||||
return task
|
return task
|
||||||
run_id = next((item for item in reversed(task.run_ids) if item), None)
|
run_id = next((item for item in reversed(task.run_ids) if item), None)
|
||||||
if not run_id:
|
if not run_id:
|
||||||
@ -608,15 +672,15 @@ class AgentService:
|
|||||||
|
|
||||||
existing = next((item for item in task.feedback if item.get("run_id") == run_id), None)
|
existing = next((item for item in task.feedback if item.get("run_id") == run_id), None)
|
||||||
if existing is not None:
|
if existing is not None:
|
||||||
if existing.get("feedback_type") != "revise":
|
if existing.get("acceptance_type") != "revise":
|
||||||
return task
|
return task
|
||||||
updated = task
|
updated = task
|
||||||
already_recorded = True
|
already_recorded = True
|
||||||
else:
|
else:
|
||||||
task_service = self._require_loaded(loaded, "task_service")
|
task_service = self._require_loaded(loaded, "task_service")
|
||||||
updated = task_service.add_feedback(
|
updated = task_service.add_acceptance(
|
||||||
task.task_id,
|
task.task_id,
|
||||||
feedback_type="revise",
|
acceptance_type="revise",
|
||||||
comment=comment,
|
comment=comment,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
)
|
)
|
||||||
@ -629,6 +693,7 @@ class AgentService:
|
|||||||
{
|
{
|
||||||
"task_id": updated.task_id,
|
"task_id": updated.task_id,
|
||||||
"task_status": updated.status,
|
"task_status": updated.status,
|
||||||
|
"acceptance_state": "revise",
|
||||||
"feedback_state": "revise",
|
"feedback_state": "revise",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -639,9 +704,10 @@ class AgentService:
|
|||||||
session_id,
|
session_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
role="system",
|
role="system",
|
||||||
event_type="task_feedback_recorded",
|
event_type="task_acceptance_recorded",
|
||||||
event_payload={
|
event_payload={
|
||||||
"task_id": updated.task_id,
|
"task_id": updated.task_id,
|
||||||
|
"acceptance_type": "revise",
|
||||||
"feedback_type": "revise",
|
"feedback_type": "revise",
|
||||||
"comment": comment,
|
"comment": comment,
|
||||||
"task_status": updated.status,
|
"task_status": updated.status,
|
||||||
@ -650,12 +716,12 @@ class AgentService:
|
|||||||
content=comment,
|
content=comment,
|
||||||
context_visible=False,
|
context_visible=False,
|
||||||
)
|
)
|
||||||
validation = ValidationResult.from_dict(updated.validation_result)
|
|
||||||
run_memory_store = self._require_loaded(loaded, "run_memory_store")
|
run_memory_store = self._require_loaded(loaded, "run_memory_store")
|
||||||
run_memory_store.update_run_record(
|
run_memory_store.update_run_record(
|
||||||
run_id,
|
run_id,
|
||||||
success=False,
|
success=False,
|
||||||
feedback={
|
feedback={
|
||||||
|
"acceptance_type": "revise",
|
||||||
"feedback_type": "revise",
|
"feedback_type": "revise",
|
||||||
"comment": comment,
|
"comment": comment,
|
||||||
"task_status": updated.status,
|
"task_status": updated.status,
|
||||||
@ -664,7 +730,7 @@ class AgentService:
|
|||||||
run_memory_store.update_skill_effects_for_run(
|
run_memory_store.update_skill_effects_for_run(
|
||||||
run_id,
|
run_id,
|
||||||
success=False,
|
success=False,
|
||||||
feedback_score=self._feedback_score_for_learning("revise", validation),
|
feedback_score=self._acceptance_score_for_learning("revise"),
|
||||||
notes=comment.strip() or "revise",
|
notes=comment.strip() or "revise",
|
||||||
)
|
)
|
||||||
skill_learning_service = self._require_loaded(loaded, "skill_learning_service")
|
skill_learning_service = self._require_loaded(loaded, "skill_learning_service")
|
||||||
@ -681,26 +747,23 @@ class AgentService:
|
|||||||
) -> AgentRunResult:
|
) -> AgentRunResult:
|
||||||
loaded = self.create_loop().boot()
|
loaded = self.create_loop().boot()
|
||||||
task_service = self._require_loaded(loaded, "task_service")
|
task_service = self._require_loaded(loaded, "task_service")
|
||||||
validation_service = self._require_loaded(loaded, "validation_service")
|
|
||||||
task_execution_planner = self._require_loaded(loaded, "task_execution_planner")
|
task_execution_planner = self._require_loaded(loaded, "task_execution_planner")
|
||||||
session_manager = self._require_loaded(loaded, "session_manager")
|
session_manager = self._require_loaded(loaded, "session_manager")
|
||||||
run_memory_store = self._require_loaded(loaded, "run_memory_store")
|
|
||||||
|
|
||||||
last_result: AgentRunResult | None = None
|
|
||||||
latest_validation: ValidationResult | None = None
|
|
||||||
base_execution_context = kwargs.get("execution_context")
|
base_execution_context = kwargs.get("execution_context")
|
||||||
|
prompt_locale = kwargs.get("prompt_locale") or task.metadata.get("prompt_locale")
|
||||||
|
output_language_instruction = self._output_language_instruction(prompt_locale)
|
||||||
provider_bundle = kwargs.get("provider_bundle") or self._make_provider_bundle_for_task(loaded, kwargs)
|
provider_bundle = kwargs.get("provider_bundle") or self._make_provider_bundle_for_task(loaded, kwargs)
|
||||||
kwargs = dict(kwargs)
|
kwargs = dict(kwargs)
|
||||||
team_provider_bundle_factory = kwargs.pop("team_provider_bundle_factory", None)
|
team_provider_bundle_factory = kwargs.pop("team_provider_bundle_factory", None)
|
||||||
kwargs["provider_bundle"] = provider_bundle
|
kwargs["provider_bundle"] = provider_bundle
|
||||||
|
|
||||||
for attempt_index in (1, 2):
|
attempt_index = int(task.metadata.get("latest_attempt_index") or 0) + 1
|
||||||
task_service.start_run(task.task_id, user_message=message, attempt_index=attempt_index)
|
task_service.start_run(task.task_id, user_message=message, attempt_index=attempt_index)
|
||||||
plan = await task_execution_planner.plan(
|
plan = await task_execution_planner.plan(
|
||||||
task=task,
|
task=task,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
attempt_index=attempt_index,
|
attempt_index=attempt_index,
|
||||||
latest_validation=latest_validation,
|
|
||||||
provider_bundle=provider_bundle,
|
provider_bundle=provider_bundle,
|
||||||
)
|
)
|
||||||
self._append_task_observation(
|
self._append_task_observation(
|
||||||
@ -715,6 +778,7 @@ class AgentService:
|
|||||||
)
|
)
|
||||||
team_summaries: list[str] = []
|
team_summaries: list[str] = []
|
||||||
team_execution_context = ""
|
team_execution_context = ""
|
||||||
|
team_result: TeamRunResult | None = None
|
||||||
if plan.is_team:
|
if plan.is_team:
|
||||||
team_result, team_error = await self._run_team_for_task(
|
team_result, team_error = await self._run_team_for_task(
|
||||||
plan,
|
plan,
|
||||||
@ -725,7 +789,18 @@ class AgentService:
|
|||||||
)
|
)
|
||||||
if team_result is not None:
|
if team_result is not None:
|
||||||
team_summaries = [self._team_summary_for_validation(team_result)]
|
team_summaries = [self._team_summary_for_validation(team_result)]
|
||||||
team_execution_context = self._team_execution_context(plan, team_result)
|
team_packet = TaskEvidencePacket(
|
||||||
|
task_id=task.task_id,
|
||||||
|
attempt_index=attempt_index,
|
||||||
|
main_run=None,
|
||||||
|
team_runs=self._team_run_evidence(team_result),
|
||||||
|
team_node_results=list(team_result.node_results),
|
||||||
|
final_output="",
|
||||||
|
)
|
||||||
|
team_execution_context = self._join_context(
|
||||||
|
self._team_execution_context(plan, team_result),
|
||||||
|
"Rendered team evidence:\n" + render_task_evidence(team_packet),
|
||||||
|
)
|
||||||
self._append_task_observation(
|
self._append_task_observation(
|
||||||
session_manager,
|
session_manager,
|
||||||
task.session_id,
|
task.session_id,
|
||||||
@ -772,27 +847,23 @@ class AgentService:
|
|||||||
"allow_candidate_generation": False,
|
"allow_candidate_generation": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if attempt_index == 2 and latest_validation is not None:
|
|
||||||
revision_context = latest_validation.recommended_revision_prompt.strip()
|
|
||||||
if revision_context:
|
|
||||||
attempt_kwargs["execution_context"] = self._join_context(
|
attempt_kwargs["execution_context"] = self._join_context(
|
||||||
base_execution_context,
|
base_execution_context,
|
||||||
f"Task validation revision request:\n{revision_context}",
|
output_language_instruction,
|
||||||
team_execution_context,
|
team_execution_context,
|
||||||
)
|
)
|
||||||
elif team_execution_context:
|
if plan.is_team and team_execution_context:
|
||||||
attempt_kwargs["execution_context"] = self._join_context(base_execution_context, team_execution_context)
|
attempt_kwargs["include_tools"] = False
|
||||||
|
attempt_kwargs["max_tool_iterations"] = 0
|
||||||
attempt_kwargs["skill_selection_context"] = self._build_skill_selection_context(
|
attempt_kwargs["skill_selection_context"] = self._build_skill_selection_context(
|
||||||
task=task,
|
task=task,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
attempt_index=attempt_index,
|
attempt_index=attempt_index,
|
||||||
latest_validation=latest_validation,
|
|
||||||
plan=plan,
|
plan=plan,
|
||||||
team_summaries=team_summaries,
|
team_summaries=team_summaries,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await runner(message, **attempt_kwargs)
|
result = await runner(message, **attempt_kwargs)
|
||||||
last_result = result
|
|
||||||
self._append_task_observation(
|
self._append_task_observation(
|
||||||
session_manager,
|
session_manager,
|
||||||
task.session_id,
|
task.session_id,
|
||||||
@ -810,53 +881,57 @@ class AgentService:
|
|||||||
result.run_id,
|
result.run_id,
|
||||||
skill_names=self._skill_names_for_run(loaded, result.run_id),
|
skill_names=self._skill_names_for_run(loaded, result.run_id),
|
||||||
)
|
)
|
||||||
validation = await validation_service.validate_task_result(
|
evidence_packet = self._build_task_evidence_packet(
|
||||||
|
session_manager=session_manager,
|
||||||
task=task,
|
task=task,
|
||||||
user_message=message,
|
attempt_index=attempt_index,
|
||||||
final_output=result.output_text,
|
result=result,
|
||||||
transcript_excerpt=self._run_excerpt(session_manager, result.session_id, result.run_id),
|
team_result=team_result,
|
||||||
tool_summaries=self._tool_summaries(session_manager, result.session_id, result.run_id),
|
|
||||||
team_summaries=team_summaries,
|
|
||||||
provider_bundle=provider_bundle,
|
|
||||||
)
|
)
|
||||||
latest_validation = validation
|
evidence_text = render_task_evidence(evidence_packet)
|
||||||
task = task_service.record_validation(task.task_id, result.run_id, validation)
|
evidence_debug = {
|
||||||
run_memory_store.update_run_record(result.run_id, validation_result=validation.to_dict())
|
"evidence_run_ids": [
|
||||||
|
item.run_id for item in [evidence_packet.main_run, *evidence_packet.team_runs] if item is not None
|
||||||
|
],
|
||||||
|
"evidence_session_ids": [
|
||||||
|
item.session_id
|
||||||
|
for item in [evidence_packet.main_run, *evidence_packet.team_runs]
|
||||||
|
if item is not None
|
||||||
|
],
|
||||||
|
"tool_result_count": sum(
|
||||||
|
len(item.tool_results)
|
||||||
|
for item in [evidence_packet.main_run, *evidence_packet.team_runs]
|
||||||
|
if item is not None
|
||||||
|
),
|
||||||
|
"evidence_length": len(evidence_text),
|
||||||
|
}
|
||||||
session_manager.update_latest_assistant_event_payload(
|
session_manager.update_latest_assistant_event_payload(
|
||||||
result.session_id,
|
result.session_id,
|
||||||
result.run_id,
|
result.run_id,
|
||||||
{
|
{
|
||||||
"task_id": task.task_id,
|
"task_id": task.task_id,
|
||||||
"task_status": task.status,
|
"task_status": task.status,
|
||||||
"validation_status": "passed" if validation.accepted else "failed",
|
"evidence_status": "recorded",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
session_manager.append_message(
|
session_manager.append_message(
|
||||||
result.session_id,
|
result.session_id,
|
||||||
run_id=result.run_id,
|
run_id=result.run_id,
|
||||||
role="system",
|
role="system",
|
||||||
event_type="task_validation_snapshotted",
|
event_type="task_evidence_recorded",
|
||||||
event_payload={
|
event_payload={
|
||||||
"task_id": task.task_id,
|
"task_id": task.task_id,
|
||||||
"attempt_index": attempt_index,
|
"attempt_index": attempt_index,
|
||||||
"validation_result": validation.to_dict(),
|
"evidence_debug": evidence_debug,
|
||||||
"retry_scheduled": not validation.accepted and attempt_index == 1,
|
|
||||||
},
|
},
|
||||||
content=validation.recommended_revision_prompt or None,
|
content=None,
|
||||||
context_visible=False,
|
context_visible=False,
|
||||||
)
|
)
|
||||||
if not validation.accepted and attempt_index == 1:
|
|
||||||
session_manager.set_run_context_visible(result.session_id, result.run_id, False)
|
|
||||||
result.task_id = task.task_id
|
result.task_id = task.task_id
|
||||||
result.task_status = task.status
|
result.task_status = task.status
|
||||||
result.validation_result = validation.to_dict()
|
result.validation_result = None
|
||||||
if validation.accepted or attempt_index == 2:
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if last_result is None: # pragma: no cover - defensive
|
|
||||||
raise RuntimeError("Task mode did not produce a run result")
|
|
||||||
return last_result
|
|
||||||
|
|
||||||
async def _run_team_for_task(
|
async def _run_team_for_task(
|
||||||
self,
|
self,
|
||||||
plan: TaskExecutionPlan,
|
plan: TaskExecutionPlan,
|
||||||
@ -911,6 +986,24 @@ class AgentService:
|
|||||||
"short_title": decision.short_title,
|
"short_title": decision.short_title,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _output_language_instruction(prompt_locale: str | None) -> str:
|
||||||
|
locale = normalize_main_agent_prompt_locale(prompt_locale)
|
||||||
|
if locale == "en":
|
||||||
|
return (
|
||||||
|
"Output language: English. Use English for user-facing task titles, summaries, plans, "
|
||||||
|
"and final answers unless the user explicitly requests another language."
|
||||||
|
)
|
||||||
|
if locale == "zh-Hant":
|
||||||
|
return (
|
||||||
|
"輸出語言:繁體中文。除非使用者明確要求其他語言,所有面向使用者的任務標題、摘要、"
|
||||||
|
"計劃與最終回答都使用繁體中文。"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
"输出语言:简体中文。除非用户明确要求其他语言,所有面向用户的任务标题、摘要、"
|
||||||
|
"计划与最终回答都使用简体中文。"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _skill_names_for_run(loaded: Any, run_id: str) -> list[str]:
|
def _skill_names_for_run(loaded: Any, run_id: str) -> list[str]:
|
||||||
store = getattr(loaded, "run_memory_store", None)
|
store = getattr(loaded, "run_memory_store", None)
|
||||||
@ -922,12 +1015,10 @@ class AgentService:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _feedback_score_for_learning(feedback_type: str, validation: ValidationResult | None) -> float:
|
def _acceptance_score_for_learning(acceptance_type: str) -> float:
|
||||||
if feedback_type == "satisfied":
|
if acceptance_type == "accept":
|
||||||
if validation is not None:
|
|
||||||
return max(0.0, min(1.0, float(validation.score)))
|
|
||||||
return 1.0
|
return 1.0
|
||||||
if feedback_type == "revise":
|
if acceptance_type == "revise":
|
||||||
return 0.5
|
return 0.5
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
@ -937,12 +1028,11 @@ class AgentService:
|
|||||||
task: TaskRecord,
|
task: TaskRecord,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
attempt_index: int,
|
attempt_index: int,
|
||||||
latest_validation: ValidationResult | None = None,
|
|
||||||
plan: TaskExecutionPlan | None = None,
|
plan: TaskExecutionPlan | None = None,
|
||||||
team_summaries: list[str] | None = None,
|
team_summaries: list[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
phase = f"attempt_{attempt_index}"
|
phase = f"attempt_{attempt_index}"
|
||||||
if latest_validation is not None:
|
if task.feedback and task.feedback[-1].get("acceptance_type") == "revise":
|
||||||
phase = f"revision_attempt_{attempt_index}"
|
phase = f"revision_attempt_{attempt_index}"
|
||||||
elif plan is not None and plan.is_team:
|
elif plan is not None and plan.is_team:
|
||||||
phase = f"team_synthesis_attempt_{attempt_index}"
|
phase = f"team_synthesis_attempt_{attempt_index}"
|
||||||
@ -963,24 +1053,14 @@ class AgentService:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sections.append("Previously activated skills:\nNone")
|
sections.append("Previously activated skills:\nNone")
|
||||||
if latest_validation is not None:
|
if task.feedback:
|
||||||
validation_lines = [
|
history_lines = []
|
||||||
f"accepted: {latest_validation.accepted}",
|
for item in task.feedback[-5:]:
|
||||||
f"score: {latest_validation.score}",
|
kind = item.get("acceptance_type") or item.get("feedback_type")
|
||||||
]
|
comment = item.get("comment") or ""
|
||||||
if latest_validation.issues:
|
run_id = item.get("run_id") or ""
|
||||||
validation_lines.append("issues:\n" + "\n".join(f"- {item}" for item in latest_validation.issues))
|
history_lines.append(f"- {kind} run={run_id}: {comment}".strip())
|
||||||
if latest_validation.missing_requirements:
|
sections.append("Task acceptance history:\n" + "\n".join(history_lines))
|
||||||
validation_lines.append(
|
|
||||||
"missing requirements:\n"
|
|
||||||
+ "\n".join(f"- {item}" for item in latest_validation.missing_requirements)
|
|
||||||
)
|
|
||||||
if latest_validation.recommended_revision_prompt:
|
|
||||||
validation_lines.append(
|
|
||||||
"recommended revision:\n"
|
|
||||||
+ latest_validation.recommended_revision_prompt
|
|
||||||
)
|
|
||||||
sections.append("Validation feedback:\n" + "\n".join(validation_lines))
|
|
||||||
if plan is not None:
|
if plan is not None:
|
||||||
plan_lines = [
|
plan_lines = [
|
||||||
f"mode: {plan.mode}",
|
f"mode: {plan.mode}",
|
||||||
@ -1083,6 +1163,36 @@ class AgentService:
|
|||||||
payloads.append(payload)
|
payloads.append(payload)
|
||||||
return payloads
|
return payloads
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _team_run_evidence(result: TeamRunResult | None) -> list[RunEvidence]:
|
||||||
|
if result is None:
|
||||||
|
return []
|
||||||
|
return [node.evidence for node in result.node_results if node.evidence is not None]
|
||||||
|
|
||||||
|
def _build_task_evidence_packet(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_manager: Any,
|
||||||
|
task: TaskRecord,
|
||||||
|
attempt_index: int,
|
||||||
|
result: AgentRunResult,
|
||||||
|
team_result: TeamRunResult | None,
|
||||||
|
) -> TaskEvidencePacket:
|
||||||
|
main_run = EvidenceBuilder(session_manager).build_run_evidence(
|
||||||
|
result.session_id,
|
||||||
|
result.run_id,
|
||||||
|
result.output_text,
|
||||||
|
result.finish_reason,
|
||||||
|
)
|
||||||
|
return TaskEvidencePacket(
|
||||||
|
task_id=task.task_id,
|
||||||
|
attempt_index=attempt_index,
|
||||||
|
main_run=main_run,
|
||||||
|
team_runs=self._team_run_evidence(team_result),
|
||||||
|
team_node_results=list(team_result.node_results) if team_result is not None else [],
|
||||||
|
final_output=result.output_text,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _team_execution_context(plan: TaskExecutionPlan, result: TeamRunResult) -> str:
|
def _team_execution_context(plan: TaskExecutionPlan, result: TeamRunResult) -> str:
|
||||||
node_lines = [
|
node_lines = [
|
||||||
@ -1168,17 +1278,19 @@ class AgentService:
|
|||||||
async def handle_inbound_message(self, inbound: InboundMessage) -> OutboundMessage:
|
async def handle_inbound_message(self, inbound: InboundMessage) -> OutboundMessage:
|
||||||
"""把 bus inbound 映射成标准 runtime 调用,并返回结构化 outbound。"""
|
"""把 bus inbound 映射成标准 runtime 调用,并返回结构化 outbound。"""
|
||||||
|
|
||||||
|
channel_identity = inbound.channel_identity
|
||||||
try:
|
try:
|
||||||
result = await self.submit_direct(
|
result = await self.submit_direct(
|
||||||
inbound.content,
|
inbound.content,
|
||||||
session_id=inbound.session_id,
|
session_id=inbound.session_id,
|
||||||
source=f"gateway:{inbound.channel}",
|
source=f"gateway:{inbound.channel}",
|
||||||
user_id=inbound.user_id,
|
user_id=inbound.user_id or (channel_identity.user_id if channel_identity else None),
|
||||||
title=inbound.title,
|
title=inbound.title,
|
||||||
execution_context=inbound.execution_context,
|
execution_context=inbound.execution_context,
|
||||||
model=inbound.model,
|
model=inbound.model,
|
||||||
provider_name=inbound.provider_name,
|
provider_name=inbound.provider_name,
|
||||||
embedding_model=inbound.embedding_model,
|
embedding_model=inbound.embedding_model,
|
||||||
|
channel_identity=channel_identity,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return self.build_outbound_error(
|
return self.build_outbound_error(
|
||||||
@ -1214,12 +1326,15 @@ class AgentService:
|
|||||||
finish_reason=result.finish_reason,
|
finish_reason=result.finish_reason,
|
||||||
provider_name=result.provider_name,
|
provider_name=result.provider_name,
|
||||||
model=result.model,
|
model=result.model,
|
||||||
|
content_type=inbound.content_type,
|
||||||
|
channel_identity=inbound.channel_identity,
|
||||||
usage=dict(result.usage),
|
usage=dict(result.usage),
|
||||||
metadata={
|
metadata={
|
||||||
"inbound_metadata": dict(inbound.metadata),
|
"inbound_metadata": dict(inbound.metadata),
|
||||||
"task_id": getattr(result, "task_id", None),
|
"task_id": getattr(result, "task_id", None),
|
||||||
"task_status": getattr(result, "task_status", None),
|
"task_status": getattr(result, "task_status", None),
|
||||||
"validation_result": getattr(result, "validation_result", None),
|
"evidence_status": "recorded" if getattr(result, "task_id", None) else None,
|
||||||
|
"validation_result": None,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1238,6 +1353,8 @@ class AgentService:
|
|||||||
session_id=inbound.session_id,
|
session_id=inbound.session_id,
|
||||||
content=detail,
|
content=detail,
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
|
content_type=inbound.content_type,
|
||||||
|
channel_identity=inbound.channel_identity,
|
||||||
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
|
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -50,10 +50,11 @@ class SessionProcessProjector:
|
|||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
payload = dict(record.event_payload or {})
|
payload = dict(record.event_payload or {})
|
||||||
task_id = payload.get("task_id")
|
run_record_for_event = run_records.get(str(record.run_id)) if record.run_id else None
|
||||||
|
task_id = payload.get("task_id") or getattr(run_record_for_event, "task_id", None)
|
||||||
if not task_id:
|
if not task_id:
|
||||||
continue
|
continue
|
||||||
attempt_index = int(payload.get("attempt_index") or 1)
|
attempt_index = int(payload.get("attempt_index") or getattr(run_record_for_event, "attempt_index", None) or 1)
|
||||||
root_run_id = f"task:{task_id}:attempt:{attempt_index}"
|
root_run_id = f"task:{task_id}:attempt:{attempt_index}"
|
||||||
created_at = _timestamp(record.timestamp)
|
created_at = _timestamp(record.timestamp)
|
||||||
root = runs.setdefault(
|
root = runs.setdefault(
|
||||||
@ -73,15 +74,70 @@ class SessionProcessProjector:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if record.event_type == "task_execution_planned":
|
if record.event_type == "assistant_message_added" and record.tool_calls:
|
||||||
|
run_id = record.run_id or root_run_id
|
||||||
|
parent_run_id = root_run_id if run_id != root_run_id else None
|
||||||
|
for index, tool_call in enumerate(record.tool_calls):
|
||||||
|
if not isinstance(tool_call, dict):
|
||||||
|
continue
|
||||||
|
tool_name = _tool_call_name(tool_call)
|
||||||
|
add_event(
|
||||||
|
event_id=f"{_event_id(record, 'tool-call')}:{index}",
|
||||||
|
run_id=run_id,
|
||||||
|
parent_run_id=parent_run_id,
|
||||||
|
kind="tool_call_started",
|
||||||
|
actor_type="tool",
|
||||||
|
actor_id=tool_name,
|
||||||
|
actor_name=tool_name,
|
||||||
|
text=f"Calling tool: {tool_name}.",
|
||||||
|
created_at=created_at,
|
||||||
|
status="running",
|
||||||
|
metadata={
|
||||||
|
"task_id": task_id,
|
||||||
|
"attempt_index": attempt_index,
|
||||||
|
"timeline_type": "tool_call",
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tool_call_id": tool_call.get("id"),
|
||||||
|
"arguments": _tool_call_arguments(tool_call),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif record.event_type == "tool_result_recorded":
|
||||||
|
run_id = record.run_id or root_run_id
|
||||||
|
parent_run_id = root_run_id if run_id != root_run_id else None
|
||||||
|
tool_name = str(record.tool_name or payload.get("tool_name") or "tool")
|
||||||
|
add_event(
|
||||||
|
event_id=_event_id(record, "tool-result"),
|
||||||
|
run_id=run_id,
|
||||||
|
parent_run_id=parent_run_id,
|
||||||
|
kind="tool_call_finished",
|
||||||
|
actor_type="tool",
|
||||||
|
actor_id=tool_name,
|
||||||
|
actor_name=tool_name,
|
||||||
|
text=_truncate(str(record.content or payload.get("error") or "")),
|
||||||
|
created_at=created_at,
|
||||||
|
status="done" if payload.get("success", True) else "error",
|
||||||
|
metadata={
|
||||||
|
**dict(payload),
|
||||||
|
"task_id": task_id,
|
||||||
|
"attempt_index": attempt_index,
|
||||||
|
"timeline_type": "tool_result",
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"tool_call_id": record.tool_call_id,
|
||||||
|
"result_summary": _truncate(str(record.content or payload.get("error") or "")),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif record.event_type == "task_execution_planned":
|
||||||
|
plan_mode = payload.get("plan_mode") or "single"
|
||||||
strategy = payload.get("strategy") or "single"
|
strategy = payload.get("strategy") or "single"
|
||||||
node_ids = payload.get("node_ids") or []
|
node_ids = payload.get("node_ids") or []
|
||||||
root["title"] = f"{payload.get('plan_mode', 'single')} plan: {strategy}"
|
root["title"] = f"{plan_mode} plan: {strategy}"
|
||||||
root["summary"] = payload.get("reason") or ""
|
root["summary"] = payload.get("reason") or ""
|
||||||
root["metadata"] = {
|
root["metadata"] = {
|
||||||
**root.get("metadata", {}),
|
**root.get("metadata", {}),
|
||||||
"plan_mode": payload.get("plan_mode"),
|
"plan_mode": plan_mode,
|
||||||
"strategy": payload.get("strategy"),
|
"strategy": strategy,
|
||||||
"node_ids": node_ids,
|
"node_ids": node_ids,
|
||||||
"skill_queries": payload.get("skill_queries") or [],
|
"skill_queries": payload.get("skill_queries") or [],
|
||||||
"selected_skill_names": payload.get("selected_skill_names") or [],
|
"selected_skill_names": payload.get("selected_skill_names") or [],
|
||||||
@ -92,36 +148,65 @@ class SessionProcessProjector:
|
|||||||
add_event(
|
add_event(
|
||||||
event_id=_event_id(record, "planned"),
|
event_id=_event_id(record, "planned"),
|
||||||
run_id=root_run_id,
|
run_id=root_run_id,
|
||||||
kind="run_started",
|
kind="task_planned",
|
||||||
actor_type="system",
|
actor_type="system",
|
||||||
actor_id="task",
|
actor_id="task",
|
||||||
actor_name="Task Planner",
|
actor_name="Task Planner",
|
||||||
text=f"Planned {payload.get('plan_mode')} execution via {strategy}. {payload.get('reason') or ''}".strip(),
|
text=f"Beaver planned {plan_mode} execution via {strategy}. {payload.get('reason') or ''}".strip(),
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
status="running",
|
status="running",
|
||||||
metadata=root["metadata"],
|
metadata={
|
||||||
|
**root["metadata"],
|
||||||
|
"timeline_type": "plan",
|
||||||
|
"user_summary": f"Beaver will use {plan_mode} execution for this task.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
selected_skill_names = [
|
||||||
|
str(item)
|
||||||
|
for item in payload.get("selected_skill_names") or []
|
||||||
|
if str(item).strip()
|
||||||
|
]
|
||||||
|
if selected_skill_names:
|
||||||
|
add_event(
|
||||||
|
event_id=_event_id(record, "skills"),
|
||||||
|
run_id=root_run_id,
|
||||||
|
kind="skill_selected",
|
||||||
|
actor_type="system",
|
||||||
|
actor_id="skill-selector",
|
||||||
|
actor_name="Skill Selector",
|
||||||
|
text=f"Selected skill guidance: {', '.join(selected_skill_names)}.",
|
||||||
|
created_at=created_at,
|
||||||
|
status="done",
|
||||||
|
metadata={
|
||||||
|
"task_id": task_id,
|
||||||
|
"attempt_index": attempt_index,
|
||||||
|
"timeline_type": "skill",
|
||||||
|
"skill_names": selected_skill_names,
|
||||||
|
"reason": payload.get("reason") or "Selected from task planning context.",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
elif record.event_type in {"task_team_run_completed", "task_team_run_failed"}:
|
elif record.event_type in {"task_team_run_completed", "task_team_run_failed"}:
|
||||||
team_success = bool(payload.get("team_success"))
|
team_success = bool(payload.get("team_success"))
|
||||||
root["status"] = "running"
|
root["status"] = "running"
|
||||||
|
team_run_ids = payload.get("team_run_ids") or []
|
||||||
root["metadata"] = {
|
root["metadata"] = {
|
||||||
**root.get("metadata", {}),
|
**root.get("metadata", {}),
|
||||||
"team_success": team_success,
|
"team_success": team_success,
|
||||||
"team_run_ids": payload.get("team_run_ids") or [],
|
"team_run_ids": team_run_ids,
|
||||||
"team_error": payload.get("error"),
|
"team_error": payload.get("error"),
|
||||||
}
|
}
|
||||||
add_event(
|
add_event(
|
||||||
event_id=_event_id(record, "team"),
|
event_id=_event_id(record, "team"),
|
||||||
run_id=root_run_id,
|
run_id=root_run_id,
|
||||||
kind="run_status",
|
kind="agent_team_created",
|
||||||
actor_type="system",
|
actor_type="system",
|
||||||
actor_id="team",
|
actor_id="team",
|
||||||
actor_name="Task Team",
|
actor_name="Task Team",
|
||||||
text=payload.get("error") or ("Team completed" if team_success else "Team completed with failed nodes"),
|
text=payload.get("error") or ("Team completed" if team_success else "Team completed with failed nodes"),
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
status="done" if team_success else "error",
|
status="done" if team_success else "error",
|
||||||
metadata=dict(payload),
|
metadata={**dict(payload), "timeline_type": "agent_team", "team_run_ids": team_run_ids},
|
||||||
)
|
)
|
||||||
node_results = payload.get("node_results") or []
|
node_results = payload.get("node_results") or []
|
||||||
for item in node_results:
|
for item in node_results:
|
||||||
@ -192,20 +277,26 @@ class SessionProcessProjector:
|
|||||||
event_id=f"{_event_id(record, 'node')}:{item.get('node_id')}",
|
event_id=f"{_event_id(record, 'node')}:{item.get('node_id')}",
|
||||||
run_id=str(node_run_id),
|
run_id=str(node_run_id),
|
||||||
parent_run_id=root_run_id,
|
parent_run_id=root_run_id,
|
||||||
kind="run_finished",
|
kind="agent_finished",
|
||||||
actor_type="agent",
|
actor_type="agent",
|
||||||
actor_id=str(item.get("node_id") or "sub-agent"),
|
actor_id=str(item.get("node_id") or "sub-agent"),
|
||||||
actor_name=str(item.get("node_id") or "Sub-agent"),
|
actor_name=str(item.get("node_id") or "Sub-agent"),
|
||||||
text=_truncate(str(item.get("output_text") or item.get("error") or "")),
|
text=_truncate(str(item.get("output_text") or item.get("error") or "")),
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
status=status,
|
status=status,
|
||||||
metadata=dict(item),
|
metadata={
|
||||||
|
**dict(item),
|
||||||
|
"task_id": task_id,
|
||||||
|
"attempt_index": attempt_index,
|
||||||
|
"timeline_type": "agent_progress",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
elif record.event_type == "task_synthesis_completed":
|
elif record.event_type == "task_synthesis_completed":
|
||||||
main_run_id = str(payload.get("main_run_id") or "")
|
main_run_id = str(payload.get("main_run_id") or "")
|
||||||
if main_run_id:
|
if main_run_id:
|
||||||
run_record = run_records.get(main_run_id)
|
run_record = run_records.get(main_run_id)
|
||||||
|
activated_skill_names = _activated_skill_names(run_record)
|
||||||
runs[main_run_id] = {
|
runs[main_run_id] = {
|
||||||
"run_id": main_run_id,
|
"run_id": main_run_id,
|
||||||
"parent_run_id": root_run_id,
|
"parent_run_id": root_run_id,
|
||||||
@ -219,8 +310,32 @@ class SessionProcessProjector:
|
|||||||
"started_at": run_record.started_at if run_record is not None else created_at,
|
"started_at": run_record.started_at if run_record is not None else created_at,
|
||||||
"finished_at": run_record.ended_at if run_record is not None else created_at,
|
"finished_at": run_record.ended_at if run_record is not None else created_at,
|
||||||
"summary": _truncate(run_record.task_text if run_record is not None else ""),
|
"summary": _truncate(run_record.task_text if run_record is not None else ""),
|
||||||
"metadata": {"task_id": task_id, "attempt_index": attempt_index},
|
"metadata": {
|
||||||
|
"task_id": task_id,
|
||||||
|
"attempt_index": attempt_index,
|
||||||
|
"skill_names": activated_skill_names,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
if activated_skill_names:
|
||||||
|
add_event(
|
||||||
|
event_id=_event_id(record, "synthesis-skills"),
|
||||||
|
run_id=main_run_id,
|
||||||
|
parent_run_id=root_run_id,
|
||||||
|
kind="skill_selected",
|
||||||
|
actor_type="system",
|
||||||
|
actor_id="skill-selector",
|
||||||
|
actor_name="Skill Selector",
|
||||||
|
text=f"Selected skill guidance: {', '.join(activated_skill_names)}.",
|
||||||
|
created_at=created_at,
|
||||||
|
status="done",
|
||||||
|
metadata={
|
||||||
|
"task_id": task_id,
|
||||||
|
"attempt_index": attempt_index,
|
||||||
|
"timeline_type": "skill",
|
||||||
|
"skill_names": activated_skill_names,
|
||||||
|
"activation_reasons": _activated_skill_reasons(run_record),
|
||||||
|
},
|
||||||
|
)
|
||||||
add_event(
|
add_event(
|
||||||
event_id=_event_id(record, "synthesis"),
|
event_id=_event_id(record, "synthesis"),
|
||||||
run_id=main_run_id,
|
run_id=main_run_id,
|
||||||
@ -235,27 +350,46 @@ class SessionProcessProjector:
|
|||||||
metadata=dict(payload),
|
metadata=dict(payload),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif record.event_type == "task_validation_snapshotted":
|
elif record.event_type == "task_evidence_recorded":
|
||||||
validation = payload.get("validation_result") if isinstance(payload.get("validation_result"), dict) else {}
|
root["status"] = "waiting"
|
||||||
accepted = bool(validation.get("accepted"))
|
root["finished_at"] = None
|
||||||
root["status"] = "done" if accepted or attempt_index == 2 else "waiting"
|
|
||||||
root["finished_at"] = created_at if root["status"] == "done" else None
|
|
||||||
add_event(
|
add_event(
|
||||||
event_id=_event_id(record, "validation"),
|
event_id=_event_id(record, "evidence"),
|
||||||
run_id=record.run_id or root_run_id,
|
run_id=record.run_id or root_run_id,
|
||||||
parent_run_id=root_run_id if record.run_id else None,
|
parent_run_id=root_run_id if record.run_id else None,
|
||||||
kind="run_status",
|
kind="task_result_ready",
|
||||||
actor_type="system",
|
actor_type="system",
|
||||||
actor_id="validator",
|
actor_id="evidence-recorder",
|
||||||
actor_name="Validator",
|
actor_name="Evidence",
|
||||||
text=(
|
text="The task result is ready for user acceptance.",
|
||||||
f"Validation {'passed' if accepted else 'failed'} "
|
|
||||||
f"(score={validation.get('score')})."
|
|
||||||
+ (" Retry scheduled." if payload.get("retry_scheduled") else "")
|
|
||||||
),
|
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
status="done" if accepted else "error",
|
status="done",
|
||||||
metadata=dict(payload),
|
metadata={**dict(payload), "timeline_type": "result"},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif record.event_type == "task_acceptance_recorded":
|
||||||
|
acceptance_type = str(payload.get("acceptance_type") or payload.get("feedback_type") or "")
|
||||||
|
if acceptance_type == "accept":
|
||||||
|
root["status"] = "done"
|
||||||
|
root["finished_at"] = created_at
|
||||||
|
elif acceptance_type == "abandon":
|
||||||
|
root["status"] = "cancelled"
|
||||||
|
root["finished_at"] = created_at
|
||||||
|
else:
|
||||||
|
root["status"] = "waiting"
|
||||||
|
root["finished_at"] = None
|
||||||
|
add_event(
|
||||||
|
event_id=_event_id(record, "acceptance"),
|
||||||
|
run_id=record.run_id or root_run_id,
|
||||||
|
parent_run_id=root_run_id if record.run_id else None,
|
||||||
|
kind="task_acceptance_recorded",
|
||||||
|
actor_type="user",
|
||||||
|
actor_id="user-acceptance",
|
||||||
|
actor_name="User Acceptance",
|
||||||
|
text=f"User acceptance recorded: {acceptance_type or 'unknown'}.",
|
||||||
|
created_at=created_at,
|
||||||
|
status="done",
|
||||||
|
metadata={**dict(payload), "timeline_type": "acceptance"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -281,3 +415,49 @@ def _truncate(text: str, limit: int = 800) -> str:
|
|||||||
if len(cleaned) <= limit:
|
if len(cleaned) <= limit:
|
||||||
return cleaned
|
return cleaned
|
||||||
return cleaned[: limit - 1] + "..."
|
return cleaned[: limit - 1] + "..."
|
||||||
|
|
||||||
|
|
||||||
|
def _activated_skill_names(run_record: Any | None) -> list[str]:
|
||||||
|
if run_record is None:
|
||||||
|
return []
|
||||||
|
names = []
|
||||||
|
for receipt in getattr(run_record, "activated_skills", []) or []:
|
||||||
|
skill_name = str(getattr(receipt, "skill_name", "") or "").strip()
|
||||||
|
if skill_name:
|
||||||
|
names.append(skill_name)
|
||||||
|
return list(dict.fromkeys(names))
|
||||||
|
|
||||||
|
|
||||||
|
def _activated_skill_reasons(run_record: Any | None) -> list[str]:
|
||||||
|
if run_record is None:
|
||||||
|
return []
|
||||||
|
reasons = []
|
||||||
|
for receipt in getattr(run_record, "activated_skills", []) or []:
|
||||||
|
reason = str(getattr(receipt, "activation_reason", "") or "").strip()
|
||||||
|
if reason:
|
||||||
|
reasons.append(reason)
|
||||||
|
return reasons
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||||
|
function_payload = tool_call.get("function")
|
||||||
|
if isinstance(function_payload, dict):
|
||||||
|
name = function_payload.get("name")
|
||||||
|
if name:
|
||||||
|
return str(name)
|
||||||
|
for key in ("name", "tool_name"):
|
||||||
|
value = tool_call.get(key)
|
||||||
|
if value:
|
||||||
|
return str(value)
|
||||||
|
return "tool"
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call_arguments(tool_call: dict[str, Any]) -> Any:
|
||||||
|
function_payload = tool_call.get("function")
|
||||||
|
if isinstance(function_payload, dict) and "arguments" in function_payload:
|
||||||
|
return function_payload.get("arguments")
|
||||||
|
if "arguments" in tool_call:
|
||||||
|
return tool_call.get("arguments")
|
||||||
|
if "args" in tool_call:
|
||||||
|
return tool_call.get("args")
|
||||||
|
return None
|
||||||
|
|||||||
@ -16,10 +16,10 @@ if TYPE_CHECKING:
|
|||||||
class TeamService:
|
class TeamService:
|
||||||
"""Internal service for Beaver-native multi-agent execution."""
|
"""Internal service for Beaver-native multi-agent execution."""
|
||||||
|
|
||||||
def __init__(self, loop: AgentLoop) -> None:
|
def __init__(self, loop: AgentLoop, *, max_parallel_team_nodes: int = 3) -> None:
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.runner = LocalAgentRunner(loop)
|
self.runner = LocalAgentRunner(loop)
|
||||||
self.scheduler = TeamGraphScheduler(self.runner)
|
self.scheduler = TeamGraphScheduler(self.runner, max_parallel_team_nodes=max_parallel_team_nodes)
|
||||||
|
|
||||||
async def run_team(
|
async def run_team(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -32,16 +32,23 @@ When there is an active task, do not force every new user message into that task
|
|||||||
|
|
||||||
- Choose `revise_task` when the user asks to change, correct, refine, expand, reformat, or redo the latest active task result.
|
- Choose `revise_task` when the user asks to change, correct, refine, expand, reformat, or redo the latest active task result.
|
||||||
- Choose `continue_task` for neutral follow-up questions or additional next steps that still belong to the active task.
|
- Choose `continue_task` for neutral follow-up questions or additional next steps that still belong to the active task.
|
||||||
- Choose `new_task` when the user asks for clearly unrelated work.
|
- Choose `simple_chat` for unrelated lightweight conversation. This starts a new topic and the previous task will be accepted automatically.
|
||||||
|
- Choose `new_task` when the user asks for clearly unrelated work that needs Task capabilities. This starts a new topic and the previous task will be accepted automatically.
|
||||||
- Choose `close_task` when the user says the task is satisfactory or finished, such as "可以了", "就这样", or "that's good".
|
- Choose `close_task` when the user says the task is satisfactory or finished, such as "可以了", "就这样", or "that's good".
|
||||||
- Choose `abandon_task` when the user says to stop, cancel, or no longer do the active task.
|
- Choose `abandon_task` when the user says to stop, cancel, or no longer do the active task.
|
||||||
|
|
||||||
|
Do not classify unrelated lightweight conversation as `revise_task` merely because
|
||||||
|
the active task is awaiting acceptance. A revision must ask to change or correct
|
||||||
|
the active task result.
|
||||||
|
|
||||||
Examples with an active weather task:
|
Examples with an active weather task:
|
||||||
|
|
||||||
- "再详细一点" -> `revise_task`
|
- "再详细一点" -> `revise_task`
|
||||||
- "加上明后天穿衣建议" -> `revise_task`
|
- "加上明后天穿衣建议" -> `revise_task`
|
||||||
- "顺便查一下深圳" -> `continue_task`
|
- "顺便查一下深圳" -> `continue_task`
|
||||||
- "帮我写一个采购合同" -> `new_task`
|
- "帮我写一个采购合同" -> `new_task`
|
||||||
|
- "吃饭没" -> `simple_chat`
|
||||||
|
- "我在冰岛" -> `simple_chat`
|
||||||
- "可以了" -> `close_task`
|
- "可以了" -> `close_task`
|
||||||
- "不用了" -> `abandon_task`
|
- "不用了" -> `abandon_task`
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Skill learning loop helpers."""
|
"""Skill learning loop helpers."""
|
||||||
|
|
||||||
|
from .case_selection import select_replay_cases
|
||||||
from .evidence import EvidencePacket, EvidenceSelector
|
from .evidence import EvidencePacket, EvidenceSelector
|
||||||
from .eval import SkillDraftEvaluator
|
from .eval import SkillDraftEvaluator
|
||||||
from .missing_skill import (
|
from .missing_skill import (
|
||||||
@ -9,11 +10,15 @@ from .missing_skill import (
|
|||||||
MissingSkillSynthesizer,
|
MissingSkillSynthesizer,
|
||||||
)
|
)
|
||||||
from .pipeline import SkillLearningPipelineService
|
from .pipeline import SkillLearningPipelineService
|
||||||
|
from .preservation import check_preservation
|
||||||
|
from .replay import ReplayArmRequest, ReplayRunner, ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
||||||
from .service import RunReceiptContext, SkillLearningService
|
from .service import RunReceiptContext, SkillLearningService
|
||||||
|
from .surrogate import SurrogateToolEvaluator
|
||||||
from .synthesizer import SkillDraftSynthesizer
|
from .synthesizer import SkillDraftSynthesizer
|
||||||
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"select_replay_cases",
|
||||||
"EvidencePacket",
|
"EvidencePacket",
|
||||||
"EvidenceSelector",
|
"EvidenceSelector",
|
||||||
"SkillDraftEvaluator",
|
"SkillDraftEvaluator",
|
||||||
@ -23,6 +28,13 @@ __all__ = [
|
|||||||
"MissingSkillSynthesizer",
|
"MissingSkillSynthesizer",
|
||||||
"RunReceiptContext",
|
"RunReceiptContext",
|
||||||
"SkillLearningPipelineService",
|
"SkillLearningPipelineService",
|
||||||
|
"check_preservation",
|
||||||
|
"ReplayToolExecutor",
|
||||||
|
"ReplayToolPolicy",
|
||||||
|
"ReplayArmRequest",
|
||||||
|
"ReplayRunner",
|
||||||
|
"classify_tool_mode",
|
||||||
|
"SurrogateToolEvaluator",
|
||||||
"SkillDraftSynthesizer",
|
"SkillDraftSynthesizer",
|
||||||
"SkillLearningService",
|
"SkillLearningService",
|
||||||
"SkillLearningWorker",
|
"SkillLearningWorker",
|
||||||
|
|||||||
109
app-instance/backend/beaver/skills/learning/case_selection.py
Normal file
109
app-instance/backend/beaver/skills/learning/case_selection.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
"""Historical replay case selection for skill draft evaluation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.memory.runs import RunRecord
|
||||||
|
from beaver.memory.skills import SkillLearningCandidate
|
||||||
|
|
||||||
|
MAX_REPLAY_CASES = 10
|
||||||
|
|
||||||
|
|
||||||
|
def select_replay_cases(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[dict[str, Any]]:
|
||||||
|
accepted = [record for record in runs if _is_accepted(record)]
|
||||||
|
if candidate.kind == "revise_skill":
|
||||||
|
selected = _select_revise(candidate, accepted)
|
||||||
|
elif candidate.kind == "merge_skills":
|
||||||
|
selected = _select_merge(candidate, accepted)
|
||||||
|
else:
|
||||||
|
selected = _select_new(candidate, accepted)
|
||||||
|
return [_case_payload(candidate, record) for record in selected[:MAX_REPLAY_CASES]]
|
||||||
|
|
||||||
|
|
||||||
|
def _select_revise(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||||
|
target = candidate.related_skill_names[0] if candidate.related_skill_names else ""
|
||||||
|
version = str(candidate.evidence.get("skill_version") or "")
|
||||||
|
matches = [
|
||||||
|
record
|
||||||
|
for record in runs
|
||||||
|
if any(
|
||||||
|
receipt.skill_name == target and (not version or receipt.skill_version == version)
|
||||||
|
for receipt in record.activated_skills
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return _recent_diverse(matches)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_merge(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||||
|
targets = set(candidate.related_skill_names)
|
||||||
|
matches = [
|
||||||
|
record
|
||||||
|
for record in runs
|
||||||
|
if targets and targets.issubset({receipt.skill_name for receipt in record.activated_skills})
|
||||||
|
]
|
||||||
|
return _recent_diverse(matches)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_new(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||||
|
source_ids = set(candidate.source_run_ids)
|
||||||
|
if source_ids:
|
||||||
|
matches = [record for record in runs if record.run_id in source_ids]
|
||||||
|
else:
|
||||||
|
theme = str(candidate.evidence.get("theme") or "").lower().strip()
|
||||||
|
matches = [record for record in runs if theme and theme in record.task_text.lower()]
|
||||||
|
return _recent_diverse(matches)
|
||||||
|
|
||||||
|
|
||||||
|
def _case_payload(candidate: SkillLearningCandidate, record: RunRecord) -> dict[str, Any]:
|
||||||
|
baseline_skill_names = []
|
||||||
|
if candidate.kind == "revise_skill":
|
||||||
|
baseline_skill_names = list(candidate.related_skill_names[:1])
|
||||||
|
elif candidate.kind == "merge_skills":
|
||||||
|
baseline_skill_names = list(candidate.related_skill_names)
|
||||||
|
return {
|
||||||
|
"run_id": record.run_id,
|
||||||
|
"task_id": record.task_id,
|
||||||
|
"session_id": record.session_id,
|
||||||
|
"task_text": record.task_text,
|
||||||
|
"baseline_skill_names": baseline_skill_names,
|
||||||
|
"candidate_skill_name": candidate.draft_skill_name,
|
||||||
|
"accepted_score": _score(record),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _recent_diverse(runs: list[RunRecord]) -> list[RunRecord]:
|
||||||
|
sorted_runs = sorted(runs, key=lambda item: (item.started_at, item.run_id), reverse=True)
|
||||||
|
result: list[RunRecord] = []
|
||||||
|
seen_tasks: set[str] = set()
|
||||||
|
for record in sorted_runs:
|
||||||
|
task_key = record.task_id or record.task_text
|
||||||
|
if task_key in seen_tasks and len(sorted_runs) > MAX_REPLAY_CASES:
|
||||||
|
continue
|
||||||
|
seen_tasks.add(task_key)
|
||||||
|
result.append(record)
|
||||||
|
if len(result) >= MAX_REPLAY_CASES:
|
||||||
|
break
|
||||||
|
if len(result) < min(len(sorted_runs), MAX_REPLAY_CASES):
|
||||||
|
seen_run_ids = {record.run_id for record in result}
|
||||||
|
result.extend(record for record in sorted_runs if record.run_id not in seen_run_ids)
|
||||||
|
return result[:MAX_REPLAY_CASES]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_accepted(record: RunRecord) -> bool:
|
||||||
|
feedback = record.feedback or {}
|
||||||
|
acceptance = feedback.get("acceptance_type")
|
||||||
|
if acceptance is None and feedback.get("feedback_type") == "satisfied":
|
||||||
|
acceptance = "accept"
|
||||||
|
return bool(record.success) and acceptance == "accept"
|
||||||
|
|
||||||
|
|
||||||
|
def _score(record: RunRecord) -> float:
|
||||||
|
validation = record.validation_result or {}
|
||||||
|
value = validation.get("score") if isinstance(validation, dict) else None
|
||||||
|
if value is not None:
|
||||||
|
try:
|
||||||
|
return max(0.0, min(1.0, float(value)))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
return 0.8 if record.success else 0.4
|
||||||
@ -4,17 +4,28 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from beaver.engine.context import SkillContext
|
||||||
from beaver.engine.providers import ProviderBundle
|
from beaver.engine.providers import ProviderBundle
|
||||||
from beaver.memory.runs import RunMemoryStore
|
from beaver.memory.runs import RunMemoryStore
|
||||||
from beaver.memory.skills import SkillDraftEvalReport, SkillLearningCandidate
|
from beaver.memory.skills import SkillDraftEvalReport, SkillLearningCandidate
|
||||||
|
from beaver.skills.learning.case_selection import select_replay_cases
|
||||||
|
from beaver.skills.learning.preservation import check_preservation
|
||||||
|
from beaver.skills.learning.replay import ReplayArmRequest, ReplayRunner
|
||||||
|
from beaver.skills.learning.surrogate import SurrogateToolEvaluator
|
||||||
from beaver.skills.specs import SkillDraft
|
from beaver.skills.specs import SkillDraft
|
||||||
|
|
||||||
|
|
||||||
class SkillDraftEvaluator:
|
class SkillDraftEvaluator:
|
||||||
"""Builds a bounded eval report without writing user-visible sessions."""
|
"""Builds a bounded eval report without writing user-visible sessions."""
|
||||||
|
|
||||||
def __init__(self, run_store: RunMemoryStore) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
run_store: RunMemoryStore,
|
||||||
|
*,
|
||||||
|
surrogate_evaluator: SurrogateToolEvaluator | None = None,
|
||||||
|
) -> None:
|
||||||
self.run_store = run_store
|
self.run_store = run_store
|
||||||
|
self.surrogate_evaluator = surrogate_evaluator or SurrogateToolEvaluator()
|
||||||
|
|
||||||
async def evaluate(
|
async def evaluate(
|
||||||
self,
|
self,
|
||||||
@ -22,11 +33,30 @@ class SkillDraftEvaluator:
|
|||||||
candidate: SkillLearningCandidate,
|
candidate: SkillLearningCandidate,
|
||||||
draft: SkillDraft,
|
draft: SkillDraft,
|
||||||
provider_bundle: ProviderBundle | None,
|
provider_bundle: ProviderBundle | None,
|
||||||
|
replay_runner: ReplayRunner | None = None,
|
||||||
) -> SkillDraftEvalReport:
|
) -> SkillDraftEvalReport:
|
||||||
if provider_bundle is None or provider_bundle.main_provider is None:
|
if provider_bundle is None or provider_bundle.main_provider is None:
|
||||||
return self._skipped(candidate, draft)
|
return self._skipped(candidate, draft)
|
||||||
|
|
||||||
runs_by_id = {record.run_id: record for record in self.run_store.list_runs()}
|
runs = self.run_store.list_runs()
|
||||||
|
replay_cases = select_replay_cases(candidate, runs)
|
||||||
|
if replay_runner is not None and replay_cases:
|
||||||
|
return await self._evaluate_replay(
|
||||||
|
candidate=candidate,
|
||||||
|
draft=draft,
|
||||||
|
replay_cases=replay_cases,
|
||||||
|
provider_bundle=provider_bundle,
|
||||||
|
replay_runner=replay_runner,
|
||||||
|
)
|
||||||
|
return self._evaluate_heuristic(candidate, draft, runs)
|
||||||
|
|
||||||
|
def _evaluate_heuristic(
|
||||||
|
self,
|
||||||
|
candidate: SkillLearningCandidate,
|
||||||
|
draft: SkillDraft,
|
||||||
|
runs: list,
|
||||||
|
) -> SkillDraftEvalReport:
|
||||||
|
runs_by_id = {record.run_id: record for record in runs}
|
||||||
cases: list[dict] = []
|
cases: list[dict] = []
|
||||||
for run_id in candidate.source_run_ids[:8]:
|
for run_id in candidate.source_run_ids[:8]:
|
||||||
record = runs_by_id.get(run_id)
|
record = runs_by_id.get(run_id)
|
||||||
@ -78,6 +108,78 @@ class SkillDraftEvaluator:
|
|||||||
created_at=_utc_now(),
|
created_at=_utc_now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _evaluate_replay(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
candidate: SkillLearningCandidate,
|
||||||
|
draft: SkillDraft,
|
||||||
|
replay_cases: list[dict],
|
||||||
|
provider_bundle: ProviderBundle,
|
||||||
|
replay_runner: ReplayRunner,
|
||||||
|
) -> SkillDraftEvalReport:
|
||||||
|
case_reports: list[dict] = []
|
||||||
|
legacy_cases: list[dict] = []
|
||||||
|
for case in replay_cases:
|
||||||
|
baseline = await replay_runner.run_arm(
|
||||||
|
ReplayArmRequest(
|
||||||
|
case_id=f"{case['run_id']}:baseline",
|
||||||
|
arm="baseline",
|
||||||
|
task_text=str(case["task_text"]),
|
||||||
|
pinned_skill_names=list(case.get("baseline_skill_names") or []),
|
||||||
|
pinned_skill_contexts=[],
|
||||||
|
provider_bundle=provider_bundle,
|
||||||
|
model_settings={"max_tool_iterations": 4, "temperature": 0.0},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
candidate_arm = await replay_runner.run_arm(
|
||||||
|
ReplayArmRequest(
|
||||||
|
case_id=f"{case['run_id']}:candidate",
|
||||||
|
arm="candidate",
|
||||||
|
task_text=str(case["task_text"]),
|
||||||
|
pinned_skill_names=[],
|
||||||
|
pinned_skill_contexts=[_draft_skill_context(draft)],
|
||||||
|
provider_bundle=provider_bundle,
|
||||||
|
model_settings={"max_tool_iterations": 4, "temperature": 0.0},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
surrogate = await self.surrogate_evaluator.evaluate(
|
||||||
|
task_text=str(case["task_text"]),
|
||||||
|
baseline=baseline,
|
||||||
|
candidate=candidate_arm,
|
||||||
|
)
|
||||||
|
baseline_score = surrogate["baseline_score"]
|
||||||
|
candidate_score = surrogate["candidate_score"]
|
||||||
|
case_report = {
|
||||||
|
"run_id": case["run_id"],
|
||||||
|
"task_id": case.get("task_id"),
|
||||||
|
"session_id": case.get("session_id"),
|
||||||
|
"baseline": baseline,
|
||||||
|
"candidate": candidate_arm,
|
||||||
|
"baseline_score": baseline_score,
|
||||||
|
"candidate_score": candidate_score,
|
||||||
|
"delta": round(candidate_score - baseline_score, 4),
|
||||||
|
"execution_coverage": _arm_mode_coverage(baseline, candidate_arm, "executed"),
|
||||||
|
"surrogate_coverage": _arm_mode_coverage(baseline, candidate_arm, "surrogate"),
|
||||||
|
"blocked_tool_count": _arm_mode_count(baseline, candidate_arm, "blocked"),
|
||||||
|
"confidence": surrogate["confidence"],
|
||||||
|
"tool_calls": [*baseline.get("tool_calls", []), *candidate_arm.get("tool_calls", [])],
|
||||||
|
"artifacts": [*baseline.get("artifacts", []), *candidate_arm.get("artifacts", [])],
|
||||||
|
"side_effects": [*baseline.get("side_effects", []), *candidate_arm.get("side_effects", [])],
|
||||||
|
"validator_notes": list(surrogate.get("notes") or []),
|
||||||
|
}
|
||||||
|
case_reports.append(case_report)
|
||||||
|
legacy_cases.append(
|
||||||
|
{
|
||||||
|
"run_id": case["run_id"],
|
||||||
|
"session_id": case.get("session_id") or "",
|
||||||
|
"baseline_score": baseline_score,
|
||||||
|
"candidate_score": candidate_score,
|
||||||
|
"delta": round(candidate_score - baseline_score, 4),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
preservation_report = _preservation_report(candidate, draft)
|
||||||
|
return _report_from_case_reports(candidate, draft, case_reports, legacy_cases, preservation_report)
|
||||||
|
|
||||||
def _skipped(self, candidate: SkillLearningCandidate, draft: SkillDraft) -> SkillDraftEvalReport:
|
def _skipped(self, candidate: SkillLearningCandidate, draft: SkillDraft) -> SkillDraftEvalReport:
|
||||||
return SkillDraftEvalReport(
|
return SkillDraftEvalReport(
|
||||||
report_id=uuid4().hex,
|
report_id=uuid4().hex,
|
||||||
@ -115,6 +217,108 @@ def _candidate_score(baseline: float, draft: SkillDraft) -> float:
|
|||||||
return min(1.0, max(0.75, baseline + 0.05))
|
return min(1.0, max(0.75, baseline + 0.05))
|
||||||
|
|
||||||
|
|
||||||
|
def _draft_skill_context(draft: SkillDraft) -> SkillContext:
|
||||||
|
tool_hints = draft.proposed_frontmatter.get("tools")
|
||||||
|
return SkillContext(
|
||||||
|
name=f"draft:{draft.skill_name}",
|
||||||
|
content=draft.proposed_content,
|
||||||
|
version=draft.draft_id,
|
||||||
|
content_hash="draft",
|
||||||
|
activation_reason="skill_replay_eval_candidate",
|
||||||
|
tool_hints=[str(item) for item in tool_hints if str(item).strip()] if isinstance(tool_hints, list) else [],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _preservation_report(candidate: SkillLearningCandidate, draft: SkillDraft) -> dict | None:
|
||||||
|
if candidate.kind not in {"revise_skill", "merge_skills"}:
|
||||||
|
return None
|
||||||
|
base_content = str(candidate.evidence.get("base_content") or "") if isinstance(candidate.evidence, dict) else ""
|
||||||
|
if not base_content.strip():
|
||||||
|
return None
|
||||||
|
return check_preservation(base_content=base_content, draft_content=draft.proposed_content)
|
||||||
|
|
||||||
|
|
||||||
|
def _report_from_case_reports(
|
||||||
|
candidate: SkillLearningCandidate,
|
||||||
|
draft: SkillDraft,
|
||||||
|
case_reports: list[dict],
|
||||||
|
legacy_cases: list[dict],
|
||||||
|
preservation_report: dict | None,
|
||||||
|
) -> SkillDraftEvalReport:
|
||||||
|
baseline_avg = sum(item["baseline_score"] for item in legacy_cases) / len(legacy_cases)
|
||||||
|
candidate_avg = sum(item["candidate_score"] for item in legacy_cases) / len(legacy_cases)
|
||||||
|
regressions = [item for item in legacy_cases if item["candidate_score"] < item["baseline_score"]]
|
||||||
|
improved = [item for item in legacy_cases if item["candidate_score"] > item["baseline_score"]]
|
||||||
|
unchanged = len(legacy_cases) - len(regressions) - len(improved)
|
||||||
|
execution, surrogate, blocked = _coverage(case_reports)
|
||||||
|
confidence = _confidence(execution, surrogate, blocked, [item.get("confidence") for item in case_reports])
|
||||||
|
score_delta = candidate_avg - baseline_avg
|
||||||
|
passed = candidate_avg >= 0.75 and not (regressions and score_delta <= 0) and blocked < 1.0
|
||||||
|
return SkillDraftEvalReport(
|
||||||
|
report_id=uuid4().hex,
|
||||||
|
skill_name=draft.skill_name,
|
||||||
|
draft_id=draft.draft_id,
|
||||||
|
candidate_id=candidate.candidate_id,
|
||||||
|
passed=passed,
|
||||||
|
baseline_score_avg=round(baseline_avg, 4),
|
||||||
|
candidate_score_avg=round(candidate_avg, 4),
|
||||||
|
score_delta=round(score_delta, 4),
|
||||||
|
regression_count=len(regressions),
|
||||||
|
improved_count=len(improved),
|
||||||
|
unchanged_count=unchanged,
|
||||||
|
cases=legacy_cases,
|
||||||
|
status="completed",
|
||||||
|
created_at=_utc_now(),
|
||||||
|
eval_version="replay-v1",
|
||||||
|
mode="replay",
|
||||||
|
execution_coverage=execution,
|
||||||
|
surrogate_coverage=surrogate,
|
||||||
|
blocked_coverage=blocked,
|
||||||
|
confidence=confidence,
|
||||||
|
case_reports=case_reports,
|
||||||
|
tool_mode_summary={"executed": execution, "surrogate": surrogate, "blocked": blocked},
|
||||||
|
preservation_report=preservation_report,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _coverage(case_reports: list[dict]) -> tuple[float, float, float]:
|
||||||
|
counts = {"executed": 0, "surrogate": 0, "blocked": 0}
|
||||||
|
for report in case_reports:
|
||||||
|
for call in report.get("tool_calls") or []:
|
||||||
|
if isinstance(call, dict) and call.get("mode") in counts:
|
||||||
|
counts[str(call["mode"])] += 1
|
||||||
|
total = sum(counts.values())
|
||||||
|
if total == 0:
|
||||||
|
return 1.0, 0.0, 0.0
|
||||||
|
return (
|
||||||
|
round(counts["executed"] / total, 4),
|
||||||
|
round(counts["surrogate"] / total, 4),
|
||||||
|
round(counts["blocked"] / total, 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _confidence(execution: float, surrogate: float, blocked: float, case_confidences: list[object]) -> str:
|
||||||
|
if blocked > 0.0:
|
||||||
|
return "low"
|
||||||
|
if execution >= 0.75 and surrogate <= 0.25:
|
||||||
|
return "high"
|
||||||
|
if execution >= 0.25 or "medium" in case_confidences:
|
||||||
|
return "medium"
|
||||||
|
return "low"
|
||||||
|
|
||||||
|
|
||||||
|
def _arm_mode_coverage(baseline: dict, candidate: dict, mode: str) -> float:
|
||||||
|
calls = [*baseline.get("tool_calls", []), *candidate.get("tool_calls", [])]
|
||||||
|
if not calls:
|
||||||
|
return 1.0 if mode == "executed" else 0.0
|
||||||
|
return round(sum(1 for call in calls if isinstance(call, dict) and call.get("mode") == mode) / len(calls), 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _arm_mode_count(baseline: dict, candidate: dict, mode: str) -> int:
|
||||||
|
calls = [*baseline.get("tool_calls", []), *candidate.get("tool_calls", [])]
|
||||||
|
return sum(1 for call in calls if isinstance(call, dict) and call.get("mode") == mode)
|
||||||
|
|
||||||
|
|
||||||
def _utc_now() -> str:
|
def _utc_now() -> str:
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from beaver.engine.providers import ProviderBundle
|
|||||||
from beaver.memory.skills import SkillDraftEvalReport, SkillDraftSafetyReport, SkillLearningCandidate, SkillLearningStore
|
from beaver.memory.skills import SkillDraftEvalReport, SkillDraftSafetyReport, SkillLearningCandidate, SkillLearningStore
|
||||||
from beaver.skills.drafts import DraftService
|
from beaver.skills.drafts import DraftService
|
||||||
from beaver.skills.learning.eval import SkillDraftEvaluator
|
from beaver.skills.learning.eval import SkillDraftEvaluator
|
||||||
|
from beaver.skills.learning.replay import ReplayRunner
|
||||||
from beaver.skills.learning.service import SkillLearningService
|
from beaver.skills.learning.service import SkillLearningService
|
||||||
from beaver.skills.learning.safety import SkillDraftSafetyChecker
|
from beaver.skills.learning.safety import SkillDraftSafetyChecker
|
||||||
from beaver.skills.publisher import SkillPublisher
|
from beaver.skills.publisher import SkillPublisher
|
||||||
@ -285,11 +286,17 @@ class SkillLearningPipelineService:
|
|||||||
draft_id: str,
|
draft_id: str,
|
||||||
*,
|
*,
|
||||||
provider_bundle: ProviderBundle | None,
|
provider_bundle: ProviderBundle | None,
|
||||||
|
replay_runner: ReplayRunner | None = None,
|
||||||
) -> SkillDraftEvalReport:
|
) -> SkillDraftEvalReport:
|
||||||
draft = self.get_draft(skill_name, draft_id)
|
draft = self.get_draft(skill_name, draft_id)
|
||||||
candidate = self.get_candidate(candidate_id)
|
candidate = self.get_candidate(candidate_id)
|
||||||
evaluator = self.evaluator or SkillDraftEvaluator(self.learning_service.run_store)
|
evaluator = self.evaluator or SkillDraftEvaluator(self.learning_service.run_store)
|
||||||
report = await evaluator.evaluate(candidate=candidate, draft=draft, provider_bundle=provider_bundle)
|
report = await evaluator.evaluate(
|
||||||
|
candidate=candidate,
|
||||||
|
draft=draft,
|
||||||
|
provider_bundle=provider_bundle,
|
||||||
|
replay_runner=replay_runner,
|
||||||
|
)
|
||||||
self.learning_store.write_eval_report(report)
|
self.learning_store.write_eval_report(report)
|
||||||
if report.status == "skipped_provider_unavailable":
|
if report.status == "skipped_provider_unavailable":
|
||||||
status = "draft_ready"
|
status = "draft_ready"
|
||||||
@ -330,6 +337,14 @@ class SkillLearningPipelineService:
|
|||||||
eval_report = self.get_eval_report(draft.skill_name, draft.draft_id)
|
eval_report = self.get_eval_report(draft.skill_name, draft.draft_id)
|
||||||
if eval_report is not None and eval_report.status != "skipped_provider_unavailable" and not eval_report.passed:
|
if eval_report is not None and eval_report.status != "skipped_provider_unavailable" and not eval_report.passed:
|
||||||
raise ValueError("Draft eval report did not pass")
|
raise ValueError("Draft eval report did not pass")
|
||||||
|
if eval_report is not None and eval_report.mode == "replay":
|
||||||
|
if eval_report.confidence == "low":
|
||||||
|
raise ValueError("Draft replay eval has low confidence and requires revision before publish")
|
||||||
|
if eval_report.blocked_coverage >= 1.0:
|
||||||
|
raise ValueError("Draft replay eval blocked all important tool calls")
|
||||||
|
preservation = eval_report.preservation_report or {}
|
||||||
|
if preservation.get("passed") is False:
|
||||||
|
raise ValueError("Draft preservation check did not pass")
|
||||||
|
|
||||||
def _mark_candidate_by_draft(
|
def _mark_candidate_by_draft(
|
||||||
self,
|
self,
|
||||||
|
|||||||
53
app-instance/backend/beaver/skills/learning/preservation.py
Normal file
53
app-instance/backend/beaver/skills/learning/preservation.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""Preservation checks for skill revision drafts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def check_preservation(*, base_content: str, draft_content: str) -> dict[str, Any]:
|
||||||
|
base_sections = _sections(base_content)
|
||||||
|
draft_sections = _sections(draft_content)
|
||||||
|
preserved: list[str] = []
|
||||||
|
changed: list[str] = []
|
||||||
|
dropped: list[str] = []
|
||||||
|
|
||||||
|
for heading, body in base_sections.items():
|
||||||
|
draft_body = draft_sections.get(heading)
|
||||||
|
if draft_body is None:
|
||||||
|
dropped.append(heading)
|
||||||
|
continue
|
||||||
|
preserved.append(heading)
|
||||||
|
if _normalize(body) != _normalize(draft_body):
|
||||||
|
changed.append(heading)
|
||||||
|
|
||||||
|
risk_level = "high" if dropped else "low"
|
||||||
|
return {
|
||||||
|
"passed": not dropped,
|
||||||
|
"risk_level": risk_level,
|
||||||
|
"preserved_sections": preserved,
|
||||||
|
"changed_sections": changed,
|
||||||
|
"dropped_sections": dropped,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _sections(content: str) -> dict[str, str]:
|
||||||
|
current = "body"
|
||||||
|
sections: dict[str, list[str]] = {current: []}
|
||||||
|
for line in (content or "").splitlines():
|
||||||
|
match = re.match(r"^#{1,6}\s+(.+?)\s*$", line)
|
||||||
|
if match:
|
||||||
|
current = match.group(1).strip()
|
||||||
|
sections.setdefault(current, [])
|
||||||
|
continue
|
||||||
|
sections.setdefault(current, []).append(line)
|
||||||
|
return {
|
||||||
|
heading: "\n".join(lines).strip()
|
||||||
|
for heading, lines in sections.items()
|
||||||
|
if "\n".join(lines).strip()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(value: str) -> str:
|
||||||
|
return re.sub(r"\s+", " ", value or "").strip().lower()
|
||||||
203
app-instance/backend/beaver/skills/learning/replay.py
Normal file
203
app-instance/backend/beaver/skills/learning/replay.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
"""Replay execution helpers for skill draft evaluation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Literal
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from beaver.tools.base import ToolContext, ToolResult, ToolSpec
|
||||||
|
from beaver.tools.registry.tool_registry import ToolRegistry
|
||||||
|
from beaver.tools.runtime.executor import ToolExecutor
|
||||||
|
|
||||||
|
ToolExecutionMode = Literal["executed", "surrogate", "blocked"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ReplayToolPolicy:
|
||||||
|
safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"})
|
||||||
|
surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"})
|
||||||
|
destructive_terms: tuple[str, ...] = (
|
||||||
|
"delete",
|
||||||
|
"remove",
|
||||||
|
"destroy",
|
||||||
|
"revoke",
|
||||||
|
"permission",
|
||||||
|
"credential",
|
||||||
|
"payment",
|
||||||
|
"pay",
|
||||||
|
)
|
||||||
|
external_write_terms: tuple[str, ...] = (
|
||||||
|
"send",
|
||||||
|
"post",
|
||||||
|
"publish",
|
||||||
|
"create",
|
||||||
|
"update",
|
||||||
|
"invite",
|
||||||
|
"reply",
|
||||||
|
"forward",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayToolExecutor:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
inner: ToolExecutor,
|
||||||
|
*,
|
||||||
|
registry: ToolRegistry,
|
||||||
|
policy: ReplayToolPolicy | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.inner = inner
|
||||||
|
self.registry = registry
|
||||||
|
self.policy = policy or ReplayToolPolicy()
|
||||||
|
self.traces: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any] | None,
|
||||||
|
*,
|
||||||
|
context: ToolContext | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
tool = self.registry.get(tool_name)
|
||||||
|
spec = tool.spec if tool is not None else ToolSpec(
|
||||||
|
name=tool_name,
|
||||||
|
description="unregistered tool",
|
||||||
|
input_schema={"type": "object", "properties": {}},
|
||||||
|
toolset="unknown",
|
||||||
|
)
|
||||||
|
mode = classify_tool_mode(spec, self.policy)
|
||||||
|
trace = {
|
||||||
|
"trace_id": uuid4().hex,
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"mode": mode,
|
||||||
|
"arguments": dict(arguments or {}),
|
||||||
|
"schema": dict(spec.input_schema),
|
||||||
|
"toolset": spec.toolset,
|
||||||
|
"metadata": dict(spec.metadata),
|
||||||
|
"classification_reason": _classification_reason(spec, mode),
|
||||||
|
}
|
||||||
|
if mode == "executed":
|
||||||
|
result = await self.inner.execute(tool_name, arguments or {}, context=context)
|
||||||
|
trace["result"] = {
|
||||||
|
"success": result.success,
|
||||||
|
"error": result.error,
|
||||||
|
"content": result.content[:2000],
|
||||||
|
}
|
||||||
|
self.traces.append(trace)
|
||||||
|
return result
|
||||||
|
if mode == "surrogate":
|
||||||
|
trace["result"] = {
|
||||||
|
"success": True,
|
||||||
|
"error": "replay_surrogate",
|
||||||
|
"content": "Tool call recorded for surrogate evaluation.",
|
||||||
|
}
|
||||||
|
self.traces.append(trace)
|
||||||
|
return ToolResult(
|
||||||
|
success=True,
|
||||||
|
content="Tool call recorded for surrogate evaluation.",
|
||||||
|
tool_name=tool_name,
|
||||||
|
error="replay_surrogate",
|
||||||
|
raw_output=trace,
|
||||||
|
)
|
||||||
|
trace["result"] = {
|
||||||
|
"success": False,
|
||||||
|
"error": "replay_blocked",
|
||||||
|
"content": "Tool call blocked by replay policy.",
|
||||||
|
}
|
||||||
|
self.traces.append(trace)
|
||||||
|
return ToolResult(
|
||||||
|
success=False,
|
||||||
|
content="Tool call blocked by replay policy.",
|
||||||
|
tool_name=tool_name,
|
||||||
|
error="replay_blocked",
|
||||||
|
raw_output=trace,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult:
|
||||||
|
tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call)
|
||||||
|
return await self.execute(tool_name, arguments, context=context)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode:
|
||||||
|
policy = policy or ReplayToolPolicy()
|
||||||
|
name = spec.name.lower()
|
||||||
|
toolset = spec.toolset.lower()
|
||||||
|
metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()}
|
||||||
|
if any(term in name for term in policy.destructive_terms):
|
||||||
|
return "blocked"
|
||||||
|
if toolset in policy.safe_toolsets:
|
||||||
|
return "executed"
|
||||||
|
if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}:
|
||||||
|
if any(term in name for term in policy.external_write_terms):
|
||||||
|
return "surrogate"
|
||||||
|
return "executed"
|
||||||
|
return "surrogate"
|
||||||
|
|
||||||
|
|
||||||
|
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
|
||||||
|
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ReplayArmRequest:
|
||||||
|
case_id: str
|
||||||
|
arm: str
|
||||||
|
task_text: str
|
||||||
|
pinned_skill_names: list[str] = field(default_factory=list)
|
||||||
|
pinned_skill_contexts: list[Any] = field(default_factory=list)
|
||||||
|
provider_bundle: Any | None = None
|
||||||
|
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayRunner:
|
||||||
|
def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None:
|
||||||
|
self.agent_loop = agent_loop
|
||||||
|
self.policy = policy or ReplayToolPolicy()
|
||||||
|
|
||||||
|
async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]:
|
||||||
|
loaded = self.agent_loop.boot()
|
||||||
|
replay_executor = ReplayToolExecutor(
|
||||||
|
loaded.tool_executor,
|
||||||
|
registry=loaded.tool_registry,
|
||||||
|
policy=self.policy,
|
||||||
|
)
|
||||||
|
result = await self.agent_loop.process_direct(
|
||||||
|
request.task_text,
|
||||||
|
provider_bundle=request.provider_bundle,
|
||||||
|
include_skill_assembly=False,
|
||||||
|
include_tools=True,
|
||||||
|
pinned_skill_names=request.pinned_skill_names,
|
||||||
|
pinned_skill_contexts=request.pinned_skill_contexts,
|
||||||
|
max_tool_iterations=int(request.model_settings.get("max_tool_iterations") or 4),
|
||||||
|
temperature=float(request.model_settings.get("temperature") or 0.0),
|
||||||
|
source="skill_replay_eval",
|
||||||
|
tool_executor_override=replay_executor,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"case_id": request.case_id,
|
||||||
|
"arm": request.arm,
|
||||||
|
"session_id": result.session_id,
|
||||||
|
"run_id": result.run_id,
|
||||||
|
"task_text": request.task_text,
|
||||||
|
"finish_reason": result.finish_reason,
|
||||||
|
"final_answer": result.output_text,
|
||||||
|
"tool_calls": list(replay_executor.traces),
|
||||||
|
"artifacts": [],
|
||||||
|
"side_effects": _side_effects_from_traces(replay_executor.traces),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
effects: list[dict[str, Any]] = []
|
||||||
|
for trace in traces:
|
||||||
|
if trace.get("mode") in {"surrogate", "blocked"}:
|
||||||
|
effects.append(
|
||||||
|
{
|
||||||
|
"tool_name": trace.get("tool_name"),
|
||||||
|
"mode": trace.get("mode"),
|
||||||
|
"arguments": trace.get("arguments"),
|
||||||
|
"classification_reason": trace.get("classification_reason"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return effects
|
||||||
@ -69,15 +69,24 @@ class SkillLearningService:
|
|||||||
existing_ids.add(candidate.candidate_id)
|
existing_ids.add(candidate.candidate_id)
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
def build_learning_candidates_for_task(self, task_id: str, *, trigger_run_id: str) -> list[SkillLearningCandidate]:
|
def build_learning_candidates_for_task(
|
||||||
"""Build candidates scoped to a single validated and satisfied Task run."""
|
self,
|
||||||
|
task_id: str,
|
||||||
|
*,
|
||||||
|
final_accepted_run_id: str | None = None,
|
||||||
|
trigger_run_id: str | None = None,
|
||||||
|
) -> list[SkillLearningCandidate]:
|
||||||
|
"""Build candidates from a user-accepted Task and all of its runs."""
|
||||||
|
|
||||||
|
final_accepted_run_id = final_accepted_run_id or trigger_run_id
|
||||||
|
if not final_accepted_run_id:
|
||||||
|
return []
|
||||||
runs = [record for record in self.run_store.list_runs() if record.task_id == task_id]
|
runs = [record for record in self.run_store.list_runs() if record.task_id == task_id]
|
||||||
trigger_run = next((record for record in runs if record.run_id == trigger_run_id), None)
|
final_run = next((record for record in runs if record.run_id == final_accepted_run_id), None)
|
||||||
if trigger_run is None or not self._is_confirmed_positive_run(trigger_run):
|
if final_run is None or not self._is_task_accepted_run(final_run):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
source_runs = [record for record in runs if self._is_confirmed_positive_run(record)]
|
source_runs = sorted(runs, key=lambda item: (item.started_at, item.run_id))
|
||||||
if not source_runs:
|
if not source_runs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -100,11 +109,16 @@ class SkillLearningService:
|
|||||||
source_session_ids=source_session_ids,
|
source_session_ids=source_session_ids,
|
||||||
related_skill_names=[],
|
related_skill_names=[],
|
||||||
reason=f"Task {task_id} completed successfully without a published skill; consider extracting reusable guidance.",
|
reason=f"Task {task_id} completed successfully without a published skill; consider extracting reusable guidance.",
|
||||||
evidence={"task_id": task_id, "trigger_run_id": trigger_run_id, "theme": self._task_theme(trigger_run.task_text)},
|
evidence={
|
||||||
|
"task_id": task_id,
|
||||||
|
"final_accepted_run_id": final_accepted_run_id,
|
||||||
|
"source_run_ids": source_run_ids,
|
||||||
|
"theme": self._task_theme(final_run.task_text),
|
||||||
|
},
|
||||||
status="open",
|
status="open",
|
||||||
priority=1,
|
priority=1,
|
||||||
confidence=0.8,
|
confidence=0.8,
|
||||||
trigger_reason="validation_accepted_and_user_satisfied",
|
trigger_reason="task_accepted",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -137,13 +151,14 @@ class SkillLearningService:
|
|||||||
),
|
),
|
||||||
evidence={
|
evidence={
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"trigger_run_id": trigger_run_id,
|
"final_accepted_run_id": final_accepted_run_id,
|
||||||
|
"source_run_ids": source_run_ids,
|
||||||
"skill_version": receipt.skill_version,
|
"skill_version": receipt.skill_version,
|
||||||
},
|
},
|
||||||
status="open",
|
status="open",
|
||||||
priority=1,
|
priority=1,
|
||||||
confidence=0.7,
|
confidence=0.7,
|
||||||
trigger_reason="validation_accepted_and_user_satisfied",
|
trigger_reason="task_accepted",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -190,7 +205,13 @@ class SkillLearningService:
|
|||||||
)
|
)
|
||||||
if candidate.kind == "merge_skills":
|
if candidate.kind == "merge_skills":
|
||||||
target_name = self._suggest_skill_name(candidate, packet)
|
target_name = self._suggest_skill_name(candidate, packet)
|
||||||
payload = await self.synthesizer.synthesize_merge(candidate, packet, provider, model)
|
payload = await self.synthesizer.synthesize_merge(
|
||||||
|
candidate,
|
||||||
|
packet,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
base_skill=self._merged_base_skill_snapshot(candidate.related_skill_names),
|
||||||
|
)
|
||||||
return self.draft_service.create_merge_draft(
|
return self.draft_service.create_merge_draft(
|
||||||
skill_name=target_name,
|
skill_name=target_name,
|
||||||
base_version=None,
|
base_version=None,
|
||||||
@ -202,7 +223,13 @@ class SkillLearningService:
|
|||||||
)
|
)
|
||||||
target_skill = candidate.related_skill_names[0]
|
target_skill = candidate.related_skill_names[0]
|
||||||
base_version = candidate.evidence.get("skill_version")
|
base_version = candidate.evidence.get("skill_version")
|
||||||
payload = await self.synthesizer.synthesize_revision(candidate, packet, provider, model)
|
payload = await self.synthesizer.synthesize_revision(
|
||||||
|
candidate,
|
||||||
|
packet,
|
||||||
|
provider,
|
||||||
|
model,
|
||||||
|
base_skill=self._base_skill_snapshot(target_skill, base_version),
|
||||||
|
)
|
||||||
return self.draft_service.create_revision_draft(
|
return self.draft_service.create_revision_draft(
|
||||||
skill_name=target_skill,
|
skill_name=target_skill,
|
||||||
base_version=base_version,
|
base_version=base_version,
|
||||||
@ -213,6 +240,46 @@ class SkillLearningService:
|
|||||||
evidence_refs=[{"run_id": item} for item in candidate.source_run_ids],
|
evidence_refs=[{"run_id": item} for item in candidate.source_run_ids],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _base_skill_snapshot(self, skill_name: str, version: str | None) -> dict[str, Any] | None:
|
||||||
|
loaded = self.draft_service.store.read_published_skill(skill_name, version)
|
||||||
|
if loaded is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"skill_name": loaded.version.skill_name,
|
||||||
|
"version": loaded.version.version,
|
||||||
|
"frontmatter": dict(loaded.version.frontmatter),
|
||||||
|
"content": loaded.content,
|
||||||
|
"summary": loaded.version.summary,
|
||||||
|
"tool_hints": list(loaded.version.tool_hints),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _merged_base_skill_snapshot(self, skill_names: list[str]) -> dict[str, Any] | None:
|
||||||
|
snapshots = [
|
||||||
|
snapshot
|
||||||
|
for name in skill_names
|
||||||
|
if (snapshot := self._base_skill_snapshot(name, None)) is not None
|
||||||
|
]
|
||||||
|
if not snapshots:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"skill_name": "merge:" + ",".join(str(item["skill_name"]) for item in snapshots),
|
||||||
|
"version": "mixed",
|
||||||
|
"frontmatter": {"merged_skills": [item["frontmatter"] for item in snapshots]},
|
||||||
|
"content": "\n\n".join(
|
||||||
|
f"<!-- base skill: {item['skill_name']} {item['version']} -->\n{item['content']}"
|
||||||
|
for item in snapshots
|
||||||
|
),
|
||||||
|
"summary": "\n".join(str(item["summary"]) for item in snapshots if item.get("summary")),
|
||||||
|
"tool_hints": list(
|
||||||
|
dict.fromkeys(
|
||||||
|
tool
|
||||||
|
for item in snapshots
|
||||||
|
for tool in item.get("tool_hints", [])
|
||||||
|
if str(tool).strip()
|
||||||
|
)
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
def rescore_skill_versions(self) -> list[SkillPerformanceSnapshot]:
|
def rescore_skill_versions(self) -> list[SkillPerformanceSnapshot]:
|
||||||
snapshots: list[SkillPerformanceSnapshot] = []
|
snapshots: list[SkillPerformanceSnapshot] = []
|
||||||
grouped: dict[tuple[str, str], list[SkillEffectRecord]] = {}
|
grouped: dict[tuple[str, str], list[SkillEffectRecord]] = {}
|
||||||
@ -269,7 +336,7 @@ class SkillLearningService:
|
|||||||
groups.setdefault(key, []).append(record)
|
groups.setdefault(key, []).append(record)
|
||||||
candidates: list[SkillLearningCandidate] = []
|
candidates: list[SkillLearningCandidate] = []
|
||||||
for theme, runs in groups.items():
|
for theme, runs in groups.items():
|
||||||
successful = [record for record in runs if self._is_confirmed_positive_run(record)]
|
successful = [record for record in runs if self._is_task_accepted_run(record)]
|
||||||
if len(successful) < 2:
|
if len(successful) < 2:
|
||||||
continue
|
continue
|
||||||
if any(record.activated_skills for record in successful):
|
if any(record.activated_skills for record in successful):
|
||||||
@ -290,7 +357,7 @@ class SkillLearningService:
|
|||||||
def _build_merge_candidates(self) -> list[SkillLearningCandidate]:
|
def _build_merge_candidates(self) -> list[SkillLearningCandidate]:
|
||||||
pair_counts: dict[tuple[str, str], list[RunRecord]] = {}
|
pair_counts: dict[tuple[str, str], list[RunRecord]] = {}
|
||||||
for record in self.run_store.list_runs():
|
for record in self.run_store.list_runs():
|
||||||
if not self._is_confirmed_positive_run(record):
|
if not self._is_task_accepted_run(record):
|
||||||
continue
|
continue
|
||||||
unique = sorted({receipt.skill_name for receipt in record.activated_skills})
|
unique = sorted({receipt.skill_name for receipt in record.activated_skills})
|
||||||
for pair in combinations(unique, 2):
|
for pair in combinations(unique, 2):
|
||||||
@ -351,14 +418,15 @@ class SkillLearningService:
|
|||||||
return effects
|
return effects
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_confirmed_positive_run(record: RunRecord) -> bool:
|
def _is_task_accepted_run(record: RunRecord) -> bool:
|
||||||
validation = record.validation_result or {}
|
|
||||||
feedback = record.feedback or {}
|
feedback = record.feedback or {}
|
||||||
|
acceptance_type = feedback.get("acceptance_type")
|
||||||
|
if acceptance_type is None and feedback.get("feedback_type") == "satisfied":
|
||||||
|
acceptance_type = "accept"
|
||||||
return (
|
return (
|
||||||
bool(record.success)
|
bool(record.success)
|
||||||
and bool(record.task_id)
|
and bool(record.task_id)
|
||||||
and validation.get("accepted") is True
|
and acceptance_type == "accept"
|
||||||
and feedback.get("feedback_type") == "satisfied"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
53
app-instance/backend/beaver/skills/learning/surrogate.py
Normal file
53
app-instance/backend/beaver/skills/learning/surrogate.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
"""Surrogate evaluation for replay tool calls that cannot execute safely."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class SurrogateToolEvaluator:
|
||||||
|
async def evaluate(self, *, task_text: str, baseline: dict[str, Any], candidate: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
baseline_score = _score_arm(task_text, baseline)
|
||||||
|
candidate_score = _score_arm(task_text, candidate)
|
||||||
|
surrogate_count = _mode_count(baseline, "surrogate") + _mode_count(candidate, "surrogate")
|
||||||
|
blocked_count = _mode_count(baseline, "blocked") + _mode_count(candidate, "blocked")
|
||||||
|
confidence = "low" if blocked_count else ("medium" if surrogate_count <= 2 else "low")
|
||||||
|
return {
|
||||||
|
"baseline_score": baseline_score,
|
||||||
|
"candidate_score": candidate_score,
|
||||||
|
"delta": round(candidate_score - baseline_score, 4),
|
||||||
|
"surrogate_tool_count": surrogate_count,
|
||||||
|
"blocked_tool_count": blocked_count,
|
||||||
|
"confidence": confidence,
|
||||||
|
"notes": [
|
||||||
|
"Surrogate score is based on intended tool calls, schemas, arguments, and task relevance.",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _score_arm(task_text: str, arm: dict[str, Any]) -> float:
|
||||||
|
calls = [item for item in arm.get("tool_calls") or [] if isinstance(item, dict)]
|
||||||
|
if not calls:
|
||||||
|
return 0.5
|
||||||
|
scores = [_score_call(task_text, call) for call in calls]
|
||||||
|
return round(sum(scores) / len(scores), 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _score_call(task_text: str, call: dict[str, Any]) -> float:
|
||||||
|
if call.get("mode") == "blocked":
|
||||||
|
return 0.2
|
||||||
|
if call.get("mode") == "executed":
|
||||||
|
result = call.get("result") if isinstance(call.get("result"), dict) else {}
|
||||||
|
return 0.85 if result.get("success") is not False else 0.35
|
||||||
|
arguments = dict(call.get("arguments") or {})
|
||||||
|
if not arguments:
|
||||||
|
return 0.45
|
||||||
|
non_empty = sum(1 for value in arguments.values() if str(value).strip())
|
||||||
|
completeness = non_empty / max(1, len(arguments))
|
||||||
|
argument_text = " ".join(str(value).lower() for value in arguments.values())
|
||||||
|
relevance = 0.15 if any(token and token in argument_text for token in task_text.lower().split()[:16]) else 0.0
|
||||||
|
return round(min(0.9, 0.5 + 0.3 * completeness + relevance), 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _mode_count(arm: dict[str, Any], mode: str) -> int:
|
||||||
|
return sum(1 for item in arm.get("tool_calls") or [] if isinstance(item, dict) and item.get("mode") == mode)
|
||||||
@ -17,8 +17,9 @@ class SkillDraftSynthesizer:
|
|||||||
evidence_packet: EvidencePacket,
|
evidence_packet: EvidencePacket,
|
||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
model: str,
|
model: str,
|
||||||
|
base_skill: dict[str, Any] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return await self._synthesize(candidate, evidence_packet, provider, model, "revise")
|
return await self._synthesize(candidate, evidence_packet, provider, model, "revise", base_skill=base_skill)
|
||||||
|
|
||||||
async def synthesize_new_skill(
|
async def synthesize_new_skill(
|
||||||
self,
|
self,
|
||||||
@ -27,7 +28,7 @@ class SkillDraftSynthesizer:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
model: str,
|
model: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return await self._synthesize(candidate, evidence_packet, provider, model, "new")
|
return await self._synthesize(candidate, evidence_packet, provider, model, "new", base_skill=None)
|
||||||
|
|
||||||
async def synthesize_merge(
|
async def synthesize_merge(
|
||||||
self,
|
self,
|
||||||
@ -35,8 +36,9 @@ class SkillDraftSynthesizer:
|
|||||||
evidence_packet: EvidencePacket,
|
evidence_packet: EvidencePacket,
|
||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
model: str,
|
model: str,
|
||||||
|
base_skill: dict[str, Any] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return await self._synthesize(candidate, evidence_packet, provider, model, "merge")
|
return await self._synthesize(candidate, evidence_packet, provider, model, "merge", base_skill=base_skill)
|
||||||
|
|
||||||
async def _synthesize(
|
async def _synthesize(
|
||||||
self,
|
self,
|
||||||
@ -45,15 +47,18 @@ class SkillDraftSynthesizer:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
model: str,
|
model: str,
|
||||||
action: str,
|
action: str,
|
||||||
|
*,
|
||||||
|
base_skill: dict[str, Any] | None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
prompt = self._build_prompt(candidate, evidence_packet, action)
|
prompt = self._build_prompt(candidate, evidence_packet, action, base_skill=base_skill)
|
||||||
response = await provider.chat(
|
response = await provider.chat(
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": (
|
"content": (
|
||||||
"You synthesize Beaver skill drafts from execution evidence. "
|
"You synthesize Beaver skill drafts from execution evidence. "
|
||||||
"Return only JSON with keys: frontmatter, content, change_reason."
|
"Return only JSON with keys: frontmatter, content, change_reason, "
|
||||||
|
"preserved_sections, changed_sections, dropped_sections."
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
{"role": "user", "content": prompt},
|
{"role": "user", "content": prompt},
|
||||||
@ -69,11 +74,30 @@ class SkillDraftSynthesizer:
|
|||||||
return self._fallback_payload(candidate, evidence_packet, action)
|
return self._fallback_payload(candidate, evidence_packet, action)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_prompt(candidate: SkillLearningCandidate, evidence_packet: EvidencePacket, action: str) -> str:
|
def _build_prompt(
|
||||||
|
candidate: SkillLearningCandidate,
|
||||||
|
evidence_packet: EvidencePacket,
|
||||||
|
action: str,
|
||||||
|
base_skill: dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
tool_names = _coerce_string_list(evidence_packet.metadata.get("tool_names"))
|
tool_names = _coerce_string_list(evidence_packet.metadata.get("tool_names"))
|
||||||
tool_section = ", ".join(tool_names) if tool_names else "none observed"
|
tool_section = ", ".join(tool_names) if tool_names else "none observed"
|
||||||
selected_tool_names = _coerce_string_list(evidence_packet.metadata.get("selected_tool_names"))
|
selected_tool_names = _coerce_string_list(evidence_packet.metadata.get("selected_tool_names"))
|
||||||
selected_tool_section = ", ".join(selected_tool_names) if selected_tool_names else "none recorded"
|
selected_tool_section = ", ".join(selected_tool_names) if selected_tool_names else "none recorded"
|
||||||
|
base_section = ""
|
||||||
|
if base_skill:
|
||||||
|
base_section = (
|
||||||
|
"\n\nBase skill snapshot:\n"
|
||||||
|
f"- skill_name: {base_skill.get('skill_name')}\n"
|
||||||
|
f"- version: {base_skill.get('version')}\n"
|
||||||
|
f"- frontmatter: {json.dumps(base_skill.get('frontmatter') or {}, ensure_ascii=False, sort_keys=True)}\n"
|
||||||
|
f"- tool_hints: {base_skill.get('tool_hints') or []}\n"
|
||||||
|
f"- summary: {base_skill.get('summary') or ''}\n"
|
||||||
|
"Base skill content:\n"
|
||||||
|
f"{base_skill.get('content') or ''}\n"
|
||||||
|
"Preserve existing instructions unless the evidence requires a change. "
|
||||||
|
"If any section is changed or dropped, explain it in changed_sections or dropped_sections."
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
f"Action: {action}\n"
|
f"Action: {action}\n"
|
||||||
f"Candidate kind: {candidate.kind}\n"
|
f"Candidate kind: {candidate.kind}\n"
|
||||||
@ -83,11 +107,13 @@ class SkillDraftSynthesizer:
|
|||||||
f"Run-selected tool names: {selected_tool_section}\n"
|
f"Run-selected tool names: {selected_tool_section}\n"
|
||||||
f"Task summaries:\n- " + "\n- ".join(evidence_packet.task_summaries)
|
f"Task summaries:\n- " + "\n- ".join(evidence_packet.task_summaries)
|
||||||
+ "\n\nSession excerpts:\n" + "\n\n".join(evidence_packet.session_excerpts)
|
+ "\n\nSession excerpts:\n" + "\n\n".join(evidence_packet.session_excerpts)
|
||||||
|
+ base_section
|
||||||
+ "\n\nReturn JSON only. The frontmatter object must include:"
|
+ "\n\nReturn JSON only. The frontmatter object must include:"
|
||||||
+ "\n- description: a concise skill description"
|
+ "\n- description: a concise skill description"
|
||||||
+ "\n- tools: an explicit JSON array of exact tool names this skill needs. "
|
+ "\n- tools: an explicit JSON array of exact tool names this skill needs. "
|
||||||
+ "Prefer called tool names when the workflow depends on them; use run-selected tool names only when clearly required. "
|
+ "Prefer called tool names when the workflow depends on them; use run-selected tool names only when clearly required. "
|
||||||
+ "Use [] only when no tool is required."
|
+ "Use [] only when no tool is required."
|
||||||
|
+ "\nThe JSON may include preserved_sections, changed_sections, and dropped_sections arrays."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -111,6 +137,9 @@ class SkillDraftSynthesizer:
|
|||||||
"frontmatter": frontmatter,
|
"frontmatter": frontmatter,
|
||||||
"content": content_value.strip(),
|
"content": content_value.strip(),
|
||||||
"change_reason": str(payload.get("change_reason") or ""),
|
"change_reason": str(payload.get("change_reason") or ""),
|
||||||
|
"preserved_sections": _coerce_string_list(payload.get("preserved_sections")),
|
||||||
|
"changed_sections": _coerce_string_list(payload.get("changed_sections")),
|
||||||
|
"dropped_sections": _coerce_string_list(payload.get("dropped_sections")),
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -124,6 +153,9 @@ class SkillDraftSynthesizer:
|
|||||||
"frontmatter": frontmatter,
|
"frontmatter": frontmatter,
|
||||||
"content": str(payload.get("content") or "").strip(),
|
"content": str(payload.get("content") or "").strip(),
|
||||||
"change_reason": str(payload.get("change_reason") or ""),
|
"change_reason": str(payload.get("change_reason") or ""),
|
||||||
|
"preserved_sections": _coerce_string_list(payload.get("preserved_sections")),
|
||||||
|
"changed_sections": _coerce_string_list(payload.get("changed_sections")),
|
||||||
|
"dropped_sections": _coerce_string_list(payload.get("dropped_sections")),
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -138,6 +170,9 @@ class SkillDraftSynthesizer:
|
|||||||
},
|
},
|
||||||
"content": f"# {title}\n\n## Evidence\n\n{content}\n",
|
"content": f"# {title}\n\n## Evidence\n\n{content}\n",
|
||||||
"change_reason": candidate.reason or f"Fallback {action} synthesis.",
|
"change_reason": candidate.reason or f"Fallback {action} synthesis.",
|
||||||
|
"preserved_sections": [],
|
||||||
|
"changed_sections": [],
|
||||||
|
"dropped_sections": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +1,27 @@
|
|||||||
"""Internal task tracking for automatic Main Agent task mode."""
|
"""Internal task tracking for automatic Main Agent task mode."""
|
||||||
|
|
||||||
from .models import MainAgentDecision, TaskEvent, TaskRecord, ValidationResult
|
from .evidence import EvidenceBuilder, RunEvidence, TaskEvidencePacket, ToolEvidence, render_task_evidence
|
||||||
|
from .models import MainAgentDecision, TaskEvent, TaskRecord, ValidationResult, ValidationStatus
|
||||||
from .planner import TaskExecutionPlan, TaskExecutionPlanner
|
from .planner import TaskExecutionPlan, TaskExecutionPlanner
|
||||||
from .router import MainAgentRouter
|
from .router import MainAgentRouter
|
||||||
from .service import TaskService
|
from .service import TaskService
|
||||||
from .skill_resolver import SkillResolutionReport, TaskSkillResolver
|
from .skill_resolver import SkillResolutionReport, TaskSkillResolver
|
||||||
from .validation import ValidationService
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"EvidenceBuilder",
|
||||||
"MainAgentDecision",
|
"MainAgentDecision",
|
||||||
"MainAgentRouter",
|
"MainAgentRouter",
|
||||||
|
"RunEvidence",
|
||||||
"TaskEvent",
|
"TaskEvent",
|
||||||
|
"TaskEvidencePacket",
|
||||||
"TaskExecutionPlan",
|
"TaskExecutionPlan",
|
||||||
"TaskExecutionPlanner",
|
"TaskExecutionPlanner",
|
||||||
"TaskRecord",
|
"TaskRecord",
|
||||||
"TaskService",
|
"TaskService",
|
||||||
"SkillResolutionReport",
|
"SkillResolutionReport",
|
||||||
"TaskSkillResolver",
|
"TaskSkillResolver",
|
||||||
|
"ToolEvidence",
|
||||||
"ValidationResult",
|
"ValidationResult",
|
||||||
"ValidationService",
|
"ValidationStatus",
|
||||||
|
"render_task_evidence",
|
||||||
]
|
]
|
||||||
|
|||||||
183
app-instance/backend/beaver/tasks/evidence.py
Normal file
183
app-instance/backend/beaver/tasks/evidence.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
"""Structured evidence for task synthesis and validation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ToolEvidence:
|
||||||
|
tool_name: str
|
||||||
|
tool_call_id: str | None
|
||||||
|
content: str
|
||||||
|
event_payload: dict[str, Any] = field(default_factory=dict)
|
||||||
|
url: str | None = None
|
||||||
|
title: str | None = None
|
||||||
|
created_at: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"tool_call_id": self.tool_call_id,
|
||||||
|
"content": self.content,
|
||||||
|
"event_payload": dict(self.event_payload),
|
||||||
|
"url": self.url,
|
||||||
|
"title": self.title,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class RunEvidence:
|
||||||
|
run_id: str
|
||||||
|
session_id: str
|
||||||
|
output_text: str
|
||||||
|
finish_reason: str
|
||||||
|
transcript: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
tool_results: list[ToolEvidence] = field(default_factory=list)
|
||||||
|
warnings: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"run_id": self.run_id,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"output_text": self.output_text,
|
||||||
|
"finish_reason": self.finish_reason,
|
||||||
|
"transcript": list(self.transcript),
|
||||||
|
"tool_results": [item.to_dict() for item in self.tool_results],
|
||||||
|
"warnings": list(self.warnings),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class TaskEvidencePacket:
|
||||||
|
task_id: str
|
||||||
|
attempt_index: int
|
||||||
|
main_run: RunEvidence | None
|
||||||
|
team_runs: list[RunEvidence] = field(default_factory=list)
|
||||||
|
team_node_results: list[Any] = field(default_factory=list)
|
||||||
|
final_output: str = ""
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"attempt_index": self.attempt_index,
|
||||||
|
"main_run": self.main_run.to_dict() if self.main_run else None,
|
||||||
|
"team_runs": [item.to_dict() for item in self.team_runs],
|
||||||
|
"team_node_results": [
|
||||||
|
item.to_dict() if hasattr(item, "to_dict") else dict(item)
|
||||||
|
for item in self.team_node_results
|
||||||
|
],
|
||||||
|
"final_output": self.final_output,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EvidenceBuilder:
|
||||||
|
def __init__(self, session_manager: Any) -> None:
|
||||||
|
self.session_manager = session_manager
|
||||||
|
|
||||||
|
def build_run_evidence(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
run_id: str,
|
||||||
|
output_text: str,
|
||||||
|
finish_reason: str,
|
||||||
|
) -> RunEvidence:
|
||||||
|
events = self.session_manager.get_run_event_records(session_id, run_id)
|
||||||
|
transcript: list[dict[str, Any]] = []
|
||||||
|
tool_results: list[ToolEvidence] = []
|
||||||
|
warnings: list[str] = []
|
||||||
|
for event in events:
|
||||||
|
payload = dict(event.event_payload or {})
|
||||||
|
transcript.append(
|
||||||
|
{
|
||||||
|
"role": event.role,
|
||||||
|
"event_type": event.event_type,
|
||||||
|
"content": event.content,
|
||||||
|
"tool_name": event.tool_name,
|
||||||
|
"tool_call_id": event.tool_call_id,
|
||||||
|
"finish_reason": event.finish_reason,
|
||||||
|
"event_payload": payload,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if event.event_type == "tool_result_recorded":
|
||||||
|
tool_results.append(
|
||||||
|
ToolEvidence(
|
||||||
|
tool_name=event.tool_name or "tool",
|
||||||
|
tool_call_id=event.tool_call_id,
|
||||||
|
content=event.content or "",
|
||||||
|
event_payload=payload,
|
||||||
|
url=_optional_str(payload.get("url")),
|
||||||
|
title=_optional_str(payload.get("title")),
|
||||||
|
created_at=_optional_str(payload.get("created_at")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if finish_reason and finish_reason != "stop":
|
||||||
|
warnings.append(f"finish_reason={finish_reason}")
|
||||||
|
return RunEvidence(
|
||||||
|
run_id=run_id,
|
||||||
|
session_id=session_id,
|
||||||
|
output_text=output_text,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
transcript=transcript,
|
||||||
|
tool_results=tool_results,
|
||||||
|
warnings=warnings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def render_task_evidence(packet: TaskEvidencePacket) -> str:
|
||||||
|
sections = [
|
||||||
|
f"Task evidence packet: task_id={packet.task_id} attempt={packet.attempt_index}",
|
||||||
|
f"Final output:\n{packet.final_output}",
|
||||||
|
]
|
||||||
|
if packet.main_run is not None:
|
||||||
|
sections.append("Main run evidence:\n" + render_run_evidence(packet.main_run))
|
||||||
|
if packet.team_runs:
|
||||||
|
sections.append(
|
||||||
|
"Team run evidence:\n"
|
||||||
|
+ "\n\n".join(render_run_evidence(item) for item in packet.team_runs)
|
||||||
|
)
|
||||||
|
if packet.team_node_results:
|
||||||
|
lines = []
|
||||||
|
for item in packet.team_node_results:
|
||||||
|
lines.append(
|
||||||
|
f"- {getattr(item, 'node_id', '')}: success={getattr(item, 'success', False)} "
|
||||||
|
f"finish_reason={getattr(item, 'finish_reason', '')} error={getattr(item, 'error', '') or ''}"
|
||||||
|
)
|
||||||
|
sections.append("Team node results:\n" + "\n".join(lines))
|
||||||
|
return "\n\n".join(section for section in sections if section.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def render_run_evidence(evidence: RunEvidence) -> str:
|
||||||
|
lines = [
|
||||||
|
f"run_id={evidence.run_id}",
|
||||||
|
f"session_id={evidence.session_id}",
|
||||||
|
f"finish_reason={evidence.finish_reason}",
|
||||||
|
]
|
||||||
|
if evidence.output_text:
|
||||||
|
lines.append(f"output:\n{evidence.output_text}")
|
||||||
|
if evidence.warnings:
|
||||||
|
lines.append("warnings:\n" + "\n".join(f"- {item}" for item in evidence.warnings))
|
||||||
|
if evidence.tool_results:
|
||||||
|
lines.append(
|
||||||
|
"tool_results:\n"
|
||||||
|
+ "\n\n".join(_render_tool_evidence(item) for item in evidence.tool_results)
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_tool_evidence(item: ToolEvidence) -> str:
|
||||||
|
header = f"- tool={item.tool_name} call_id={item.tool_call_id or ''}"
|
||||||
|
metadata = []
|
||||||
|
if item.url:
|
||||||
|
metadata.append(f"url={item.url}")
|
||||||
|
if item.title:
|
||||||
|
metadata.append(f"title={item.title}")
|
||||||
|
if item.created_at:
|
||||||
|
metadata.append(f"created_at={item.created_at}")
|
||||||
|
return "\n".join([header, *metadata, item.content])
|
||||||
|
|
||||||
|
|
||||||
|
def _optional_str(value: Any) -> str | None:
|
||||||
|
return str(value) if value is not None else None
|
||||||
@ -1,33 +1,70 @@
|
|||||||
"""Models for internal task tracking and validation."""
|
"""Models for internal task tracking and user acceptance."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
TASK_OPEN_STATUSES = {"open", "running", "validating", "awaiting_feedback", "needs_revision"}
|
ValidationStatus = Literal["accepted", "rejected", "insufficient_evidence", "validator_error"]
|
||||||
|
|
||||||
|
VALIDATION_STATUSES = {"accepted", "rejected", "insufficient_evidence", "validator_error"}
|
||||||
|
TASK_OPEN_STATUSES = {"open", "running", "awaiting_acceptance", "needs_revision"}
|
||||||
|
LEGACY_STATUS_MAP = {
|
||||||
|
"validating": "running",
|
||||||
|
"awaiting_feedback": "awaiting_acceptance",
|
||||||
|
"needs_review": "awaiting_acceptance",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class ValidationResult:
|
class ValidationResult:
|
||||||
passed: bool
|
status: ValidationStatus = "rejected"
|
||||||
score: float
|
score: float = 0.0
|
||||||
issues: list[str] = field(default_factory=list)
|
issues: list[str] = field(default_factory=list)
|
||||||
missing_requirements: list[str] = field(default_factory=list)
|
missing_requirements: list[str] = field(default_factory=list)
|
||||||
|
evidence_gaps: list[str] = field(default_factory=list)
|
||||||
recommended_revision_prompt: str = ""
|
recommended_revision_prompt: str = ""
|
||||||
validator: str = "heuristic"
|
validator: str = "heuristic"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
status: ValidationStatus | None = None,
|
||||||
|
passed: bool | None = None,
|
||||||
|
score: float = 0.0,
|
||||||
|
issues: list[str] | None = None,
|
||||||
|
missing_requirements: list[str] | None = None,
|
||||||
|
evidence_gaps: list[str] | None = None,
|
||||||
|
recommended_revision_prompt: str = "",
|
||||||
|
validator: str = "heuristic",
|
||||||
|
) -> None:
|
||||||
|
if status is not None and status not in VALIDATION_STATUSES:
|
||||||
|
raise ValueError(f"unknown validation status: {status}")
|
||||||
|
self.status = status or ("accepted" if passed and score >= 0.75 else "rejected")
|
||||||
|
self.score = max(0.0, min(1.0, float(score or 0.0)))
|
||||||
|
self.issues = list(issues or [])
|
||||||
|
self.missing_requirements = list(missing_requirements or [])
|
||||||
|
self.evidence_gaps = list(evidence_gaps or [])
|
||||||
|
self.recommended_revision_prompt = recommended_revision_prompt
|
||||||
|
self.validator = validator
|
||||||
|
|
||||||
|
@property
|
||||||
|
def passed(self) -> bool:
|
||||||
|
return self.status == "accepted"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def accepted(self) -> bool:
|
def accepted(self) -> bool:
|
||||||
return self.passed and self.score >= 0.75
|
return self.status == "accepted"
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
"status": self.status,
|
||||||
"passed": self.passed,
|
"passed": self.passed,
|
||||||
"score": self.score,
|
"score": self.score,
|
||||||
"issues": list(self.issues),
|
"issues": list(self.issues),
|
||||||
"missing_requirements": list(self.missing_requirements),
|
"missing_requirements": list(self.missing_requirements),
|
||||||
|
"evidence_gaps": list(self.evidence_gaps),
|
||||||
"recommended_revision_prompt": self.recommended_revision_prompt,
|
"recommended_revision_prompt": self.recommended_revision_prompt,
|
||||||
"validator": self.validator,
|
"validator": self.validator,
|
||||||
"accepted": self.accepted,
|
"accepted": self.accepted,
|
||||||
@ -37,11 +74,17 @@ class ValidationResult:
|
|||||||
def from_dict(cls, payload: dict[str, Any] | None) -> "ValidationResult | None":
|
def from_dict(cls, payload: dict[str, Any] | None) -> "ValidationResult | None":
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
return None
|
return None
|
||||||
|
raw_status = payload.get("status")
|
||||||
|
if "status" in payload and raw_status not in VALIDATION_STATUSES:
|
||||||
|
raise ValueError(f"unknown validation status: {raw_status}")
|
||||||
|
status: ValidationStatus | None = raw_status if "status" in payload else None
|
||||||
return cls(
|
return cls(
|
||||||
passed=bool(payload.get("passed")),
|
status=status,
|
||||||
|
passed=bool(payload.get("passed")) if "status" not in payload else None,
|
||||||
score=float(payload.get("score", 0.0) or 0.0),
|
score=float(payload.get("score", 0.0) or 0.0),
|
||||||
issues=[str(item) for item in payload.get("issues") or []],
|
issues=[str(item) for item in payload.get("issues") or []],
|
||||||
missing_requirements=[str(item) for item in payload.get("missing_requirements") or []],
|
missing_requirements=[str(item) for item in payload.get("missing_requirements") or []],
|
||||||
|
evidence_gaps=[str(item) for item in payload.get("evidence_gaps") or []],
|
||||||
recommended_revision_prompt=str(payload.get("recommended_revision_prompt") or ""),
|
recommended_revision_prompt=str(payload.get("recommended_revision_prompt") or ""),
|
||||||
validator=str(payload.get("validator") or "unknown"),
|
validator=str(payload.get("validator") or "unknown"),
|
||||||
)
|
)
|
||||||
@ -73,6 +116,14 @@ class TaskRecord:
|
|||||||
def is_open(self) -> bool:
|
def is_open(self) -> bool:
|
||||||
return self.status in TASK_OPEN_STATUSES
|
return self.status in TASK_OPEN_STATUSES
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_execution_active(self) -> bool:
|
||||||
|
return self.status == "running"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_user_action(self) -> bool:
|
||||||
|
return self.status in {"awaiting_acceptance", "needs_revision"}
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"task_id": self.task_id,
|
"task_id": self.task_id,
|
||||||
@ -91,6 +142,7 @@ class TaskRecord:
|
|||||||
"satisfaction": self.satisfaction,
|
"satisfaction": self.satisfaction,
|
||||||
"run_ids": list(self.run_ids),
|
"run_ids": list(self.run_ids),
|
||||||
"skill_names": list(self.skill_names),
|
"skill_names": list(self.skill_names),
|
||||||
|
"acceptance": list(self.feedback),
|
||||||
"feedback": list(self.feedback),
|
"feedback": list(self.feedback),
|
||||||
"validation_result": self.validation_result,
|
"validation_result": self.validation_result,
|
||||||
"metadata": dict(self.metadata),
|
"metadata": dict(self.metadata),
|
||||||
@ -106,7 +158,7 @@ class TaskRecord:
|
|||||||
goal=str(payload.get("goal") or payload.get("description") or ""),
|
goal=str(payload.get("goal") or payload.get("description") or ""),
|
||||||
constraints=[str(item) for item in payload.get("constraints") or []],
|
constraints=[str(item) for item in payload.get("constraints") or []],
|
||||||
priority=int(payload.get("priority", 0) or 0),
|
priority=int(payload.get("priority", 0) or 0),
|
||||||
status=str(payload.get("status") or "open"),
|
status=LEGACY_STATUS_MAP.get(str(payload.get("status") or "open"), str(payload.get("status") or "open")),
|
||||||
creator=str(payload.get("creator") or "main-agent"),
|
creator=str(payload.get("creator") or "main-agent"),
|
||||||
created_at=str(payload.get("created_at") or ""),
|
created_at=str(payload.get("created_at") or ""),
|
||||||
updated_at=str(payload.get("updated_at") or ""),
|
updated_at=str(payload.get("updated_at") or ""),
|
||||||
@ -115,7 +167,11 @@ class TaskRecord:
|
|||||||
satisfaction=_optional_float(payload.get("satisfaction")),
|
satisfaction=_optional_float(payload.get("satisfaction")),
|
||||||
run_ids=[str(item) for item in payload.get("run_ids") or []],
|
run_ids=[str(item) for item in payload.get("run_ids") or []],
|
||||||
skill_names=[str(item) for item in payload.get("skill_names") or []],
|
skill_names=[str(item) for item in payload.get("skill_names") or []],
|
||||||
feedback=[dict(item) for item in payload.get("feedback") or [] if isinstance(item, dict)],
|
feedback=[
|
||||||
|
_normalize_acceptance_entry(dict(item))
|
||||||
|
for item in (payload.get("acceptance") or payload.get("feedback") or [])
|
||||||
|
if isinstance(item, dict)
|
||||||
|
],
|
||||||
validation_result=dict(payload["validation_result"]) if isinstance(payload.get("validation_result"), dict) else None,
|
validation_result=dict(payload["validation_result"]) if isinstance(payload.get("validation_result"), dict) else None,
|
||||||
metadata=dict(payload.get("metadata") or {}),
|
metadata=dict(payload.get("metadata") or {}),
|
||||||
)
|
)
|
||||||
@ -180,3 +236,13 @@ def _optional_float(value: Any) -> float | None:
|
|||||||
if value in (None, ""):
|
if value in (None, ""):
|
||||||
return None
|
return None
|
||||||
return float(value)
|
return float(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_acceptance_entry(entry: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
if entry.get("acceptance_type") is None and entry.get("feedback_type") is not None:
|
||||||
|
feedback_type = str(entry.get("feedback_type") or "")
|
||||||
|
entry["acceptance_type"] = "accept" if feedback_type == "satisfied" else feedback_type
|
||||||
|
if entry.get("feedback_type") is None and entry.get("acceptance_type") is not None:
|
||||||
|
acceptance_type = str(entry.get("acceptance_type") or "")
|
||||||
|
entry["feedback_type"] = "satisfied" if acceptance_type == "accept" else acceptance_type
|
||||||
|
return entry
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from typing import Any, Literal
|
|||||||
from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode
|
from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode
|
||||||
from beaver.engine.providers import ProviderBundle
|
from beaver.engine.providers import ProviderBundle
|
||||||
|
|
||||||
from .models import TaskRecord, ValidationResult
|
from .models import TaskRecord
|
||||||
from .skill_resolver import SkillResolutionReport, TaskSkillResolver
|
from .skill_resolver import SkillResolutionReport, TaskSkillResolver
|
||||||
|
|
||||||
|
|
||||||
@ -76,7 +76,6 @@ class TaskExecutionPlanner:
|
|||||||
task: TaskRecord,
|
task: TaskRecord,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
attempt_index: int,
|
attempt_index: int,
|
||||||
latest_validation: ValidationResult | None = None,
|
|
||||||
provider_bundle: ProviderBundle | None = None,
|
provider_bundle: ProviderBundle | None = None,
|
||||||
timeout_seconds: float = 30.0,
|
timeout_seconds: float = 30.0,
|
||||||
) -> TaskExecutionPlan:
|
) -> TaskExecutionPlan:
|
||||||
@ -105,7 +104,6 @@ class TaskExecutionPlanner:
|
|||||||
task=task,
|
task=task,
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
attempt_index=attempt_index,
|
attempt_index=attempt_index,
|
||||||
latest_validation=latest_validation,
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@ -230,14 +228,10 @@ class TaskExecutionPlanner:
|
|||||||
task: TaskRecord,
|
task: TaskRecord,
|
||||||
user_message: str,
|
user_message: str,
|
||||||
attempt_index: int,
|
attempt_index: int,
|
||||||
latest_validation: ValidationResult | None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
validation_note = ""
|
history_note = ""
|
||||||
if latest_validation is not None:
|
if task.feedback:
|
||||||
validation_note = (
|
history_note = "\nRelevant task history:\n" + json.dumps(task.feedback[-5:], ensure_ascii=False)
|
||||||
"\nPrevious validation issues:\n"
|
|
||||||
+ json.dumps(latest_validation.to_dict(), ensure_ascii=False)
|
|
||||||
)
|
|
||||||
return (
|
return (
|
||||||
"Decide execution mode for this internal Task attempt.\n"
|
"Decide execution mode for this internal Task attempt.\n"
|
||||||
"Use mode=team only when independent research, review, implementation slices, or staged checks "
|
"Use mode=team only when independent research, review, implementation slices, or staged checks "
|
||||||
@ -254,7 +248,7 @@ class TaskExecutionPlanner:
|
|||||||
f"Task goal:\n{task.goal}\n\n"
|
f"Task goal:\n{task.goal}\n\n"
|
||||||
f"Current user request:\n{user_message}\n\n"
|
f"Current user request:\n{user_message}\n\n"
|
||||||
f"Attempt index: {attempt_index}\n"
|
f"Attempt index: {attempt_index}\n"
|
||||||
f"{validation_note}"
|
f"{history_note}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -161,6 +161,9 @@ class MainAgentRouter:
|
|||||||
"Critical policy:\n"
|
"Critical policy:\n"
|
||||||
"- If there is an active Task, choose continue_task or revise_task unless the user's topic is completely unrelated "
|
"- If there is an active Task, choose continue_task or revise_task unless the user's topic is completely unrelated "
|
||||||
"to that Task or the user explicitly closes/abandons it.\n"
|
"to that Task or the user explicitly closes/abandons it.\n"
|
||||||
|
"- With an active Task, choose simple_chat for unrelated lightweight conversation and new_task for unrelated work "
|
||||||
|
"that needs Task capabilities. Either decision starts a new topic.\n"
|
||||||
|
"- An unrelated lightweight conversation must not be classified as revise_task merely because the active Task is awaiting acceptance.\n"
|
||||||
"- Choose revise_task when the active Task is awaiting feedback or needs revision and the user asks for changes "
|
"- Choose revise_task when the active Task is awaiting feedback or needs revision and the user asks for changes "
|
||||||
"such as '改一下', '加上', '删除', '换成', '再详细点', '格式改成', '不要', or equivalent wording.\n"
|
"such as '改一下', '加上', '删除', '换成', '再详细点', '格式改成', '不要', or equivalent wording.\n"
|
||||||
"- Choose continue_task for neutral follow-up questions or additional next steps that do not imply dissatisfaction with the previous result.\n"
|
"- Choose continue_task for neutral follow-up questions or additional next steps that do not imply dissatisfaction with the previous result.\n"
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from .models import TaskEvent, TaskRecord, ValidationResult
|
from .models import TaskEvent, TaskRecord
|
||||||
from .store import TaskStore
|
from .store import TaskStore
|
||||||
|
|
||||||
|
|
||||||
@ -77,6 +77,8 @@ class TaskService:
|
|||||||
payload = task.to_dict()
|
payload = task.to_dict()
|
||||||
payload["short_title"] = self.ensure_short_title(task).metadata.get("short_title")
|
payload["short_title"] = self.ensure_short_title(task).metadata.get("short_title")
|
||||||
payload["is_open"] = task.is_open
|
payload["is_open"] = task.is_open
|
||||||
|
payload["is_execution_active"] = task.is_execution_active
|
||||||
|
payload["requires_user_action"] = task.requires_user_action
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def ensure_short_title(self, task: TaskRecord) -> TaskRecord:
|
def ensure_short_title(self, task: TaskRecord) -> TaskRecord:
|
||||||
@ -103,18 +105,70 @@ class TaskService:
|
|||||||
for name in skill_names or []:
|
for name in skill_names or []:
|
||||||
if name not in task.skill_names:
|
if name not in task.skill_names:
|
||||||
task.skill_names.append(name)
|
task.skill_names.append(name)
|
||||||
|
task.status = "awaiting_acceptance"
|
||||||
task.updated_at = self._now()
|
task.updated_at = self._now()
|
||||||
self.store.upsert_task(task)
|
self.store.upsert_task(task)
|
||||||
self._event(task, "run_completed", run_id=run_id, payload={"skill_names": skill_names or []})
|
self._event(task, "run_completed", run_id=run_id, payload={"skill_names": skill_names or []})
|
||||||
|
self._event(task, "evidence_recorded", run_id=run_id, payload={"skill_names": skill_names or []})
|
||||||
return task
|
return task
|
||||||
|
|
||||||
def record_validation(self, task_id: str, run_id: str, validation: ValidationResult) -> TaskRecord:
|
def add_acceptance(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
*,
|
||||||
|
acceptance_type: str,
|
||||||
|
comment: str | None = None,
|
||||||
|
run_id: str | None = None,
|
||||||
|
) -> TaskRecord:
|
||||||
task = self._require(task_id)
|
task = self._require(task_id)
|
||||||
task.status = "awaiting_feedback"
|
now = self._now()
|
||||||
task.updated_at = self._now()
|
normalized = normalize_acceptance_type(acceptance_type)
|
||||||
task.validation_result = validation.to_dict()
|
matching_acceptance = any(
|
||||||
|
item.get("run_id") == run_id and item.get("acceptance_type") == normalized
|
||||||
|
for item in task.feedback
|
||||||
|
)
|
||||||
|
conflicting_acceptance = next(
|
||||||
|
(
|
||||||
|
item
|
||||||
|
for item in task.feedback
|
||||||
|
if item.get("run_id") == run_id and item.get("acceptance_type") != normalized
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if conflicting_acceptance is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Acceptance for run_id={run_id!r} was already recorded as "
|
||||||
|
f"{conflicting_acceptance.get('acceptance_type')!r}"
|
||||||
|
)
|
||||||
|
if task.status in {"closed", "abandoned"} and not matching_acceptance:
|
||||||
|
raise ValueError(f"Task {task.task_id} is already finalized as {task.status!r}")
|
||||||
|
if matching_acceptance:
|
||||||
|
return task
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"acceptance_type": normalized,
|
||||||
|
"feedback_type": "satisfied" if normalized == "accept" else normalized,
|
||||||
|
"comment": comment or "",
|
||||||
|
"run_id": run_id,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
|
task.feedback.append(entry)
|
||||||
|
if normalized == "revise":
|
||||||
|
task.status = "needs_revision"
|
||||||
|
elif normalized == "abandon":
|
||||||
|
task.status = "abandoned"
|
||||||
|
task.closed_at = now
|
||||||
|
task.close_reason = comment or "abandoned"
|
||||||
|
elif normalized == "accept":
|
||||||
|
task.status = "closed"
|
||||||
|
task.closed_at = now
|
||||||
|
task.close_reason = "accepted"
|
||||||
|
task.satisfaction = 1.0
|
||||||
|
if run_id:
|
||||||
|
task.metadata["final_accepted_run_id"] = run_id
|
||||||
|
task.updated_at = now
|
||||||
self.store.upsert_task(task)
|
self.store.upsert_task(task)
|
||||||
self._event(task, "validated", run_id=run_id, payload=validation.to_dict())
|
self._event(task, f"acceptance_{normalized}", run_id=run_id, payload=entry)
|
||||||
return task
|
return task
|
||||||
|
|
||||||
def add_feedback(
|
def add_feedback(
|
||||||
@ -125,52 +179,12 @@ class TaskService:
|
|||||||
comment: str | None = None,
|
comment: str | None = None,
|
||||||
run_id: str | None = None,
|
run_id: str | None = None,
|
||||||
) -> TaskRecord:
|
) -> TaskRecord:
|
||||||
task = self._require(task_id)
|
return self.add_acceptance(
|
||||||
now = self._now()
|
task_id,
|
||||||
matching_feedback = any(
|
acceptance_type=feedback_type,
|
||||||
item.get("run_id") == run_id and item.get("feedback_type") == feedback_type
|
comment=comment,
|
||||||
for item in task.feedback
|
run_id=run_id,
|
||||||
)
|
)
|
||||||
conflicting_feedback = next(
|
|
||||||
(
|
|
||||||
item
|
|
||||||
for item in task.feedback
|
|
||||||
if item.get("run_id") == run_id and item.get("feedback_type") != feedback_type
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if conflicting_feedback is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Feedback for run_id={run_id!r} was already recorded as "
|
|
||||||
f"{conflicting_feedback.get('feedback_type')!r}"
|
|
||||||
)
|
|
||||||
if task.status in {"closed", "abandoned"} and not matching_feedback:
|
|
||||||
raise ValueError(f"Task {task.task_id} is already finalized as {task.status!r}")
|
|
||||||
if matching_feedback:
|
|
||||||
return task
|
|
||||||
|
|
||||||
entry = {
|
|
||||||
"feedback_type": feedback_type,
|
|
||||||
"comment": comment or "",
|
|
||||||
"run_id": run_id,
|
|
||||||
"created_at": now,
|
|
||||||
}
|
|
||||||
task.feedback.append(entry)
|
|
||||||
if feedback_type == "revise":
|
|
||||||
task.status = "needs_revision"
|
|
||||||
elif feedback_type == "abandon":
|
|
||||||
task.status = "abandoned"
|
|
||||||
task.closed_at = now
|
|
||||||
task.close_reason = comment or "abandoned"
|
|
||||||
elif feedback_type == "satisfied":
|
|
||||||
task.status = "closed"
|
|
||||||
task.closed_at = now
|
|
||||||
task.close_reason = "satisfied"
|
|
||||||
task.satisfaction = 1.0
|
|
||||||
task.updated_at = now
|
|
||||||
self.store.upsert_task(task)
|
|
||||||
self._event(task, f"feedback_{feedback_type}", run_id=run_id, payload=entry)
|
|
||||||
return task
|
|
||||||
|
|
||||||
def close_task(self, task_id: str, *, reason: str = "closed") -> TaskRecord:
|
def close_task(self, task_id: str, *, reason: str = "closed") -> TaskRecord:
|
||||||
task = self._require(task_id)
|
task = self._require(task_id)
|
||||||
@ -245,3 +259,12 @@ def short_task_title(text: str) -> str:
|
|||||||
if len(words) <= 4:
|
if len(words) <= 4:
|
||||||
return cleaned[:40]
|
return cleaned[:40]
|
||||||
return " ".join(words[:4])[:40]
|
return " ".join(words[:4])[:40]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_acceptance_type(value: str) -> str:
|
||||||
|
normalized = (value or "").strip().lower()
|
||||||
|
if normalized == "satisfied":
|
||||||
|
return "accept"
|
||||||
|
if normalized not in {"accept", "revise", "abandon"}:
|
||||||
|
raise ValueError("acceptance_type must be one of: accept, revise, abandon")
|
||||||
|
return normalized
|
||||||
|
|||||||
@ -1,138 +0,0 @@
|
|||||||
"""Automatic validation for internal Task mode."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from beaver.engine.providers import ProviderBundle
|
|
||||||
|
|
||||||
from .models import TaskRecord, ValidationResult
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationService:
|
|
||||||
async def validate_task_result(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
task: TaskRecord,
|
|
||||||
user_message: str,
|
|
||||||
final_output: str,
|
|
||||||
transcript_excerpt: str = "",
|
|
||||||
tool_summaries: list[str] | None = None,
|
|
||||||
team_summaries: list[str] | None = None,
|
|
||||||
provider_bundle: ProviderBundle | None = None,
|
|
||||||
) -> ValidationResult:
|
|
||||||
provider = None
|
|
||||||
model = None
|
|
||||||
if provider_bundle is not None:
|
|
||||||
provider = provider_bundle.auxiliary_provider or provider_bundle.main_provider
|
|
||||||
runtime = provider_bundle.auxiliary_runtime or provider_bundle.main_runtime
|
|
||||||
model = getattr(runtime, "model", None)
|
|
||||||
if provider is not None:
|
|
||||||
try:
|
|
||||||
return await self._validate_with_provider(
|
|
||||||
provider=provider,
|
|
||||||
model=model,
|
|
||||||
task=task,
|
|
||||||
user_message=user_message,
|
|
||||||
final_output=final_output,
|
|
||||||
transcript_excerpt=transcript_excerpt,
|
|
||||||
tool_summaries=tool_summaries or [],
|
|
||||||
team_summaries=team_summaries or [],
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
return ValidationResult(
|
|
||||||
passed=False,
|
|
||||||
score=0.0,
|
|
||||||
issues=[f"Validator failed: {exc}"],
|
|
||||||
missing_requirements=["A valid automatic validation result is required before accepting the task."],
|
|
||||||
recommended_revision_prompt=(
|
|
||||||
"Review the task result again because automatic validation failed, "
|
|
||||||
"then provide a corrected final answer that explicitly satisfies the task goal."
|
|
||||||
),
|
|
||||||
validator="llm_error",
|
|
||||||
)
|
|
||||||
return self._heuristic_validate(final_output)
|
|
||||||
|
|
||||||
async def _validate_with_provider(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
provider: Any,
|
|
||||||
model: str | None,
|
|
||||||
task: TaskRecord,
|
|
||||||
user_message: str,
|
|
||||||
final_output: str,
|
|
||||||
transcript_excerpt: str,
|
|
||||||
tool_summaries: list[str],
|
|
||||||
team_summaries: list[str],
|
|
||||||
) -> ValidationResult:
|
|
||||||
prompt = (
|
|
||||||
"Validate whether the assistant output satisfies the task. "
|
|
||||||
"Return only compact JSON with keys: passed, score, issues, "
|
|
||||||
"missing_requirements, recommended_revision_prompt.\n\n"
|
|
||||||
f"Task goal:\n{task.goal}\n\n"
|
|
||||||
f"Current user request:\n{user_message}\n\n"
|
|
||||||
f"Transcript excerpt:\n{transcript_excerpt[:2500]}\n\n"
|
|
||||||
f"Tool summaries:\n{json.dumps(tool_summaries[:12], ensure_ascii=False)}\n\n"
|
|
||||||
f"Team summaries:\n{json.dumps(team_summaries[:12], ensure_ascii=False)}\n\n"
|
|
||||||
f"Assistant final output:\n{final_output[:4000]}"
|
|
||||||
)
|
|
||||||
response = await provider.chat(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a strict task result validator."},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
tools=None,
|
|
||||||
model=model,
|
|
||||||
max_tokens=4096,
|
|
||||||
temperature=0.0,
|
|
||||||
)
|
|
||||||
payload = self._parse_json_object(response.content or "")
|
|
||||||
return ValidationResult(
|
|
||||||
passed=bool(payload.get("passed")),
|
|
||||||
score=max(0.0, min(1.0, float(payload.get("score", 0.0) or 0.0))),
|
|
||||||
issues=[str(item) for item in payload.get("issues") or []],
|
|
||||||
missing_requirements=[str(item) for item in payload.get("missing_requirements") or []],
|
|
||||||
recommended_revision_prompt=str(payload.get("recommended_revision_prompt") or ""),
|
|
||||||
validator="llm",
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _heuristic_validate(final_output: str) -> ValidationResult:
|
|
||||||
text = final_output.strip()
|
|
||||||
if not text:
|
|
||||||
return ValidationResult(
|
|
||||||
passed=False,
|
|
||||||
score=0.0,
|
|
||||||
issues=["Assistant output is empty."],
|
|
||||||
missing_requirements=["A non-empty result is required."],
|
|
||||||
recommended_revision_prompt="Produce a complete, non-empty answer for the task.",
|
|
||||||
validator="heuristic",
|
|
||||||
)
|
|
||||||
lowered = text.lower()
|
|
||||||
if "run failed before completion" in lowered or "tool loop stopped" in lowered:
|
|
||||||
return ValidationResult(
|
|
||||||
passed=False,
|
|
||||||
score=0.35,
|
|
||||||
issues=["The run did not complete cleanly."],
|
|
||||||
missing_requirements=["A successful final result is required."],
|
|
||||||
recommended_revision_prompt="Retry the task and address the failure before returning the final answer.",
|
|
||||||
validator="heuristic",
|
|
||||||
)
|
|
||||||
return ValidationResult(passed=True, score=0.85, validator="heuristic")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_json_object(text: str) -> dict[str, Any]:
|
|
||||||
cleaned = text.strip()
|
|
||||||
if cleaned.startswith("```"):
|
|
||||||
cleaned = cleaned.strip("`")
|
|
||||||
if cleaned.lower().startswith("json"):
|
|
||||||
cleaned = cleaned[4:].strip()
|
|
||||||
start = cleaned.find("{")
|
|
||||||
end = cleaned.rfind("}")
|
|
||||||
if start >= 0 and end >= start:
|
|
||||||
cleaned = cleaned[start : end + 1]
|
|
||||||
payload = json.loads(cleaned)
|
|
||||||
if not isinstance(payload, dict):
|
|
||||||
raise ValueError("validator response must be a JSON object")
|
|
||||||
return payload
|
|
||||||
@ -51,7 +51,7 @@ class WebFetchTool:
|
|||||||
try:
|
try:
|
||||||
safe_url = _safe_url(url)
|
safe_url = _safe_url(url)
|
||||||
limit = max(1000, min(int(max_chars or 12000), 50000))
|
limit = max(1000, min(int(max_chars or 12000), 50000))
|
||||||
async with httpx.AsyncClient(timeout=20, follow_redirects=True, trust_env=False) as client:
|
async with httpx.AsyncClient(timeout=20, follow_redirects=True, trust_env=True) as client:
|
||||||
response = await client.get(
|
response = await client.get(
|
||||||
safe_url,
|
safe_url,
|
||||||
headers={"User-Agent": "Mozilla/5.0 Beaver/1.0"},
|
headers={"User-Agent": "Mozilla/5.0 Beaver/1.0"},
|
||||||
@ -96,7 +96,7 @@ class WebSearchTool:
|
|||||||
raise ValueError("query is required")
|
raise ValueError("query is required")
|
||||||
bounded = max(1, min(int(limit or 5), 10))
|
bounded = max(1, min(int(limit or 5), 10))
|
||||||
url = f"https://duckduckgo.com/html/?q={quote_plus(query)}"
|
url = f"https://duckduckgo.com/html/?q={quote_plus(query)}"
|
||||||
async with httpx.AsyncClient(timeout=20, follow_redirects=True, trust_env=False) as client:
|
async with httpx.AsyncClient(timeout=20, follow_redirects=True, trust_env=True) as client:
|
||||||
response = await client.get(url, headers={"User-Agent": "Mozilla/5.0 Beaver/1.0"})
|
response = await client.get(url, headers={"User-Agent": "Mozilla/5.0 Beaver/1.0"})
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
html = response.text
|
html = response.text
|
||||||
|
|||||||
@ -886,11 +886,9 @@ tool_hints:
|
|||||||
<div class="panel">
|
<div class="panel">
|
||||||
<h3>Beaver 平台部署层</h3>
|
<h3>Beaver 平台部署层</h3>
|
||||||
<ul>
|
<ul>
|
||||||
<li><span class="code">auth-portal</span>:用户入口、登录注册。</li>
|
<li><span class="code">Keycloak OIDC</span>:用户入口、SSO 登录与登出。</li>
|
||||||
<li><span class="code">authz-service</span>:授权服务和后端身份。</li>
|
<li><span class="code">app-instance</span>:单体应用容器,内置前端、后端和 Nginx。</li>
|
||||||
<li><span class="code">deploy-control</span>:创建和管理单用户实例。</li>
|
<li>外部编排系统负责部署和访问入口,当前应用不维护实例注册表。</li>
|
||||||
<li><span class="code">router-proxy</span>:按 Host 转发到实例。</li>
|
|
||||||
<li>每个用户一个 <span class="code">app-instance</span>,Docker 容器隔离 + per-instance workspace。</li>
|
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
<div class="panel">
|
<div class="panel">
|
||||||
|
|||||||
@ -13,6 +13,7 @@ dependencies = [
|
|||||||
"litellm>=1.79.0,<2.0.0",
|
"litellm>=1.79.0,<2.0.0",
|
||||||
"minio>=7.2.0,<8.0.0",
|
"minio>=7.2.0,<8.0.0",
|
||||||
"openai>=1.79.0,<2.0.0",
|
"openai>=1.79.0,<2.0.0",
|
||||||
|
"PyJWT[crypto]>=2.9.0,<3.0.0",
|
||||||
"pydantic>=2.12.0,<3.0.0",
|
"pydantic>=2.12.0,<3.0.0",
|
||||||
"python-multipart>=0.0.20,<1.0.0",
|
"python-multipart>=0.0.20,<1.0.0",
|
||||||
"typer>=0.20.0,<1.0.0",
|
"typer>=0.20.0,<1.0.0",
|
||||||
@ -22,6 +23,24 @@ dependencies = [
|
|||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=9.0.0,<10.0.0",
|
"pytest>=9.0.0,<10.0.0",
|
||||||
|
"pytest-asyncio>=1.0.0,<2.0.0",
|
||||||
|
]
|
||||||
|
telegram = [
|
||||||
|
"python-telegram-bot>=22.0,<23.0",
|
||||||
|
]
|
||||||
|
feishu = [
|
||||||
|
"lark-oapi>=1.4.22,<2.0.0",
|
||||||
|
]
|
||||||
|
qqbot = [
|
||||||
|
"aiohttp>=3.9.0,<4.0.0",
|
||||||
|
]
|
||||||
|
weixin = [
|
||||||
|
"aiohttp>=3.9.0,<4.0.0",
|
||||||
|
]
|
||||||
|
channels = [
|
||||||
|
"python-telegram-bot>=22.0,<23.0",
|
||||||
|
"lark-oapi>=1.4.22,<2.0.0",
|
||||||
|
"aiohttp>=3.9.0,<4.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
47
app-instance/backend/tests/unit/test_agent_loop.py
Normal file
47
app-instance/backend/tests/unit/test_agent_loop.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import asyncio
|
||||||
|
from contextlib import suppress
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from beaver.engine import AgentLoop, AgentRunResult, EngineLoader
|
||||||
|
|
||||||
|
|
||||||
|
def _run_result(run_id: str, output_text: str) -> AgentRunResult:
|
||||||
|
return AgentRunResult(
|
||||||
|
session_id="web:test",
|
||||||
|
run_id=run_id,
|
||||||
|
output_text=output_text,
|
||||||
|
finish_reason="stop",
|
||||||
|
tool_iterations=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_running_loop_handles_reentrant_submit_direct(tmp_path) -> None:
|
||||||
|
async def run_case() -> None:
|
||||||
|
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||||
|
calls: list[str] = []
|
||||||
|
|
||||||
|
async def fake_process_direct(task: str, **kwargs: Any) -> AgentRunResult:
|
||||||
|
calls.append(task)
|
||||||
|
if task == "outer":
|
||||||
|
return await loop.submit_direct("inner", session_id="web:test")
|
||||||
|
return _run_result(task, "inner completed")
|
||||||
|
|
||||||
|
loop._process_direct_impl = fake_process_direct # type: ignore[method-assign]
|
||||||
|
|
||||||
|
loop_task = asyncio.create_task(loop.run())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(loop.submit_direct("outer", session_id="web:test"), timeout=1)
|
||||||
|
finally:
|
||||||
|
await loop.stop()
|
||||||
|
with suppress(asyncio.TimeoutError):
|
||||||
|
await asyncio.wait_for(loop_task, timeout=1)
|
||||||
|
if not loop_task.done():
|
||||||
|
loop_task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await loop_task
|
||||||
|
|
||||||
|
assert result.output_text == "inner completed"
|
||||||
|
assert calls == ["outer", "inner"]
|
||||||
|
|
||||||
|
asyncio.run(run_case())
|
||||||
@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from beaver.engine.loader import EngineLoader
|
||||||
|
from beaver.engine.loop import AgentLoop
|
||||||
|
from beaver.engine.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
from beaver.engine.providers.factory import ProviderBundle
|
||||||
|
from beaver.skills.learning.replay import ReplayToolExecutor, ReplayToolPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallingProvider(LLMProvider):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
thinking_enabled: bool | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call-1",
|
||||||
|
name="read_file",
|
||||||
|
arguments={"path": "README.md"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return LLMResponse(content="done")
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return "stub"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_direct_uses_replay_tool_executor(tmp_path: Path) -> None:
|
||||||
|
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||||
|
loaded = loop.boot()
|
||||||
|
provider = ToolCallingProvider()
|
||||||
|
runtime = SimpleNamespace(model="stub", provider_name="stub")
|
||||||
|
replay_executor = ReplayToolExecutor(
|
||||||
|
loaded.tool_executor,
|
||||||
|
registry=loaded.tool_registry,
|
||||||
|
policy=ReplayToolPolicy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await loop.process_direct(
|
||||||
|
"Read the README.",
|
||||||
|
provider_bundle=ProviderBundle(main_runtime=runtime, main_provider=provider), # type: ignore[arg-type]
|
||||||
|
include_skill_assembly=False,
|
||||||
|
pinned_skill_names=[],
|
||||||
|
tool_executor_override=replay_executor,
|
||||||
|
max_tool_iterations=2,
|
||||||
|
source="skill_replay_eval",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.output_text == "done"
|
||||||
|
assert replay_executor.traces
|
||||||
|
assert replay_executor.traces[0]["tool_name"] == "read_file"
|
||||||
@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode
|
from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode
|
||||||
from beaver.coordinator.registry import AgentRegistry, RegisteredAgent, TargetResolver
|
from beaver.coordinator.registry import AgentRegistry, RegisteredAgent, TargetResolver
|
||||||
from beaver.tasks import TaskRecord
|
from beaver.tasks import TaskRecord
|
||||||
@ -20,22 +22,64 @@ def _task() -> TaskRecord:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_registry_seeds_builtin_agents_and_filters_disabled(tmp_path) -> None:
|
def test_registry_starts_empty_and_filters_disabled(tmp_path) -> None:
|
||||||
registry = AgentRegistry(tmp_path)
|
registry = AgentRegistry(tmp_path)
|
||||||
|
|
||||||
assert {agent.agent_id for agent in registry.list_active_agents()} >= {
|
assert registry.list_agents() == []
|
||||||
"researcher",
|
|
||||||
"implementer",
|
|
||||||
"reviewer",
|
|
||||||
"tester",
|
|
||||||
"documenter",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
registry.upsert_agent(
|
||||||
|
RegisteredAgent(
|
||||||
|
agent_id="tester",
|
||||||
|
name="tester",
|
||||||
|
display_name="Tester",
|
||||||
|
role="testing",
|
||||||
|
description="Runs checks.",
|
||||||
|
system_prompt="test",
|
||||||
|
)
|
||||||
|
)
|
||||||
registry.disable_agent("tester")
|
registry.disable_agent("tester")
|
||||||
|
|
||||||
assert "tester" not in {agent.agent_id for agent in registry.list_active_agents()}
|
assert "tester" not in {agent.agent_id for agent in registry.list_active_agents()}
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_drops_legacy_builtin_agents(tmp_path) -> None:
|
||||||
|
registry_path = tmp_path / "agents" / "registry.json"
|
||||||
|
registry_path.parent.mkdir(parents=True)
|
||||||
|
registry_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"agents": [
|
||||||
|
{
|
||||||
|
"agent_id": "researcher",
|
||||||
|
"name": "researcher",
|
||||||
|
"display_name": "Researcher",
|
||||||
|
"role": "research",
|
||||||
|
"description": "legacy builtin",
|
||||||
|
"system_prompt": "research",
|
||||||
|
"source": "builtin",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"agent_id": "workspace-agent",
|
||||||
|
"name": "workspace-agent",
|
||||||
|
"display_name": "Workspace Agent",
|
||||||
|
"role": "workspace",
|
||||||
|
"description": "user configured",
|
||||||
|
"system_prompt": "work",
|
||||||
|
"source": "workspace",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = AgentRegistry(tmp_path)
|
||||||
|
|
||||||
|
assert [agent.agent_id for agent in registry.list_agents()] == ["workspace-agent"]
|
||||||
|
|
||||||
|
|
||||||
def test_resolver_selects_registered_agent_by_role_and_capabilities(tmp_path) -> None:
|
def test_resolver_selects_registered_agent_by_role_and_capabilities(tmp_path) -> None:
|
||||||
registry = AgentRegistry(tmp_path)
|
registry = AgentRegistry(tmp_path)
|
||||||
registry.upsert_agent(
|
registry.upsert_agent(
|
||||||
@ -88,4 +132,3 @@ def test_resolver_falls_back_to_ephemeral_agent_when_no_match(tmp_path) -> None:
|
|||||||
assert resolved.nodes[0].agent.name == "rare"
|
assert resolved.nodes[0].agent.name == "rare"
|
||||||
assert resolved.nodes[0].agent.metadata["resolution"] == "fallback_ephemeral"
|
assert resolved.nodes[0].agent.metadata["resolution"] == "fallback_ephemeral"
|
||||||
assert reports[0].fallback_used is True
|
assert reports[0].fallback_used is True
|
||||||
|
|
||||||
|
|||||||
@ -45,6 +45,18 @@ class RecordingProvider(LLMProvider):
|
|||||||
return "stub-model"
|
return "stub-model"
|
||||||
|
|
||||||
|
|
||||||
|
class BlockingProvider(RecordingProvider):
|
||||||
|
def __init__(self, content: str, started: asyncio.Event, release: asyncio.Event) -> None:
|
||||||
|
super().__init__([_response(content)])
|
||||||
|
self.started = started
|
||||||
|
self.release = release
|
||||||
|
|
||||||
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
self.started.set()
|
||||||
|
await self.release.wait()
|
||||||
|
return await super().chat(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class StubSkillAssembler:
|
class StubSkillAssembler:
|
||||||
def __init__(self, activated_skills: list[SkillContext] | None = None) -> None:
|
def __init__(self, activated_skills: list[SkillContext] | None = None) -> None:
|
||||||
self.activated_skills = list(activated_skills or [])
|
self.activated_skills = list(activated_skills or [])
|
||||||
@ -153,6 +165,26 @@ def test_local_agent_runner_uses_shared_loop_and_records_parent_task(tmp_path: P
|
|||||||
assert child_session["parent_session_id"] == "session-root"
|
assert child_session["parent_session_id"] == "session-root"
|
||||||
|
|
||||||
|
|
||||||
|
def test_team_node_preserves_evidence_when_finish_reason_is_not_stop(tmp_path: Path) -> None:
|
||||||
|
loop = _loop(tmp_path)
|
||||||
|
provider = RecordingProvider([_response("partial evidence", finish_reason="max_tool_iterations")])
|
||||||
|
envelope = DelegationEnvelope(
|
||||||
|
parent_task_id="task-parent",
|
||||||
|
parent_session_id="session-root",
|
||||||
|
parent_run_id="run-root",
|
||||||
|
agent=AgentDescriptor(name="researcher", role="research"),
|
||||||
|
task="research the requested topic",
|
||||||
|
node_id="research",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(LocalAgentRunner(loop).run(envelope, provider_bundle=_bundle(provider)))
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.evidence is not None
|
||||||
|
assert result.evidence.output_text == "partial evidence"
|
||||||
|
assert result.evidence.finish_reason == "max_tool_iterations"
|
||||||
|
|
||||||
|
|
||||||
def test_pinned_skill_is_injected_into_delegated_run(tmp_path: Path) -> None:
|
def test_pinned_skill_is_injected_into_delegated_run(tmp_path: Path) -> None:
|
||||||
_publish_skill(
|
_publish_skill(
|
||||||
tmp_path,
|
tmp_path,
|
||||||
@ -278,6 +310,57 @@ def test_team_parallel_runs_all_nodes(tmp_path: Path) -> None:
|
|||||||
assert [item.output_text for item in result.node_results] == ["one", "two", "three"]
|
assert [item.output_text for item in result.node_results] == ["one", "two", "three"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_team_parallel_starts_nodes_concurrently_with_isolated_loops(tmp_path: Path) -> None:
|
||||||
|
loop = _loop(tmp_path)
|
||||||
|
first_started = asyncio.Event()
|
||||||
|
second_started = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
providers = {
|
||||||
|
"one": BlockingProvider("one", first_started, release),
|
||||||
|
"two": BlockingProvider("two", second_started, release),
|
||||||
|
}
|
||||||
|
graph = ExecutionGraph(
|
||||||
|
strategy="parallel",
|
||||||
|
nodes=[
|
||||||
|
ExecutionNode("one", "task one", AgentDescriptor(name="one")),
|
||||||
|
ExecutionNode("two", "task two", AgentDescriptor(name="two")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_case():
|
||||||
|
loop_task = asyncio.create_task(loop.run())
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task = asyncio.create_task(
|
||||||
|
TeamService(loop).run_team(
|
||||||
|
graph,
|
||||||
|
parent_task_id=None,
|
||||||
|
parent_session_id="session-root",
|
||||||
|
parent_run_id="run-root",
|
||||||
|
provider_bundle_factory=lambda node: _bundle(providers[node.node_id]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(first_started.wait(), timeout=1)
|
||||||
|
await asyncio.wait_for(second_started.wait(), timeout=1)
|
||||||
|
release.set()
|
||||||
|
return await task
|
||||||
|
finally:
|
||||||
|
release.set()
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
await loop.stop()
|
||||||
|
await loop_task
|
||||||
|
|
||||||
|
result = asyncio.run(run_case())
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert [item.node_id for item in result.node_results] == ["one", "two"]
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_node_factory_error_is_normalized_and_keeps_completed_runs(tmp_path: Path) -> None:
|
def test_parallel_node_factory_error_is_normalized_and_keeps_completed_runs(tmp_path: Path) -> None:
|
||||||
loop = _loop(tmp_path)
|
loop = _loop(tmp_path)
|
||||||
loaded = loop.boot()
|
loaded = loop.boot()
|
||||||
@ -438,7 +521,7 @@ def test_team_summary_lists_only_failed_nodes_when_all_nodes_fail(tmp_path: Path
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result.success is False
|
assert result.success is False
|
||||||
assert result.summary == "Failed nodes:\n- one: one down\n- two: two down"
|
assert result.summary == "Failed nodes:\n- one: one down evidence=no\n- two: two down evidence=no"
|
||||||
|
|
||||||
|
|
||||||
def test_graph_structure_errors_still_raise(tmp_path: Path) -> None:
|
def test_graph_structure_errors_still_raise(tmp_path: Path) -> None:
|
||||||
|
|||||||
@ -0,0 +1,84 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from beaver.interfaces.web.app import create_app
|
||||||
|
from beaver.services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_connection_api_creates_updates_lists_and_revokes(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
'{"agents": {"defaults": {"workspace": "%s"}}, "providers": {}}' % str(tmp_path),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
app = create_app(service=service, manage_service_lifecycle=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
created = client.post(
|
||||||
|
"/api/channel-connections",
|
||||||
|
json={
|
||||||
|
"kind": "telegram",
|
||||||
|
"mode": "polling",
|
||||||
|
"displayName": "Telegram Main",
|
||||||
|
"authType": "token",
|
||||||
|
"secrets": {"botToken": "token-1"},
|
||||||
|
"config": {"maxMessageChars": 4096, "requireMentionInGroups": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert created.status_code == 200
|
||||||
|
body = created.json()
|
||||||
|
connection_id = body["connection"]["connection_id"]
|
||||||
|
assert body["connection"]["kind"] == "telegram"
|
||||||
|
assert body["connection"]["status"] == "draft"
|
||||||
|
assert "credentials_ref" not in body["connection"]
|
||||||
|
assert body["connection"]["runtime_config"] == {
|
||||||
|
"max_message_chars": 4096,
|
||||||
|
"require_mention_in_groups": True,
|
||||||
|
}
|
||||||
|
assert body["credentials"] == {"botToken": "***"}
|
||||||
|
|
||||||
|
patched = client.patch(
|
||||||
|
f"/api/channel-connections/{connection_id}",
|
||||||
|
json={
|
||||||
|
"displayName": "Telegram Ops",
|
||||||
|
"config": {"maxMessageChars": 2048},
|
||||||
|
"secrets": {"botToken": "token-2"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert patched.status_code == 200
|
||||||
|
assert patched.json()["connection"]["display_name"] == "Telegram Ops"
|
||||||
|
assert patched.json()["connection"]["runtime_config"] == {"max_message_chars": 2048}
|
||||||
|
assert patched.json()["credentials"] == {"botToken": "***"}
|
||||||
|
|
||||||
|
listed = client.get("/api/channel-connections")
|
||||||
|
assert listed.status_code == 200
|
||||||
|
assert listed.json()[0]["connection_id"] == connection_id
|
||||||
|
assert "credentials_ref" not in listed.json()[0]
|
||||||
|
|
||||||
|
revoked = client.post(f"/api/channel-connections/{connection_id}/revoke")
|
||||||
|
assert revoked.status_code == 200
|
||||||
|
assert revoked.json()["connection"]["status"] == "revoked"
|
||||||
|
finally:
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_connectors_api_lists_registered_connectors(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
'{"agents": {"defaults": {"workspace": "%s"}}, "providers": {}}' % str(tmp_path),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
app = create_app(service=service, manage_service_lifecycle=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channel-connectors")
|
||||||
|
finally:
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == [{"kind": "feishu"}, {"kind": "telegram"}, {"kind": "weixin"}]
|
||||||
@ -0,0 +1,63 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from beaver.interfaces.channels.connections import (
|
||||||
|
ChannelConnectionStore,
|
||||||
|
CredentialStore,
|
||||||
|
PairingTokenStore,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_connection_store_creates_updates_lists_and_revokes(tmp_path) -> None:
|
||||||
|
store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||||
|
|
||||||
|
created = store.create(
|
||||||
|
kind="telegram",
|
||||||
|
mode="polling",
|
||||||
|
display_name="Telegram Main",
|
||||||
|
account_id="telegram:bot-main",
|
||||||
|
owner_user_id="user-1",
|
||||||
|
auth_type="token",
|
||||||
|
runtime_config={"max_message_chars": 4096},
|
||||||
|
capabilities=["receive_text", "send_text"],
|
||||||
|
)
|
||||||
|
updated = store.update_status(created.connection_id, status="connected", last_error=None)
|
||||||
|
revoked = store.revoke(created.connection_id)
|
||||||
|
|
||||||
|
assert created.connection_id
|
||||||
|
assert created.channel_id.startswith("telegram-")
|
||||||
|
assert created.status == "draft"
|
||||||
|
assert updated.status == "connected"
|
||||||
|
assert revoked.status == "revoked"
|
||||||
|
assert store.get(created.connection_id).status == "revoked"
|
||||||
|
assert [item.connection_id for item in store.list()] == [created.connection_id]
|
||||||
|
|
||||||
|
|
||||||
|
def test_credential_store_saves_values_by_reference_and_redacts_views(tmp_path) -> None:
|
||||||
|
store = CredentialStore(tmp_path / "credentials.json")
|
||||||
|
|
||||||
|
ref = store.put(kind="telegram", values={"botToken": "secret-token", "empty": ""})
|
||||||
|
|
||||||
|
assert ref.startswith("cred_")
|
||||||
|
assert store.get(ref) == {"botToken": "secret-token"}
|
||||||
|
assert store.redacted(ref) == {"botToken": "***"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_pairing_token_store_uses_one_time_expiring_tokens(tmp_path) -> None:
|
||||||
|
store = PairingTokenStore(tmp_path / "pairing.json")
|
||||||
|
|
||||||
|
session = store.create(kind="terminal", ttl_seconds=60, scope="channel:pair")
|
||||||
|
consumed = store.consume(session.token, expected_kind="terminal")
|
||||||
|
reused = store.consume(session.token, expected_kind="terminal")
|
||||||
|
|
||||||
|
assert session.status == "pending"
|
||||||
|
assert consumed is not None
|
||||||
|
assert consumed.status == "consumed"
|
||||||
|
assert reused is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_pairing_token_store_rejects_expired_tokens(tmp_path) -> None:
|
||||||
|
store = PairingTokenStore(tmp_path / "pairing.json")
|
||||||
|
|
||||||
|
session = store.create(kind="weixin", ttl_seconds=-1, scope="channel:pair")
|
||||||
|
|
||||||
|
assert store.consume(session.token, expected_kind="weixin") is None
|
||||||
@ -0,0 +1,164 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from beaver.foundation.config.schema import ChannelConfig
|
||||||
|
from beaver.interfaces.channels.connections import (
|
||||||
|
ChannelConnectionStore,
|
||||||
|
ChannelConnectorRegistry,
|
||||||
|
ChannelRuntimeSpec,
|
||||||
|
CredentialStore,
|
||||||
|
ValidationResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeConnector:
|
||||||
|
kind = "fake"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.validated: list[str] = []
|
||||||
|
self.revoked: list[str] = []
|
||||||
|
|
||||||
|
async def validate(self, connection_id: str) -> ValidationResult:
|
||||||
|
self.validated.append(connection_id)
|
||||||
|
return ValidationResult(ok=True, status="connected", account_id="fake-account", display_name="Fake")
|
||||||
|
|
||||||
|
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||||
|
return ChannelRuntimeSpec(
|
||||||
|
channel_id="fake-channel",
|
||||||
|
kind="fake",
|
||||||
|
mode="webhook",
|
||||||
|
account_id="fake-account",
|
||||||
|
display_name="Fake",
|
||||||
|
config={"enabled": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke(self, connection_id: str) -> None:
|
||||||
|
self.revoked.append(connection_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def test_connector_registry_dispatches_by_kind(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||||
|
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||||
|
connector = FakeConnector()
|
||||||
|
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||||
|
registry.register(connector)
|
||||||
|
|
||||||
|
connection = connection_store.create(
|
||||||
|
kind="fake",
|
||||||
|
mode="webhook",
|
||||||
|
display_name="Fake",
|
||||||
|
account_id="fake-account",
|
||||||
|
owner_user_id=None,
|
||||||
|
auth_type="token",
|
||||||
|
)
|
||||||
|
result = await registry.validate(connection.connection_id)
|
||||||
|
spec = await registry.materialize_runtime(connection.connection_id)
|
||||||
|
|
||||||
|
assert result.ok is True
|
||||||
|
assert connector.validated == [connection.connection_id]
|
||||||
|
assert spec.channel_id == "fake-channel"
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_connector_registry_materializes_channel_configs_with_credentials(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||||
|
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||||
|
credentials_ref = credential_store.put(kind="telegram", values={"botToken": "token-1"})
|
||||||
|
connection = connection_store.create(
|
||||||
|
kind="fake",
|
||||||
|
mode="webhook",
|
||||||
|
display_name="Connected",
|
||||||
|
account_id="connected",
|
||||||
|
owner_user_id=None,
|
||||||
|
auth_type="token",
|
||||||
|
credentials_ref=credentials_ref,
|
||||||
|
)
|
||||||
|
connection_store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||||
|
|
||||||
|
class CredentialAwareConnector(FakeConnector):
|
||||||
|
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||||
|
stored = connection_store.get(connection_id)
|
||||||
|
return ChannelRuntimeSpec(
|
||||||
|
channel_id="fake-channel",
|
||||||
|
kind="fake",
|
||||||
|
mode="webhook",
|
||||||
|
account_id="fake-account",
|
||||||
|
display_name="Fake",
|
||||||
|
config={"enabled": True},
|
||||||
|
secrets_ref=stored.credentials_ref,
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||||
|
registry.register(CredentialAwareConnector())
|
||||||
|
|
||||||
|
configs = await registry.materialize_channel_configs()
|
||||||
|
|
||||||
|
assert isinstance(configs["fake-channel"], ChannelConfig)
|
||||||
|
assert configs["fake-channel"].enabled is True
|
||||||
|
assert configs["fake-channel"].secrets == {"botToken": "token-1"}
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_connector_registry_materializes_only_connected_connections(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||||
|
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||||
|
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||||
|
registry.register(FakeConnector())
|
||||||
|
|
||||||
|
draft = connection_store.create(
|
||||||
|
kind="fake",
|
||||||
|
mode="webhook",
|
||||||
|
display_name="Draft",
|
||||||
|
account_id="draft",
|
||||||
|
owner_user_id=None,
|
||||||
|
auth_type="token",
|
||||||
|
)
|
||||||
|
connected = connection_store.create(
|
||||||
|
kind="fake",
|
||||||
|
mode="webhook",
|
||||||
|
display_name="Connected",
|
||||||
|
account_id="connected",
|
||||||
|
owner_user_id=None,
|
||||||
|
auth_type="token",
|
||||||
|
)
|
||||||
|
connection_store.update_status(connected.connection_id, status="connected", last_error=None)
|
||||||
|
|
||||||
|
specs = await registry.materialize_connected_runtime_specs()
|
||||||
|
|
||||||
|
assert [spec.channel_id for spec in specs] == ["fake-channel"]
|
||||||
|
assert connection_store.get(draft.connection_id).status == "draft"
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_connector_registry_revoke_calls_connector_and_updates_store(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||||
|
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||||
|
connector = FakeConnector()
|
||||||
|
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||||
|
registry.register(connector)
|
||||||
|
|
||||||
|
connection = connection_store.create(
|
||||||
|
kind="fake",
|
||||||
|
mode="webhook",
|
||||||
|
display_name="Fake",
|
||||||
|
account_id="fake-account",
|
||||||
|
owner_user_id=None,
|
||||||
|
auth_type="token",
|
||||||
|
)
|
||||||
|
connection_store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||||
|
|
||||||
|
await registry.revoke(connection.connection_id)
|
||||||
|
|
||||||
|
assert connector.revoked == [connection.connection_id]
|
||||||
|
assert connection_store.get(connection.connection_id).status == "revoked"
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
414
app-instance/backend/tests/unit/test_channel_runtime.py
Normal file
414
app-instance/backend/tests/unit/test_channel_runtime.py
Normal file
@ -0,0 +1,414 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from beaver.foundation.config.schema import ChannelConfig
|
||||||
|
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||||
|
from beaver.foundation.events import MessageBus
|
||||||
|
from beaver.interfaces.channels.generic_webhook import GenericWebhookAdapter
|
||||||
|
from beaver.interfaces.channels.runtime import ChannelRuntime
|
||||||
|
from beaver.interfaces.channels.state import ChannelDedupeStore, ChannelEventLog
|
||||||
|
from beaver.interfaces.web.app import _self_restart_enabled, create_app
|
||||||
|
from beaver.services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_identity_builds_stable_session_id() -> None:
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id="webhook-dev",
|
||||||
|
kind="webhook",
|
||||||
|
account_id="local",
|
||||||
|
peer_id="demo-user",
|
||||||
|
thread_id="main",
|
||||||
|
peer_type="dm",
|
||||||
|
message_id="msg-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert identity.session_id() == "webhook-dev:local:demo-user:main"
|
||||||
|
assert identity.dedupe_key() == "webhook-dev:local:demo-user:main:msg-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_identity_requires_routing_fields() -> None:
|
||||||
|
identity = ChannelIdentity(channel_id="webhook-dev", kind="webhook", account_id="", peer_id="demo")
|
||||||
|
|
||||||
|
assert identity.validation_error() == "account_id is required"
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_carry_channel_identity() -> None:
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id="webhook-dev",
|
||||||
|
kind="webhook",
|
||||||
|
account_id="local",
|
||||||
|
peer_id="demo-user",
|
||||||
|
message_id="msg-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
inbound = InboundMessage(channel="webhook-dev", content="hello", channel_identity=identity)
|
||||||
|
outbound = OutboundMessage(
|
||||||
|
channel="webhook-dev",
|
||||||
|
content="ok",
|
||||||
|
session_id=identity.session_id(),
|
||||||
|
finish_reason="stop",
|
||||||
|
channel_identity=identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert inbound.channel_identity is identity
|
||||||
|
assert outbound.channel_identity is identity
|
||||||
|
|
||||||
|
|
||||||
|
def test_dedupe_store_tracks_processing_and_done(tmp_path) -> None:
|
||||||
|
store = ChannelDedupeStore(tmp_path / "dedupe.json", retention_hours=48)
|
||||||
|
|
||||||
|
created = store.mark_processing(
|
||||||
|
dedupe_key="webhook-dev:local:demo:msg-1",
|
||||||
|
session_id="webhook-dev:local:demo",
|
||||||
|
message_id="msg-1",
|
||||||
|
)
|
||||||
|
duplicate = store.mark_processing(
|
||||||
|
dedupe_key="webhook-dev:local:demo:msg-1",
|
||||||
|
session_id="webhook-dev:local:demo",
|
||||||
|
message_id="msg-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert created.created is True
|
||||||
|
assert duplicate.created is False
|
||||||
|
assert duplicate.record is not None
|
||||||
|
assert duplicate.record["status"] == "processing"
|
||||||
|
|
||||||
|
store.mark_done(
|
||||||
|
dedupe_key="webhook-dev:local:demo:msg-1",
|
||||||
|
run_id="run-1",
|
||||||
|
reply="hello" * 10000,
|
||||||
|
max_reply_chars=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
done = store.get("webhook-dev:local:demo:msg-1")
|
||||||
|
assert done is not None
|
||||||
|
assert done["status"] == "done"
|
||||||
|
assert done["reply"] == "hellohellohellohello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_event_log_writes_recent_events(tmp_path) -> None:
|
||||||
|
log = ChannelEventLog(tmp_path / "events.jsonl")
|
||||||
|
log.record(
|
||||||
|
channel_id="webhook-dev",
|
||||||
|
kind="inbound_accepted",
|
||||||
|
session_id="webhook-dev:local:demo",
|
||||||
|
message_id="msg-1",
|
||||||
|
status="ok",
|
||||||
|
text="hello world",
|
||||||
|
)
|
||||||
|
|
||||||
|
events = log.recent(channel_id="webhook-dev", limit=10)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0]["kind"] == "inbound_accepted"
|
||||||
|
assert events[0]["text_preview"] == "hello world"
|
||||||
|
assert "raw_channel_payload" not in json.dumps(events[0])
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAgentService:
|
||||||
|
is_running = True
|
||||||
|
|
||||||
|
async def handle_inbound_message(self, inbound):
|
||||||
|
return OutboundMessage(
|
||||||
|
message_id=inbound.message_id,
|
||||||
|
channel=inbound.channel,
|
||||||
|
content=f"echo:{inbound.content}",
|
||||||
|
session_id=inbound.session_id,
|
||||||
|
finish_reason="stop",
|
||||||
|
run_id="run-1",
|
||||||
|
channel_identity=inbound.channel_identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SlowFakeAgentService(FakeAgentService):
|
||||||
|
async def handle_inbound_message(self, inbound):
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
return await super().handle_inbound_message(inbound)
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_runtime_accept_inbound_normalizes_session_and_dedupes(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
runtime = ChannelRuntime(
|
||||||
|
service=FakeAgentService(),
|
||||||
|
bus=bus,
|
||||||
|
workspace=tmp_path,
|
||||||
|
channels={},
|
||||||
|
)
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id="webhook-dev",
|
||||||
|
kind="webhook",
|
||||||
|
account_id="local",
|
||||||
|
peer_id="demo",
|
||||||
|
message_id="msg-1",
|
||||||
|
)
|
||||||
|
result = await runtime.accept_inbound(
|
||||||
|
InboundMessage(
|
||||||
|
channel="webhook-dev",
|
||||||
|
content="hello",
|
||||||
|
session_id="wrong",
|
||||||
|
channel_identity=identity,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
duplicate = await runtime.accept_inbound(
|
||||||
|
InboundMessage(
|
||||||
|
channel="webhook-dev",
|
||||||
|
content="hello",
|
||||||
|
channel_identity=identity,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
queued = await bus.consume_inbound()
|
||||||
|
assert result.accepted is True
|
||||||
|
assert queued.session_id == "webhook-dev:local:demo"
|
||||||
|
assert duplicate.accepted is False
|
||||||
|
assert duplicate.duplicate is True
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_generic_webhook_adapter_waits_for_outbound_reply(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
runtime = ChannelRuntime(
|
||||||
|
service=FakeAgentService(),
|
||||||
|
bus=bus,
|
||||||
|
workspace=tmp_path,
|
||||||
|
channels={},
|
||||||
|
)
|
||||||
|
adapter = GenericWebhookAdapter(
|
||||||
|
channel_id="webhook-dev",
|
||||||
|
kind="webhook",
|
||||||
|
mode="webhook",
|
||||||
|
account_id="local",
|
||||||
|
display_name="Webhook Dev",
|
||||||
|
inbound_sink=runtime,
|
||||||
|
response_timeout_seconds=1,
|
||||||
|
)
|
||||||
|
runtime.manager.register(adapter)
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
response = await adapter.handle_webhook_payload(
|
||||||
|
{
|
||||||
|
"peer_id": "demo",
|
||||||
|
"message_id": "msg-1",
|
||||||
|
"text": "hello",
|
||||||
|
"peer_type": "dm",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
assert response["ok"] is True
|
||||||
|
assert response["reply"] == "echo:hello"
|
||||||
|
assert response["session_id"] == "webhook-dev:local:demo"
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_generic_webhook_records_unclaimed_outbound_after_timeout(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
runtime = ChannelRuntime(
|
||||||
|
service=SlowFakeAgentService(),
|
||||||
|
bus=bus,
|
||||||
|
workspace=tmp_path,
|
||||||
|
channels={},
|
||||||
|
)
|
||||||
|
adapter = GenericWebhookAdapter(
|
||||||
|
channel_id="webhook-dev",
|
||||||
|
kind="webhook",
|
||||||
|
mode="webhook",
|
||||||
|
account_id="local",
|
||||||
|
display_name="Webhook Dev",
|
||||||
|
inbound_sink=runtime,
|
||||||
|
response_timeout_seconds=1,
|
||||||
|
)
|
||||||
|
adapter.response_timeout_seconds = 0.01
|
||||||
|
runtime.manager.register(adapter)
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
response = await adapter.handle_webhook_payload(
|
||||||
|
{
|
||||||
|
"peer_id": "demo",
|
||||||
|
"message_id": "msg-1",
|
||||||
|
"text": "hello",
|
||||||
|
"peer_type": "dm",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
events = runtime.recent_events("webhook-dev", limit=20)
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
assert response["pending"] is True
|
||||||
|
assert any(event["kind"] == "outbound_unclaimed" for event in events)
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_runtime_starts_enabled_generic_webhook_and_reports_status(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
runtime = ChannelRuntime(
|
||||||
|
service=FakeAgentService(),
|
||||||
|
workspace=tmp_path,
|
||||||
|
channels={
|
||||||
|
"webhook-dev": ChannelConfig(
|
||||||
|
enabled=True,
|
||||||
|
kind="webhook",
|
||||||
|
mode="webhook",
|
||||||
|
account_id="local",
|
||||||
|
display_name="Webhook Dev",
|
||||||
|
config={"response_timeout_seconds": 1800},
|
||||||
|
),
|
||||||
|
"off": ChannelConfig(
|
||||||
|
enabled=False,
|
||||||
|
kind="webhook",
|
||||||
|
mode="webhook",
|
||||||
|
account_id="local",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
statuses = runtime.statuses()
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
by_id = {item["channel_id"]: item for item in statuses}
|
||||||
|
assert by_id["webhook-dev"]["state"] == "running"
|
||||||
|
assert by_id["webhook-dev"]["webhook_url"] == "/api/channels/webhook-dev/webhook"
|
||||||
|
assert by_id["off"]["state"] == "disabled"
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_runtime_builds_platform_adapters_without_starting_networks(tmp_path) -> None:
|
||||||
|
runtime = ChannelRuntime(
|
||||||
|
service=FakeAgentService(),
|
||||||
|
workspace=tmp_path,
|
||||||
|
channels={},
|
||||||
|
)
|
||||||
|
|
||||||
|
cases = {
|
||||||
|
"telegram-main": ChannelConfig(enabled=True, kind="telegram", mode="polling", account_id="bot-main"),
|
||||||
|
"feishu-main": ChannelConfig(enabled=True, kind="feishu", mode="websocket", account_id="tenant-main"),
|
||||||
|
"qq-main": ChannelConfig(enabled=True, kind="qqbot", mode="websocket", account_id="qq-main"),
|
||||||
|
"weixin-main": ChannelConfig(enabled=True, kind="weixin", mode="polling", account_id="wx-main"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for channel_id, cfg in cases.items():
|
||||||
|
adapter = runtime._build_adapter(channel_id, cfg)
|
||||||
|
assert adapter.channel_id == channel_id
|
||||||
|
assert adapter.kind == cfg.kind
|
||||||
|
assert adapter.mode == cfg.mode
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_runtime_reports_platform_capabilities(tmp_path) -> None:
|
||||||
|
runtime = ChannelRuntime(
|
||||||
|
service=FakeAgentService(),
|
||||||
|
workspace=tmp_path,
|
||||||
|
channels={
|
||||||
|
"telegram-main": ChannelConfig(enabled=True, kind="telegram", mode="polling", account_id="bot-main"),
|
||||||
|
"weixin-main": ChannelConfig(enabled=True, kind="weixin", mode="polling", account_id="wx-main"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
by_id = {item["channel_id"]: item for item in runtime.statuses()}
|
||||||
|
|
||||||
|
assert by_id["telegram-main"]["capabilities"] == [
|
||||||
|
"receive_text",
|
||||||
|
"send_text",
|
||||||
|
"receive_media",
|
||||||
|
"groups",
|
||||||
|
]
|
||||||
|
assert by_id["weixin-main"]["capabilities"] == [
|
||||||
|
"receive_text",
|
||||||
|
"send_text",
|
||||||
|
"receive_media",
|
||||||
|
"direct_messages",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_runtime_platform_start_failure_does_not_stop_other_channels(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
runtime = ChannelRuntime(
|
||||||
|
service=FakeAgentService(),
|
||||||
|
workspace=tmp_path,
|
||||||
|
channels={
|
||||||
|
"telegram-main": ChannelConfig(
|
||||||
|
enabled=True,
|
||||||
|
kind="telegram",
|
||||||
|
mode="polling",
|
||||||
|
account_id="bot-main",
|
||||||
|
secrets={},
|
||||||
|
),
|
||||||
|
"off": ChannelConfig(
|
||||||
|
enabled=False,
|
||||||
|
kind="weixin",
|
||||||
|
mode="polling",
|
||||||
|
account_id="wx-main",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
by_id = {item["channel_id"]: item for item in runtime.statuses()}
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
assert by_id["telegram-main"]["state"] == "error"
|
||||||
|
assert "botToken" in by_id["telegram-main"]["last_error"]
|
||||||
|
assert by_id["off"]["state"] == "disabled"
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_web_app_status_exposes_configured_channel(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
workspace.mkdir()
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"workspace": str(workspace), "model": "openai/gpt-5"}},
|
||||||
|
"providers": {},
|
||||||
|
"channels": {
|
||||||
|
"webhook-dev": {
|
||||||
|
"enabled": True,
|
||||||
|
"kind": "webhook",
|
||||||
|
"mode": "webhook",
|
||||||
|
"accountId": "local",
|
||||||
|
"displayName": "Webhook Dev",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
app = create_app(service=service, manage_service_lifecycle=False)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
payload = client.get("/api/status").json()
|
||||||
|
|
||||||
|
service.close()
|
||||||
|
assert payload["channels"][0]["channel_id"] == "webhook-dev"
|
||||||
|
assert payload["channels"][0]["state"] == "running"
|
||||||
|
assert payload["channels"][0]["webhook_url"] == "/api/channels/webhook-dev/webhook"
|
||||||
|
assert payload["runtime_controls"]["self_restart"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_restart_env_defaults_enabled(monkeypatch) -> None:
|
||||||
|
monkeypatch.delenv("BEAVER_ENABLE_SELF_RESTART", raising=False)
|
||||||
|
|
||||||
|
assert _self_restart_enabled() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_self_restart_env_can_disable(monkeypatch) -> None:
|
||||||
|
monkeypatch.setenv("BEAVER_ENABLE_SELF_RESTART", "0")
|
||||||
|
|
||||||
|
assert _self_restart_enabled() is False
|
||||||
@ -0,0 +1,119 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from beaver.foundation.config.schema import ChannelConfig
|
||||||
|
from beaver.foundation.events import MessageBus, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.runtime import ChannelRuntime
|
||||||
|
|
||||||
|
|
||||||
|
class FakeService:
|
||||||
|
async def handle_inbound_message(self, inbound):
|
||||||
|
return OutboundMessage(channel=inbound.channel, content="ok", session_id=inbound.session_id, finish_reason="stop")
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_add_channel_starts_new_channel_after_runtime_start(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
await runtime.add_channel(
|
||||||
|
"webhook-dev",
|
||||||
|
ChannelConfig(enabled=True, kind="webhook", mode="webhook", account_id="acct"),
|
||||||
|
)
|
||||||
|
assert "webhook-dev" in runtime.adapters
|
||||||
|
assert runtime.states["webhook-dev"]["state"] == "running"
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_add_channel_noops_for_same_config(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
cfg = ChannelConfig(enabled=True, kind="webhook", mode="webhook", account_id="acct")
|
||||||
|
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
await runtime.add_channel("webhook-dev", cfg)
|
||||||
|
first = runtime.adapters["webhook-dev"]
|
||||||
|
await runtime.add_channel("webhook-dev", cfg)
|
||||||
|
assert runtime.adapters["webhook-dev"] is first
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_replacement_failure_keeps_old_channel(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
good = ChannelConfig(enabled=True, kind="webhook", mode="webhook", account_id="acct")
|
||||||
|
bad = ChannelConfig(enabled=True, kind="missing", mode="http", account_id="acct")
|
||||||
|
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
await runtime.add_channel("webhook-dev", good)
|
||||||
|
old = runtime.adapters["webhook-dev"]
|
||||||
|
try:
|
||||||
|
await runtime.add_channel("webhook-dev", bad)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise AssertionError("Expected ValueError")
|
||||||
|
assert runtime.adapters["webhook-dev"] is old
|
||||||
|
assert runtime.channel_configs["webhook-dev"] == good
|
||||||
|
assert runtime.states["webhook-dev"]["state"] == "running"
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_remove_channel_stops_and_unregisters(tmp_path) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
await runtime.add_channel(
|
||||||
|
"webhook-dev",
|
||||||
|
ChannelConfig(enabled=True, kind="webhook", mode="webhook", account_id="acct"),
|
||||||
|
)
|
||||||
|
await runtime.remove_channel("webhook-dev")
|
||||||
|
assert "webhook-dev" not in runtime.adapters
|
||||||
|
assert "webhook-dev" not in runtime.manager.channels
|
||||||
|
assert runtime.states["webhook-dev"]["state"] == "removed"
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_builds_external_connector_channel(tmp_path, monkeypatch) -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
monkeypatch.setenv("EXTERNAL_CONNECTOR_TOKEN", "connector-token")
|
||||||
|
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||||
|
await runtime.start()
|
||||||
|
try:
|
||||||
|
await runtime.add_channel(
|
||||||
|
"weixin-main",
|
||||||
|
ChannelConfig(
|
||||||
|
enabled=True,
|
||||||
|
kind="external_connector",
|
||||||
|
mode="http",
|
||||||
|
account_id="weixin:me",
|
||||||
|
display_name="Weixin Main",
|
||||||
|
config={
|
||||||
|
"platformKind": "weixin",
|
||||||
|
"connectionId": "conn_1",
|
||||||
|
"sidecarBaseUrl": "http://external-connector:8787",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
adapter = runtime.adapters["weixin-main"]
|
||||||
|
assert adapter.kind == "external_connector"
|
||||||
|
assert adapter.mode == "http"
|
||||||
|
assert getattr(adapter, "platform_kind") == "weixin"
|
||||||
|
finally:
|
||||||
|
await runtime.stop()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
@ -1,9 +1,14 @@
|
|||||||
import json
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from beaver.engine import AgentLoop, EngineLoader
|
from beaver.engine import AgentLoop, EngineLoader
|
||||||
from beaver.engine.providers import make_provider_bundle
|
from beaver.engine.providers import make_provider_bundle
|
||||||
from beaver.engine.providers.litellm import LiteLLMProvider
|
from beaver.engine.providers.litellm import LiteLLMProvider
|
||||||
from beaver.foundation.config import load_config
|
from beaver.foundation.config import load_config
|
||||||
|
from beaver.interfaces.web.app import create_app, _reload_agent_config
|
||||||
|
from beaver.services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
def test_load_config_reads_current_instance_shape(tmp_path) -> None:
|
def test_load_config_reads_current_instance_shape(tmp_path) -> None:
|
||||||
@ -42,6 +47,44 @@ def test_load_config_reads_current_instance_shape(tmp_path) -> None:
|
|||||||
assert target["extra_headers"] == {"X-Test": "1"}
|
assert target["extra_headers"] == {"X-Test": "1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_loader_reads_channels(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"model": "openai/gpt-5"}},
|
||||||
|
"channels": {
|
||||||
|
"webhook-dev": {
|
||||||
|
"enabled": True,
|
||||||
|
"kind": "webhook",
|
||||||
|
"mode": "webhook",
|
||||||
|
"accountId": "local",
|
||||||
|
"displayName": "Webhook Dev",
|
||||||
|
"config": {
|
||||||
|
"responseTimeoutSeconds": 1800,
|
||||||
|
"dedupeRetentionHours": 48,
|
||||||
|
},
|
||||||
|
"secrets": {"ignored_for_status": "secret-value"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(config_path=config_path)
|
||||||
|
|
||||||
|
channel = config.channels["webhook-dev"]
|
||||||
|
assert channel.enabled is True
|
||||||
|
assert channel.kind == "webhook"
|
||||||
|
assert channel.mode == "webhook"
|
||||||
|
assert channel.account_id == "local"
|
||||||
|
assert channel.display_name == "Webhook Dev"
|
||||||
|
assert channel.config["response_timeout_seconds"] == 1800
|
||||||
|
assert channel.config["dedupe_retention_hours"] == 48
|
||||||
|
assert channel.secrets == {"ignored_for_status": "secret-value"}
|
||||||
|
|
||||||
|
|
||||||
def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None:
|
def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None:
|
||||||
config_path = tmp_path / "config.json"
|
config_path = tmp_path / "config.json"
|
||||||
config_path.write_text(
|
config_path.write_text(
|
||||||
@ -124,6 +167,236 @@ def test_agent_loop_config_drives_provider_bundle(tmp_path) -> None:
|
|||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_reload_agent_config_updates_booted_loop_config(tmp_path) -> None:
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"workspace": str(workspace), "model": "old-model"}},
|
||||||
|
"providers": {"openai": {"apiKey": "sk-test", "apiBase": "https://old.example.com/v1"}},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
loaded = service.create_loop().boot()
|
||||||
|
assert loaded.config.default_model == "old-model"
|
||||||
|
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"workspace": str(workspace), "model": "new-model"}},
|
||||||
|
"providers": {"openai": {"apiKey": "sk-test", "apiBase": "https://new.example.com/v1"}},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
_reload_agent_config(service, config_path)
|
||||||
|
|
||||||
|
target = service.create_loop().boot().config.resolve_provider_target()
|
||||||
|
assert target["model"] == "new-model"
|
||||||
|
assert target["api_base"] == "https://new.example.com/v1"
|
||||||
|
assert target["api_key"] == "sk-test"
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_reload_agent_config_keeps_running_service_when_old_mcp_close_fails(tmp_path) -> None:
|
||||||
|
async def run_case() -> None:
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"workspace": str(workspace), "model": "old-model"}},
|
||||||
|
"providers": {"openai": {"apiKey": "sk-test", "apiBase": "https://old.example.com/v1"}},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
await service.start()
|
||||||
|
|
||||||
|
class FailingMCPManager:
|
||||||
|
async def close(self) -> None:
|
||||||
|
raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
|
||||||
|
|
||||||
|
loaded = service.create_loop().boot()
|
||||||
|
loaded.mcp_manager = FailingMCPManager()
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"workspace": str(workspace), "model": "new-model"}},
|
||||||
|
"providers": {"openai": {"apiKey": "sk-test", "apiBase": "https://new.example.com/v1"}},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
unhandled: list[dict[str, object]] = []
|
||||||
|
previous_handler = loop.get_exception_handler()
|
||||||
|
loop.set_exception_handler(lambda _loop, context: unhandled.append(context))
|
||||||
|
try:
|
||||||
|
_reload_agent_config(service, config_path)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
target = service.create_loop().boot().config.resolve_provider_target()
|
||||||
|
assert service.is_running is True
|
||||||
|
assert target["model"] == "new-model"
|
||||||
|
assert target["api_base"] == "https://new.example.com/v1"
|
||||||
|
assert unhandled == []
|
||||||
|
finally:
|
||||||
|
loop.set_exception_handler(previous_handler)
|
||||||
|
await service.shutdown(force=True)
|
||||||
|
|
||||||
|
asyncio.run(run_case())
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_defaults_include_runtime_controls(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"maxTokens": 12345,
|
||||||
|
"temperature": 0.4,
|
||||||
|
"maxToolIterations": 9,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(config_path=config_path)
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
|
||||||
|
assert config.agents_defaults.max_tokens == 12345
|
||||||
|
assert config.agents_defaults.temperature == 0.4
|
||||||
|
assert config.agents_defaults.max_tool_iterations == 9
|
||||||
|
assert service.profile.max_tokens == 12345
|
||||||
|
assert service.profile.temperature == 0.4
|
||||||
|
assert service.profile.max_tool_iterations == 9
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_config_api_persists_and_reloads_defaults(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(json.dumps({"agents": {"defaults": {}}}), encoding="utf-8")
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
app = create_app(service=service, manage_service_lifecycle=False)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/api/agent-config",
|
||||||
|
json={"max_tokens": 8192, "temperature": 0.6, "max_tool_iterations": 12},
|
||||||
|
)
|
||||||
|
status = client.get("/api/status")
|
||||||
|
|
||||||
|
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
defaults = saved["agents"]["defaults"]
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"ok": True}
|
||||||
|
assert defaults["maxTokens"] == 8192
|
||||||
|
assert defaults["temperature"] == 0.6
|
||||||
|
assert defaults["maxToolIterations"] == 12
|
||||||
|
assert service.profile.max_tokens == 8192
|
||||||
|
assert service.profile.temperature == 0.6
|
||||||
|
assert service.profile.max_tool_iterations == 12
|
||||||
|
assert status.json()["max_tokens"] == 8192
|
||||||
|
assert status.json()["temperature"] == 0.6
|
||||||
|
assert status.json()["max_tool_iterations"] == 12
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_config_api_accepts_zero_temperature_and_iterations(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
app = create_app(service=service, manage_service_lifecycle=False)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post(
|
||||||
|
"/api/agent-config",
|
||||||
|
json={"max_tokens": None, "temperature": 0, "max_tool_iterations": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
config = load_config(config_path=config_path)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert config.agents_defaults.max_tokens is None
|
||||||
|
assert config.agents_defaults.temperature == 0
|
||||||
|
assert config.agents_defaults.max_tool_iterations == 0
|
||||||
|
assert service.profile.max_tokens is None
|
||||||
|
assert service.profile.temperature == 0
|
||||||
|
assert service.profile.max_tool_iterations == 0
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_config_api_persists_and_masks_secrets(tmp_path) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agents": {"defaults": {"model": "openai/gpt-5"}},
|
||||||
|
"channels": {
|
||||||
|
"telegram-main": {
|
||||||
|
"enabled": False,
|
||||||
|
"kind": "telegram",
|
||||||
|
"mode": "polling",
|
||||||
|
"accountId": "bot-main",
|
||||||
|
"displayName": "Telegram Main",
|
||||||
|
"secrets": {"botToken": "1234567890abcdef"},
|
||||||
|
"config": {"requireMentionInGroups": True},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
app = create_app(service=service, manage_service_lifecycle=False)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
before = client.get("/api/channels/telegram-main/config")
|
||||||
|
response = client.post(
|
||||||
|
"/api/channels/telegram-main/config",
|
||||||
|
json={
|
||||||
|
"enabled": True,
|
||||||
|
"kind": "telegram",
|
||||||
|
"mode": "polling",
|
||||||
|
"account_id": "bot-main",
|
||||||
|
"display_name": "Telegram Primary",
|
||||||
|
"secrets": {"botToken": ""},
|
||||||
|
"config": {
|
||||||
|
"requireMentionInGroups": False,
|
||||||
|
"allowFrom": ["1001", "1002"],
|
||||||
|
"maxMessageChars": 3000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||||
|
channel = saved["channels"]["telegram-main"]
|
||||||
|
|
||||||
|
assert before.status_code == 200
|
||||||
|
assert before.json()["secrets"] == {"botToken": "1234••••cdef"}
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["ok"] is True
|
||||||
|
assert response.json()["restart_required"] is True
|
||||||
|
assert response.json()["channel"]["display_name"] == "Telegram Primary"
|
||||||
|
assert response.json()["channel"]["secrets"] == {"botToken": "1234••••cdef"}
|
||||||
|
assert channel["enabled"] is True
|
||||||
|
assert channel["displayName"] == "Telegram Primary"
|
||||||
|
assert channel["secrets"]["botToken"] == "1234567890abcdef"
|
||||||
|
assert channel["config"]["allowFrom"] == ["1001", "1002"]
|
||||||
|
assert load_config(config_path=config_path).channels["telegram-main"].enabled is True
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
def test_openai_compatible_qwen_config_keeps_openai_provider() -> None:
|
def test_openai_compatible_qwen_config_keeps_openai_provider() -> None:
|
||||||
bundle = make_provider_bundle(
|
bundle = make_provider_bundle(
|
||||||
model="qwen-plus",
|
model="qwen-plus",
|
||||||
|
|||||||
@ -0,0 +1,51 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from beaver.interfaces.channels.connections import MessageDedupeStore
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dedupe_store_completes_and_dedupes_completed(tmp_path) -> None:
|
||||||
|
store = MessageDedupeStore(tmp_path / "message_dedupe.json")
|
||||||
|
|
||||||
|
first = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=1)
|
||||||
|
store.complete(first.dedupe_key, message_id="msg_1")
|
||||||
|
duplicate = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=2)
|
||||||
|
|
||||||
|
assert first.should_process is True
|
||||||
|
assert duplicate.should_process is False
|
||||||
|
assert duplicate.status == "completed"
|
||||||
|
assert duplicate.http_status == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dedupe_store_returns_conflict_for_active_processing(tmp_path) -> None:
|
||||||
|
store = MessageDedupeStore(tmp_path / "message_dedupe.json", processing_ttl_seconds=60)
|
||||||
|
|
||||||
|
store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=1)
|
||||||
|
duplicate = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=2)
|
||||||
|
|
||||||
|
assert duplicate.should_process is False
|
||||||
|
assert duplicate.status == "processing"
|
||||||
|
assert duplicate.http_status == 409
|
||||||
|
assert duplicate.retry_after_seconds == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dedupe_store_reprocesses_stale_processing(tmp_path) -> None:
|
||||||
|
store = MessageDedupeStore(tmp_path / "message_dedupe.json", processing_ttl_seconds=0)
|
||||||
|
|
||||||
|
store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=1)
|
||||||
|
stale = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=2)
|
||||||
|
|
||||||
|
assert stale.should_process is True
|
||||||
|
assert stale.status == "processing"
|
||||||
|
assert stale.record.delivery_attempts == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dedupe_store_reprocesses_failed_records(tmp_path) -> None:
|
||||||
|
store = MessageDedupeStore(tmp_path / "message_dedupe.json")
|
||||||
|
|
||||||
|
first = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=1)
|
||||||
|
store.fail(first.dedupe_key, error="runtime rejected")
|
||||||
|
retry = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=2)
|
||||||
|
|
||||||
|
assert retry.should_process is True
|
||||||
|
assert retry.record.delivery_attempts == 2
|
||||||
|
assert retry.record.last_error is None
|
||||||
51
app-instance/backend/tests/unit/test_context_builder.py
Normal file
51
app-instance/backend/tests/unit/test_context_builder.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from beaver.engine.context import ContextBuildInput, ContextBuilder, RuntimeContext, SessionContext
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_builder_injects_current_date_and_time() -> None:
|
||||||
|
result = ContextBuilder().build_messages(
|
||||||
|
ContextBuildInput(
|
||||||
|
base_system_prompt="Follow user requests.",
|
||||||
|
current_user_input="今天几号?",
|
||||||
|
session_context=SessionContext(session_id="web:alpha", source="web", model="stub-model"),
|
||||||
|
runtime_context=RuntimeContext(
|
||||||
|
utc_datetime="2026-05-26T01:10:00+00:00",
|
||||||
|
local_datetime="2026-05-26T09:10:00+08:00",
|
||||||
|
timezone="Asia/Shanghai",
|
||||||
|
utc_offset="+08:00",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt = result.messages[0]["content"]
|
||||||
|
assert "# Current Date and Time" in system_prompt
|
||||||
|
assert "Current UTC time: 2026-05-26T01:10:00+00:00" in system_prompt
|
||||||
|
assert "Current local time: 2026-05-26T09:10:00+08:00" in system_prompt
|
||||||
|
assert "Local timezone: Asia/Shanghai" in system_prompt
|
||||||
|
assert "Local UTC offset: +08:00" in system_prompt
|
||||||
|
assert '"today", "tomorrow", "now", "this week", and "next month"' in system_prompt
|
||||||
|
assert result.messages[-1] == {"role": "user", "content": "今天几号?"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_builder_uses_simplified_main_agent_prompt_by_default() -> None:
|
||||||
|
system_prompt = ContextBuilder().build_system_prompt(ContextBuildInput())
|
||||||
|
|
||||||
|
assert "你是海狸 (Beaver)" in system_prompt
|
||||||
|
assert "博维资讯系统有限公司研发的 AI 助手" in system_prompt
|
||||||
|
assert "使用简体中文进行面向用户的回复" in system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_builder_uses_traditional_main_agent_prompt_for_zh_hant() -> None:
|
||||||
|
system_prompt = ContextBuilder().build_system_prompt(ContextBuildInput(prompt_locale="zh-Hant"))
|
||||||
|
|
||||||
|
assert "你是海狸 (Beaver)" in system_prompt
|
||||||
|
assert "博維資訊系統有限公司研發的 AI 助手" in system_prompt
|
||||||
|
assert "使用繁體中文進行面向使用者的回覆" in system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_builder_uses_english_main_agent_prompt_for_en() -> None:
|
||||||
|
system_prompt = ContextBuilder().build_system_prompt(ContextBuildInput(prompt_locale="en"))
|
||||||
|
|
||||||
|
assert "You are Beaver, an AI assistant developed by Boway Information Systems Co., Ltd." in system_prompt
|
||||||
|
assert "Use English for user-facing replies" in system_prompt
|
||||||
@ -0,0 +1,193 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from beaver.interfaces.channels.connections import ChannelConnectionStore
|
||||||
|
from beaver.interfaces.web.app import create_app
|
||||||
|
from beaver.services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
|
def _app(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.setenv("BEAVER_BRIDGE_TOKEN", "bridge-token")
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config_path.write_text(
|
||||||
|
'{"agents": {"defaults": {"workspace": "%s"}}, "providers": {}}' % str(tmp_path),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
service = AgentService(config_path=config_path)
|
||||||
|
app = create_app(service=service, manage_service_lifecycle=False)
|
||||||
|
return app, service
|
||||||
|
|
||||||
|
|
||||||
|
def _connected_connection(tmp_path):
|
||||||
|
state_dir = tmp_path / "state" / "channel_connections"
|
||||||
|
store = ChannelConnectionStore(state_dir / "connections.json")
|
||||||
|
connection = store.create(
|
||||||
|
kind="weixin",
|
||||||
|
mode="sidecar",
|
||||||
|
display_name="Weixin Main",
|
||||||
|
account_id="weixin:me",
|
||||||
|
owner_user_id=None,
|
||||||
|
auth_type="connector_session",
|
||||||
|
)
|
||||||
|
store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||||
|
return connection
|
||||||
|
|
||||||
|
|
||||||
|
def _connection_with_status(tmp_path, status: str):
|
||||||
|
state_dir = tmp_path / "state" / "channel_connections"
|
||||||
|
store = ChannelConnectionStore(state_dir / "connections.json")
|
||||||
|
connection = store.create(
|
||||||
|
kind="feishu",
|
||||||
|
mode="sidecar",
|
||||||
|
display_name="Feishu Main",
|
||||||
|
account_id="feishu:app-1",
|
||||||
|
owner_user_id=None,
|
||||||
|
auth_type="connector_session",
|
||||||
|
)
|
||||||
|
store.update_status(connection.connection_id, status=status, last_error=None)
|
||||||
|
return connection
|
||||||
|
|
||||||
|
|
||||||
|
def _payload(connection, *, event_id: str = "evt-1", delivery_attempt: int = 1) -> dict:
|
||||||
|
return {
|
||||||
|
"eventId": event_id,
|
||||||
|
"timestamp": "2026-06-02T09:30:00Z",
|
||||||
|
"deliveryAttempt": delivery_attempt,
|
||||||
|
"connectionId": connection.connection_id,
|
||||||
|
"channelId": connection.channel_id,
|
||||||
|
"kind": "weixin",
|
||||||
|
"accountId": "weixin:me",
|
||||||
|
"peerId": "peer-1",
|
||||||
|
"peerType": "dm",
|
||||||
|
"userId": "sender-1",
|
||||||
|
"threadId": None,
|
||||||
|
"messageId": "msg-1",
|
||||||
|
"messageType": "text",
|
||||||
|
"content": "hello",
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_bridge_endpoint_accepts_valid_event(tmp_path, monkeypatch) -> None:
|
||||||
|
app, service = _app(tmp_path, monkeypatch)
|
||||||
|
try:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
connection = _connected_connection(tmp_path)
|
||||||
|
response = client.post(
|
||||||
|
"/api/channel-connector-bridge/events",
|
||||||
|
headers={"Authorization": "Bearer bridge-token"},
|
||||||
|
json=_payload(connection),
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["accepted"] is True
|
||||||
|
finally:
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bridge_endpoint_rejects_invalid_token(tmp_path, monkeypatch) -> None:
|
||||||
|
app, service = _app(tmp_path, monkeypatch)
|
||||||
|
try:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
connection = _connected_connection(tmp_path)
|
||||||
|
response = client.post(
|
||||||
|
"/api/channel-connector-bridge/events",
|
||||||
|
headers={"Authorization": "Bearer wrong"},
|
||||||
|
json=_payload(connection),
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
finally:
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bridge_endpoint_rejects_connection_identity_mismatch(tmp_path, monkeypatch) -> None:
|
||||||
|
app, service = _app(tmp_path, monkeypatch)
|
||||||
|
try:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
connection = _connected_connection(tmp_path)
|
||||||
|
payload = _payload(connection)
|
||||||
|
payload["channelId"] = "forged-channel"
|
||||||
|
payload["kind"] = "feishu"
|
||||||
|
payload["accountId"] = "feishu:attacker"
|
||||||
|
response = client.post(
|
||||||
|
"/api/channel-connector-bridge/events",
|
||||||
|
headers={"Authorization": "Bearer bridge-token"},
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert "does not match connection" in response.json()["detail"]
|
||||||
|
finally:
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bridge_endpoint_rejects_inactive_connection(tmp_path, monkeypatch) -> None:
|
||||||
|
app, service = _app(tmp_path, monkeypatch)
|
||||||
|
try:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
connection = _connection_with_status(tmp_path, "pairing")
|
||||||
|
response = client.post(
|
||||||
|
"/api/channel-connector-bridge/events",
|
||||||
|
headers={"Authorization": "Bearer bridge-token"},
|
||||||
|
json={
|
||||||
|
**_payload(connection),
|
||||||
|
"kind": "feishu",
|
||||||
|
"accountId": "feishu:app-1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 409
|
||||||
|
assert "not connected" in response.json()["detail"]
|
||||||
|
finally:
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bridge_endpoint_rejects_empty_or_oversized_content(tmp_path, monkeypatch) -> None:
|
||||||
|
app, service = _app(tmp_path, monkeypatch)
|
||||||
|
try:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
connection = _connection_with_status(tmp_path, "connected")
|
||||||
|
blank = client.post(
|
||||||
|
"/api/channel-connector-bridge/events",
|
||||||
|
headers={"Authorization": "Bearer bridge-token"},
|
||||||
|
json={
|
||||||
|
**_payload(connection, event_id="blank"),
|
||||||
|
"kind": "feishu",
|
||||||
|
"accountId": "feishu:app-1",
|
||||||
|
"content": " ",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
too_long = client.post(
|
||||||
|
"/api/channel-connector-bridge/events",
|
||||||
|
headers={"Authorization": "Bearer bridge-token"},
|
||||||
|
json={
|
||||||
|
**_payload(connection, event_id="too-long"),
|
||||||
|
"kind": "feishu",
|
||||||
|
"accountId": "feishu:app-1",
|
||||||
|
"content": "x" * 20001,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert blank.status_code == 400
|
||||||
|
assert too_long.status_code == 413
|
||||||
|
finally:
|
||||||
|
service.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bridge_endpoint_dedupes_repeated_event(tmp_path, monkeypatch) -> None:
|
||||||
|
app, service = _app(tmp_path, monkeypatch)
|
||||||
|
try:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
connection = _connected_connection(tmp_path)
|
||||||
|
first = client.post(
|
||||||
|
"/api/channel-connector-bridge/events",
|
||||||
|
headers={"Authorization": "Bearer bridge-token"},
|
||||||
|
json=_payload(connection),
|
||||||
|
)
|
||||||
|
second = client.post(
|
||||||
|
"/api/channel-connector-bridge/events",
|
||||||
|
headers={"Authorization": "Bearer bridge-token"},
|
||||||
|
json=_payload(connection, delivery_attempt=2),
|
||||||
|
)
|
||||||
|
assert first.status_code == 200
|
||||||
|
assert second.status_code in {200, 409}
|
||||||
|
assert second.json()["duplicate"] is True
|
||||||
|
finally:
|
||||||
|
service.close()
|
||||||
@ -0,0 +1,114 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from beaver.foundation.events import ChannelIdentity, OutboundMessage
|
||||||
|
from beaver.interfaces.channels.external_connector import ExternalConnectorChannel, _request_id
|
||||||
|
|
||||||
|
|
||||||
|
class FakeSidecarClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[dict] = []
|
||||||
|
|
||||||
|
async def send(self, payload: dict) -> dict:
|
||||||
|
self.sent.append(payload)
|
||||||
|
return {"ok": True, "providerMessageId": "provider-1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_external_connector_channel_sends_with_target_and_request_id() -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
client = FakeSidecarClient()
|
||||||
|
channel = ExternalConnectorChannel(
|
||||||
|
channel_id="weixin-main",
|
||||||
|
platform_kind="weixin",
|
||||||
|
connection_id="conn_1",
|
||||||
|
account_id="weixin:me",
|
||||||
|
display_name="Weixin Main",
|
||||||
|
sidecar_client=client,
|
||||||
|
)
|
||||||
|
message = OutboundMessage(
|
||||||
|
channel="weixin-main",
|
||||||
|
content="reply",
|
||||||
|
session_id="s1",
|
||||||
|
finish_reason="stop",
|
||||||
|
message_id="out-msg-1",
|
||||||
|
channel_identity=ChannelIdentity(
|
||||||
|
channel_id="weixin-main",
|
||||||
|
kind="weixin",
|
||||||
|
account_id="weixin:me",
|
||||||
|
peer_id="peer-1",
|
||||||
|
peer_type="dm",
|
||||||
|
thread_id=None,
|
||||||
|
user_id="sender-1",
|
||||||
|
message_id="in-msg-1",
|
||||||
|
),
|
||||||
|
metadata={"inbound_metadata": {"contextToken": "ctx-1"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.send(message)
|
||||||
|
|
||||||
|
assert client.sent == [
|
||||||
|
{
|
||||||
|
"requestId": "out_weixin-main:s1:out-msg-1",
|
||||||
|
"connectionId": "conn_1",
|
||||||
|
"channelId": "weixin-main",
|
||||||
|
"kind": "weixin",
|
||||||
|
"target": {"peerId": "peer-1", "peerType": "dm", "threadId": None},
|
||||||
|
"content": "reply",
|
||||||
|
"metadata": {"inboundMessageId": "in-msg-1", "sessionId": "s1", "contextToken": "ctx-1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_external_connector_request_id_falls_back_when_message_id_is_none_or_blank() -> None:
|
||||||
|
identity = ChannelIdentity(
|
||||||
|
channel_id="weixin-main",
|
||||||
|
kind="weixin",
|
||||||
|
account_id="weixin:me",
|
||||||
|
peer_id="peer-1",
|
||||||
|
peer_type="dm",
|
||||||
|
message_id="in-msg-1",
|
||||||
|
)
|
||||||
|
first = OutboundMessage(
|
||||||
|
channel="weixin-main",
|
||||||
|
content="same reply",
|
||||||
|
session_id="s1",
|
||||||
|
finish_reason="stop",
|
||||||
|
message_id=None, # type: ignore[arg-type]
|
||||||
|
channel_identity=identity,
|
||||||
|
)
|
||||||
|
second = OutboundMessage(
|
||||||
|
channel="weixin-main",
|
||||||
|
content="same reply",
|
||||||
|
session_id="s1",
|
||||||
|
finish_reason="stop",
|
||||||
|
message_id="",
|
||||||
|
channel_identity=identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _request_id(first) == _request_id(second)
|
||||||
|
assert _request_id(first).startswith("out_weixin-main:s1:")
|
||||||
|
|
||||||
|
|
||||||
|
def test_external_connector_channel_requires_identity() -> None:
|
||||||
|
async def run() -> None:
|
||||||
|
channel = ExternalConnectorChannel(
|
||||||
|
channel_id="weixin-main",
|
||||||
|
platform_kind="weixin",
|
||||||
|
connection_id="conn_1",
|
||||||
|
account_id="weixin:me",
|
||||||
|
display_name="Weixin Main",
|
||||||
|
sidecar_client=FakeSidecarClient(),
|
||||||
|
)
|
||||||
|
message = OutboundMessage(channel="weixin-main", content="reply", session_id="s1", finish_reason="stop")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await channel.send(message)
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "channel_identity is required" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Expected ValueError")
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user