Compare commits
37 Commits
33a9845566
...
KEYCLOAK_L
| Author | SHA1 | Date | |
|---|---|---|---|
| 461d1300ad | |||
| fc9fd93c36 | |||
| 9cc3334ea7 | |||
| dc4c6f313d | |||
| 9e2c02a333 | |||
| b9171998b9 | |||
| 64d789a3d0 | |||
| cc1bf85517 | |||
| 4c8bc53d33 | |||
| 70014c0f70 | |||
| eb69bb168a | |||
| 7287e93f87 | |||
| a925f0e77f | |||
| 6dc580ab26 | |||
| 3a16dc283d | |||
| 0fd4df3c17 | |||
| f46a435bab | |||
| a28254c6b8 | |||
| 0cf4f44346 | |||
| e0bc6c55b0 | |||
| 2c5205b06e | |||
| 236ac19789 | |||
| ba2417c7f2 | |||
| f3655c86c9 | |||
| c3d84b904a | |||
| ffa1249403 | |||
| ee972441f5 | |||
| d335199a64 | |||
| feeaccc0e3 | |||
| cf35edb4ca | |||
| e0a4862af8 | |||
| c3a0aef104 | |||
| b25713a141 | |||
| d74a1c9c12 | |||
| 834d4e1e2f | |||
| 6a6ddc21c0 | |||
| 826db8ec2e |
40
.env.example
40
.env.example
@ -1,30 +1,24 @@
|
||||
# Shared values used by the root deployment flow in README.md
|
||||
# Standalone app-instance defaults
|
||||
|
||||
PROJECT_ROOT=/home/ivan/xuan/beaver_project
|
||||
BEAVER_NET=beaver-instance-edge
|
||||
BEAVER_PROXY_CONTAINER_NAME=beaver-router-proxy
|
||||
BEAVER_PUBLIC_URL=http://172.19.0.245:18080
|
||||
BEAVER_HOST_IP=172.19.0.245
|
||||
BEAVER_HOST_PORT=18080
|
||||
|
||||
BEAVER_DEPLOY_TOKEN=change-me
|
||||
BEAVER_AUTHZ_INTERNAL_TOKEN=change-me
|
||||
|
||||
BEAVER_SERVER_IP=203.0.113.10
|
||||
BEAVER_BASE_DOMAIN=203.0.113.10.nip.io
|
||||
|
||||
BEAVER_PROVIDER=openai
|
||||
BEAVER_MODEL=openai/gpt-5
|
||||
BEAVER_API_KEY=sk-xxxxxxxx
|
||||
BEAVER_API_BASE=
|
||||
|
||||
# Per-instance Beaver backend config. In Docker app-instance this should point
|
||||
# to the mounted single-user sandbox config, not to frontend env.
|
||||
BEAVER_HOME=/root/.beaver
|
||||
BEAVER_CONFIG_PATH=/root/.beaver/config.json
|
||||
BEAVER_WORKSPACE=/root/.beaver/workspace
|
||||
|
||||
# Must be reachable from app-instance containers.
|
||||
BEAVER_AUTHZ_URL=http://beaver-authz-service:19090
|
||||
BEAVER_OUTLOOK_MCP_URL=
|
||||
BEAVER_OUTLOOK_MCP_SERVER_ID=outlook_mcp
|
||||
# Keycloak SSO
|
||||
BEAVER_KEYCLOAK_ISSUER=https://keycloak.bwgdi.com/realms/beaver
|
||||
BEAVER_KEYCLOAK_CLIENT_ID=beaver-agnet
|
||||
BEAVER_KEYCLOAK_REDIRECT_URI=http://172.19.0.245:18080/auth/callback
|
||||
BEAVER_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=http://172.19.0.245:18080/logout/callback
|
||||
|
||||
# Must be reachable from auth-portal and authz-service containers.
|
||||
BEAVER_DEPLOY_URL=http://beaver-deploy-control:8090
|
||||
# Model provider settings used when generating config.json with run-standalone.sh
|
||||
BEAVER_PROVIDER=openai
|
||||
BEAVER_MODEL=openai/gpt-5
|
||||
BEAVER_API_KEY=
|
||||
BEAVER_API_BASE=
|
||||
|
||||
# Optional runtime limits
|
||||
BEAVER_USER_FILES_MAX_UPLOAD_BYTES=5368709120
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -21,6 +21,7 @@ sessions/
|
||||
**/.ruff_cache/
|
||||
**/.mypy_cache/
|
||||
**/.cache/
|
||||
**/.codegraph/
|
||||
**/.venv/
|
||||
**/dist/
|
||||
**/build/
|
||||
|
||||
12
AGENTS.md
Normal file
12
AGENTS.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Repository Instructions
|
||||
|
||||
This branch is scoped to the standalone `app-instance` package.
|
||||
|
||||
Useful entry points:
|
||||
|
||||
- `app-instance/README.md`
|
||||
- `app-instance/run-standalone.sh`
|
||||
- `app-instance/backend`
|
||||
- `app-instance/frontend`
|
||||
|
||||
The repository no longer owns the old auth portal, authz service, deploy control plane, router proxy, or per-user instance orchestration.
|
||||
411
DESIGN.md
411
DESIGN.md
@ -1,411 +0,0 @@
|
||||
# DESIGN.md
|
||||
|
||||
## Brand
|
||||
|
||||
**Beaver — Taupe**
|
||||
|
||||
A calm editorial UI system focused on rhythm, hierarchy, and soft neutral contrast.
|
||||
Designed for AI-native tools, dashboards, and minimalist productivity software.
|
||||
|
||||
The interface emphasizes:
|
||||
|
||||
- Spacious layouts
|
||||
- Soft grayscale surfaces
|
||||
- Typography-first hierarchy
|
||||
- Rounded geometry
|
||||
- Quiet shadows
|
||||
- Dense information with low visual noise
|
||||
|
||||
The visual tone should feel:
|
||||
|
||||
- thoughtful
|
||||
- mature
|
||||
- calm
|
||||
- premium
|
||||
- architectural
|
||||
- editorial
|
||||
|
||||
Avoid:
|
||||
|
||||
- saturated colors
|
||||
- hard borders
|
||||
- sharp corners
|
||||
- excessive gradients
|
||||
- loud shadows
|
||||
- playful illustration-heavy UI
|
||||
|
||||
---
|
||||
|
||||
# Colors
|
||||
|
||||
## Core Palette
|
||||
|
||||
| Token | Hex | Usage |
|
||||
|---|---|---|
|
||||
| background | `#F5F3F1` | Main app background |
|
||||
| foreground | `#0B0B0B` | Primary text |
|
||||
| primary | `#1D1715` | Primary actions |
|
||||
| secondary | `#E5E2DF` | Secondary surfaces |
|
||||
| muted | `#DDD9D6` | Muted backgrounds |
|
||||
| accent | `#CAC5C0` | Borders / subtle emphasis |
|
||||
|
||||
---
|
||||
|
||||
## Neutral Scale
|
||||
|
||||
| Token | Hex |
|
||||
|---|---|
|
||||
| zinc-50 | `#F7F5F4` |
|
||||
| zinc-100 | `#ECE8E5` |
|
||||
| zinc-200 | `#D8D2CE` |
|
||||
| zinc-300 | `#B8AEA8` |
|
||||
| zinc-400 | `#8B7E77` |
|
||||
| zinc-500 | `#6A5E58` |
|
||||
| zinc-600 | `#4F4642` |
|
||||
| zinc-700 | `#342E2B` |
|
||||
|
||||
---
|
||||
|
||||
## Semantic Colors
|
||||
|
||||
### Taupe
|
||||
|
||||
| Step | Hex |
|
||||
|---|---|
|
||||
| taupe-100 | `#E7E2DE` |
|
||||
| taupe-300 | `#B8AEA8` |
|
||||
| taupe-500 | `#8B7E77` |
|
||||
| taupe-700 | `#5F5550` |
|
||||
|
||||
### Sage
|
||||
|
||||
| Step | Hex |
|
||||
|---|---|
|
||||
| sage-100 | `#E3E8E2` |
|
||||
| sage-300 | `#B7C2B5` |
|
||||
| sage-500 | `#869683` |
|
||||
| sage-700 | `#657162` |
|
||||
|
||||
### Slate
|
||||
|
||||
| Step | Hex |
|
||||
|---|---|
|
||||
| slate-100 | `#E4E7EB` |
|
||||
| slate-300 | `#BCC4CE` |
|
||||
| slate-500 | `#8C96A3` |
|
||||
| slate-700 | `#697281` |
|
||||
|
||||
---
|
||||
|
||||
# Typography
|
||||
|
||||
## Philosophy
|
||||
|
||||
Typography drives hierarchy.
|
||||
|
||||
The system should feel like a modern editorial publication mixed with a productivity dashboard.
|
||||
|
||||
Large headings use elegant serif typography.
|
||||
UI and body copy use neutral grotesk sans-serif typography.
|
||||
|
||||
---
|
||||
|
||||
## Font Stack
|
||||
|
||||
### Serif
|
||||
|
||||
```css
|
||||
font-family: "Lora", Georgia, serif;
|
||||
```
|
||||
|
||||
Used for:
|
||||
|
||||
- hero titles
|
||||
- article headings
|
||||
- marketing emphasis
|
||||
- editorial sections
|
||||
|
||||
---
|
||||
|
||||
### Sans
|
||||
|
||||
```css
|
||||
font-family: "Public Sans", Inter, sans-serif;
|
||||
```
|
||||
|
||||
Used for:
|
||||
|
||||
- UI
|
||||
- labels
|
||||
- forms
|
||||
- dashboards
|
||||
- buttons
|
||||
- navigation
|
||||
|
||||
---
|
||||
|
||||
## Type Scale
|
||||
|
||||
| Style | Size | Weight | Line Height |
|
||||
|---|---|---|---|
|
||||
| h1 | 48px | 600 | 1.1 |
|
||||
| h2 | 36px | 600 | 1.15 |
|
||||
| h3 | 28px | 500 | 1.2 |
|
||||
| body-lg | 18px | 400 | 1.7 |
|
||||
| body | 16px | 400 | 1.6 |
|
||||
| small | 14px | 400 | 1.5 |
|
||||
| mono | 13px | 500 | 1.4 |
|
||||
|
||||
---
|
||||
|
||||
# Radius
|
||||
|
||||
Rounded geometry should feel soft but architectural.
|
||||
|
||||
| Token | Radius |
|
||||
|---|---|
|
||||
| xs | 4px |
|
||||
| sm | 8px |
|
||||
| md | 12px |
|
||||
| lg | 16px |
|
||||
| xl | 24px |
|
||||
| full | 999px |
|
||||
|
||||
Cards should primarily use:
|
||||
|
||||
```css
|
||||
border-radius: 16px;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Shadows
|
||||
|
||||
Shadows should be subtle and diffused.
|
||||
|
||||
Avoid strong elevation.
|
||||
|
||||
## Soft
|
||||
|
||||
```css
|
||||
box-shadow:
|
||||
0 1px 2px rgba(0,0,0,0.04),
|
||||
0 6px 24px rgba(0,0,0,0.03);
|
||||
```
|
||||
|
||||
## Floating
|
||||
|
||||
```css
|
||||
box-shadow:
|
||||
0 12px 40px rgba(0,0,0,0.06);
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
# Grid
|
||||
|
||||
## Layout
|
||||
|
||||
- 12-column grid
|
||||
- Max width: `1280px`
|
||||
- Horizontal padding: `32px`
|
||||
- Large whitespace between sections
|
||||
|
||||
---
|
||||
|
||||
## Content Widths
|
||||
|
||||
| Type | Width |
|
||||
|---|---|
|
||||
| reading | 720px |
|
||||
| dashboard | 1280px |
|
||||
| modal | 480px |
|
||||
| form | 560px |
|
||||
|
||||
---
|
||||
|
||||
# Spacing
|
||||
|
||||
Base unit:
|
||||
|
||||
```txt
|
||||
4px
|
||||
```
|
||||
|
||||
Spacing scale:
|
||||
|
||||
| Token | Value |
|
||||
|---|---|
|
||||
| 1 | 4px |
|
||||
| 2 | 8px |
|
||||
| 3 | 12px |
|
||||
| 4 | 16px |
|
||||
| 5 | 20px |
|
||||
| 6 | 24px |
|
||||
| 8 | 32px |
|
||||
| 10 | 40px |
|
||||
| 12 | 48px |
|
||||
| 16 | 64px |
|
||||
|
||||
Use generous vertical rhythm.
|
||||
|
||||
Sections should breathe.
|
||||
|
||||
---
|
||||
|
||||
# Components
|
||||
|
||||
## Buttons
|
||||
|
||||
### Primary
|
||||
|
||||
- Dark background
|
||||
- White text
|
||||
- Pill radius
|
||||
- Minimal shadow
|
||||
|
||||
```css
|
||||
background: #1D1715;
|
||||
color: white;
|
||||
border-radius: 999px;
|
||||
height: 40px;
|
||||
padding: 0 16px;
|
||||
```
|
||||
|
||||
### Secondary
|
||||
|
||||
```css
|
||||
background: #ECE8E5;
|
||||
color: #1D1715;
|
||||
```
|
||||
|
||||
### Ghost
|
||||
|
||||
Transparent background with subtle hover fill.
|
||||
|
||||
---
|
||||
|
||||
## Cards
|
||||
|
||||
Cards are soft containers with quiet separation.
|
||||
|
||||
```css
|
||||
background: rgba(255,255,255,0.7);
|
||||
border: 1px solid rgba(0,0,0,0.04);
|
||||
border-radius: 16px;
|
||||
```
|
||||
|
||||
Avoid heavy borders.
|
||||
|
||||
---
|
||||
|
||||
## Inputs
|
||||
|
||||
Inputs should feel invisible until focused.
|
||||
|
||||
```css
|
||||
background: #F7F5F4;
|
||||
border: 1px solid transparent;
|
||||
```
|
||||
|
||||
Focus:
|
||||
|
||||
```css
|
||||
border-color: #8B7E77;
|
||||
box-shadow: 0 0 0 3px rgba(139,126,119,0.12);
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Charts
|
||||
|
||||
Charts should use muted earthy tones.
|
||||
|
||||
Preferred palette:
|
||||
|
||||
- taupe
|
||||
- sage
|
||||
- slate
|
||||
|
||||
Avoid:
|
||||
|
||||
- neon colors
|
||||
- bright blue dashboards
|
||||
- rainbow charts
|
||||
|
||||
---
|
||||
|
||||
# Motion
|
||||
|
||||
Motion should be restrained and smooth.
|
||||
|
||||
Preferred easing:
|
||||
|
||||
```css
|
||||
cubic-bezier(0.22, 1, 0.36, 1)
|
||||
```
|
||||
|
||||
Preferred duration:
|
||||
|
||||
| Type | Duration |
|
||||
|---|---|
|
||||
| hover | 150ms |
|
||||
| panel | 250ms |
|
||||
| modal | 350ms |
|
||||
|
||||
---
|
||||
|
||||
# Layout Skeleton
|
||||
|
||||
Application layout:
|
||||
|
||||
- Left sidebar
|
||||
- Large content canvas
|
||||
- Floating top toolbar
|
||||
- Soft dashboard cards
|
||||
- Spacious internal padding
|
||||
|
||||
The UI should always feel:
|
||||
|
||||
- breathable
|
||||
- editorial
|
||||
- premium
|
||||
- calm
|
||||
|
||||
Never dense or overly enterprise-looking.
|
||||
|
||||
---
|
||||
|
||||
# Design Keywords
|
||||
|
||||
Use these words when generating UI:
|
||||
|
||||
- editorial
|
||||
- taupe
|
||||
- soft neutral
|
||||
- premium minimal
|
||||
- typography-first
|
||||
- architectural spacing
|
||||
- calm dashboard
|
||||
- quiet luxury
|
||||
- modern serif
|
||||
- subtle shadows
|
||||
- muted grayscale
|
||||
- sophisticated SaaS
|
||||
|
||||
---
|
||||
|
||||
# AI Agent Instructions
|
||||
|
||||
When generating UI:
|
||||
|
||||
1. Prioritize whitespace over density
|
||||
2. Typography should create hierarchy
|
||||
3. Use muted neutral palettes
|
||||
4. Prefer soft cards over hard sections
|
||||
5. Avoid excessive color usage
|
||||
6. Keep interactions subtle
|
||||
7. Use serif fonts sparingly for emphasis
|
||||
8. Maintain premium visual restraint
|
||||
9. Design should feel timeless rather than trendy
|
||||
10. Every screen should feel breathable
|
||||
198
README.md
198
README.md
@ -1,183 +1,75 @@
|
||||
# Beaver Project
|
||||
# Beaver App Instance
|
||||
|
||||
`Beaver Project` 是一套单机 Docker 部署的多实例运行环境:
|
||||
This branch contains a standalone Beaver app instance for an external orchestrator to deploy.
|
||||
|
||||
- 用户先进入独立的 `auth-portal` 完成注册或登录。
|
||||
- 注册会触发 `authz-service` 调用 `deploy-control`。
|
||||
- `deploy-control` 在同一台机器上创建一个独立的 `app-instance` 容器。
|
||||
- `router-proxy` 按实例域名把流量转发到对应容器。
|
||||
The app instance is a single Docker image that runs:
|
||||
|
||||
当前推荐的最小部署方式是一台 Linux / WSL2 Ubuntu 机器加 Docker。生产域名和 HTTPS 可以放在项目外层的 Nginx、Caddy、Traefik 或云负载均衡上。
|
||||
- Next.js frontend
|
||||
- FastAPI backend
|
||||
- Nginx reverse proxy
|
||||
|
||||
## 组件
|
||||
Authentication is delegated to Keycloak. This repository does not maintain local user registration, passwords, per-user instance routing, or deployment orchestration.
|
||||
|
||||
| 目录 | 职责 | 默认端口 |
|
||||
| --- | --- | --- |
|
||||
| `auth-portal/` | 用户登录、注册、模型配置引导入口 | `3081` |
|
||||
| `authz-service/` | AuthZ 服务,负责账号和 backend 身份编排 | `19090` |
|
||||
| `deploy-control/` | 部署控制面,调用 Docker 创建和管理实例 | `8090` |
|
||||
| `router-proxy/` | 统一实例入口代理,按 Host 分发到实例容器 | `8088` |
|
||||
| `app-instance/` | 单用户运行实例,容器内包含前端、后端和 Nginx | 容器内 `8080` |
|
||||
## Runtime Contract
|
||||
|
||||
公网环境通常只暴露:
|
||||
The external orchestrator is responsible for:
|
||||
|
||||
- `auth-portal`: `3081`,或外层代理后的 `https://portal.example.com`
|
||||
- `router-proxy`: `8088`,或外层代理后的 `https://<slug>.apps.example.com`
|
||||
- Building or pulling the app image
|
||||
- Providing a mounted Beaver home directory
|
||||
- Providing `config.json`
|
||||
- Choosing the public host and port
|
||||
- Registering the public callback URL and post-logout callback URL in Keycloak
|
||||
- Managing lifecycle, routing, TLS, and scaling
|
||||
|
||||
不要直接把 `deploy-control:8090` 和 `authz-service:19090` 暴露到公网。
|
||||
|
||||
## 请求链路
|
||||
|
||||
注册:
|
||||
The app instance exposes container port `8080`:
|
||||
|
||||
```text
|
||||
Browser
|
||||
-> auth-portal
|
||||
-> authz-service POST /portal/register
|
||||
-> deploy-control POST /api/instances/register
|
||||
-> app-instance/create-instance.sh
|
||||
-> app-instance POST /api/auth/register
|
||||
-> authz-service /oauth/register or /backends/register
|
||||
-> auth-portal provider onboarding
|
||||
-> deploy-control POST /api/instances/configure-provider
|
||||
/ frontend
|
||||
/api backend API
|
||||
/ws backend WebSocket
|
||||
```
|
||||
|
||||
登录:
|
||||
## Keycloak
|
||||
|
||||
Default IdP settings:
|
||||
|
||||
```text
|
||||
Browser
|
||||
-> auth-portal
|
||||
-> deploy-control POST /api/instances/resolve
|
||||
-> app-instance POST /api/auth/login
|
||||
-> app-instance frontend URL
|
||||
issuer: https://keycloak.bwgdi.com/realms/beaver
|
||||
client_id: beaver-agnet
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
For the current plain-IP deployment, register these values in Keycloak:
|
||||
|
||||
本机完整流程见:
|
||||
```text
|
||||
web_origin: http://172.19.0.245:18080
|
||||
redirect_uri: http://172.19.0.245:18080/auth/callback
|
||||
post_logout_redirect_uri: http://172.19.0.245:18080/logout/callback
|
||||
```
|
||||
|
||||
- [部署指南.md](./部署指南.md)
|
||||
The app validates Keycloak JWT issuer, signature, expiry, audience/azp, and nonce where applicable. The Keycloak `sub` claim is used as the application user ID.
|
||||
|
||||
域名、HTTPS、公网反向代理说明见:
|
||||
|
||||
- [域名配置指引.md](./域名配置指引.md)
|
||||
|
||||
最小配置变量:
|
||||
## Build
|
||||
|
||||
```bash
|
||||
export PROJECT_ROOT=/home/ivan/xuan/beaver_project
|
||||
export BEAVER_NET=beaver-instance-edge
|
||||
export BEAVER_PROXY_CONTAINER_NAME=beaver-router-proxy
|
||||
|
||||
export BEAVER_DEPLOY_TOKEN="$(openssl rand -hex 32)"
|
||||
export BEAVER_AUTHZ_INTERNAL_TOKEN="$(openssl rand -hex 32)"
|
||||
|
||||
export BEAVER_BASE_DOMAIN=127.0.0.1.nip.io
|
||||
export BEAVER_AUTHZ_URL='http://beaver-authz-service:19090'
|
||||
export BEAVER_DEPLOY_URL='http://beaver-deploy-control:8090'
|
||||
|
||||
export BEAVER_OUTLOOK_MCP_URL=''
|
||||
export BEAVER_OUTLOOK_MCP_SERVER_ID='outlook_mcp'
|
||||
cd app-instance
|
||||
docker build -t beaver/app-instance:keycloak-login .
|
||||
```
|
||||
|
||||
启动顺序:
|
||||
|
||||
1. 创建运行目录。
|
||||
2. 构建四个镜像。
|
||||
3. 创建共享 Docker network。
|
||||
4. 启动 `router-proxy`。
|
||||
5. 启动 `authz-service`。
|
||||
6. 启动 `deploy-control`。
|
||||
7. 启动 `auth-portal`。
|
||||
8. 打开 `http://127.0.0.1:3081/register` 测试注册。
|
||||
|
||||
## 关键配置关系
|
||||
|
||||
`DEPLOY_API_TOKEN` 和 `DEPLOY_CONTROL_API_TOKEN` 必须相等:
|
||||
|
||||
- `auth-portal` / `authz-service` 用 `DEPLOY_API_TOKEN` 请求 `deploy-control`。
|
||||
- `deploy-control` 用 `DEPLOY_CONTROL_API_TOKEN` 校验请求。
|
||||
|
||||
`AUTHZ_ISSUER` 在这套单机部署里要写容器网络地址:
|
||||
|
||||
```text
|
||||
http://beaver-authz-service:19090
|
||||
```
|
||||
|
||||
不要写成 `http://127.0.0.1:19090`,因为新创建的 `app-instance` 容器里的 `127.0.0.1` 指向它自己,不是 AuthZ 容器。
|
||||
|
||||
`DEPLOY_PUBLIC_*` 决定新实例展示给用户的 URL:
|
||||
## Run Locally
|
||||
|
||||
```bash
|
||||
DEPLOY_PUBLIC_SCHEME=http
|
||||
DEPLOY_PUBLIC_BASE_DOMAIN=127.0.0.1.nip.io
|
||||
DEPLOY_PUBLIC_PORT=8088
|
||||
cd app-instance
|
||||
./run-standalone.sh --build --replace
|
||||
```
|
||||
|
||||
本机测试时实例 URL 形如:
|
||||
The script prints the `web_origin`, `redirect_uri`, and `post_logout_redirect_uri` values that must be configured in Keycloak.
|
||||
|
||||
```text
|
||||
http://alice.127.0.0.1.nip.io:8088
|
||||
```
|
||||
## Useful Paths
|
||||
|
||||
正式 HTTPS 域名通常改成:
|
||||
- `app-instance/` - app image, frontend, backend, runtime helper script, app tests
|
||||
- `skills/` - reusable skill content retained for product/runtime packaging
|
||||
- `agents/` - agent registry retained for runtime packaging
|
||||
- `docs/` - product and architecture documentation
|
||||
- `AGENTS.md` - local coding-agent instructions
|
||||
|
||||
```bash
|
||||
DEPLOY_PUBLIC_SCHEME=https
|
||||
DEPLOY_PUBLIC_BASE_DOMAIN=apps.example.com
|
||||
DEPLOY_PUBLIC_PORT=443
|
||||
```
|
||||
|
||||
实例 URL 形如:
|
||||
|
||||
```text
|
||||
https://alice.apps.example.com
|
||||
```
|
||||
|
||||
前提是你已经在项目外层把 `*.apps.example.com` 的 `80/443` 流量转发到 `router-proxy:8088`。
|
||||
|
||||
## 模型配置方式
|
||||
|
||||
当前版本不会在注册创建实例时写入模型 provider、model 或 API key。
|
||||
|
||||
流程是:
|
||||
|
||||
1. 注册先创建一个不含模型凭证的实例。
|
||||
2. `auth-portal` 进入模型配置引导页。
|
||||
3. 用户确认后,Portal 调用 `deploy-control /api/instances/configure-provider`。
|
||||
4. `deploy-control` 写入该实例的 `config.json` 并重启对应容器。
|
||||
|
||||
如果用户跳过引导,实例仍会创建成功,但后续需要在实例内补齐 provider 配置后才能正常调用模型。
|
||||
|
||||
## 持久化目录
|
||||
|
||||
至少保留:
|
||||
|
||||
```text
|
||||
authz-service/runtime/data
|
||||
app-instance/runtime/instances
|
||||
app-instance/runtime/registry
|
||||
router-proxy/runtime/conf.d
|
||||
```
|
||||
|
||||
不要在需要保留账号、实例或配置时删除这些目录。
|
||||
|
||||
## 模板文件
|
||||
|
||||
可参考这些环境变量模板:
|
||||
|
||||
- [`.env.example`](./.env.example)
|
||||
- [`auth-portal/src/.env.example`](./auth-portal/src/.env.example)
|
||||
- [`authz-service/.env.example`](./authz-service/.env.example)
|
||||
- [`deploy-control/.env.example`](./deploy-control/.env.example)
|
||||
- [`router-proxy/.env.example`](./router-proxy/.env.example)
|
||||
|
||||
这些模板不会被脚本自动加载。你可以手动 `export`,也可以在 `docker run` 时使用 `--env-file`。
|
||||
|
||||
## 子项目文档
|
||||
|
||||
- [`app-instance/README.md`](./app-instance/README.md)
|
||||
- [`auth-portal/src/README.md`](./auth-portal/src/README.md)
|
||||
- [`authz-service/README.md`](./authz-service/README.md)
|
||||
- [`deploy-control/README.md`](./deploy-control/README.md)
|
||||
- [`router-proxy/README.md`](./router-proxy/README.md)
|
||||
See `app-instance/README.md` for the app-specific contract and environment variables.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
runtime/
|
||||
.env
|
||||
backend/.git/
|
||||
backend/.venv/
|
||||
backend/.pytest_cache/
|
||||
@ -13,4 +14,3 @@ frontend/node_modules/
|
||||
frontend/.next/
|
||||
frontend/.next-dev/
|
||||
frontend/tsconfig.tsbuildinfo
|
||||
|
||||
|
||||
28
app-instance/.env.example
Normal file
28
app-instance/.env.example
Normal file
@ -0,0 +1,28 @@
|
||||
# Docker image/runtime
|
||||
IMAGE_NAME=beaver/app-instance:keycloak
|
||||
CONTAINER_NAME=beaver-app-instance
|
||||
HOST_IP=172.19.0.245
|
||||
HOST_PORT=18080
|
||||
HOST_BIND_IP=0.0.0.0
|
||||
DATA_DIR=runtime/standalone
|
||||
DOCKER_NETWORK=
|
||||
|
||||
# Beaver model config rendered into runtime/standalone/beaver-home/config.json
|
||||
BEAVER_PROVIDER=custom
|
||||
BEAVER_MODEL=
|
||||
BEAVER_API_KEY=
|
||||
BEAVER_API_BASE=
|
||||
|
||||
# Keycloak public client
|
||||
KEYCLOAK_ISSUER=https://keycloak.bwgdi.com/realms/beaver
|
||||
KEYCLOAK_CLIENT_ID=beaver-agnet
|
||||
KEYCLOAK_REDIRECT_URI=
|
||||
KEYCLOAK_POST_LOGOUT_REDIRECT_URI=
|
||||
|
||||
# Build mirrors
|
||||
NPM_REGISTRY=https://registry.npmmirror.com
|
||||
NPM_FETCH_RETRIES=5
|
||||
NPM_FETCH_RETRY_MIN_TIMEOUT=20000
|
||||
NPM_FETCH_RETRY_MAX_TIMEOUT=120000
|
||||
APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/debian
|
||||
PYPI_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
@ -20,11 +20,15 @@ RUN --mount=type=cache,target=/root/.npm \
|
||||
|
||||
COPY frontend/ ./
|
||||
|
||||
ARG NEXT_PUBLIC_AUTH_PORTAL_URL=""
|
||||
ARG NEXT_PUBLIC_AUTH_PORTAL_PORT="3081"
|
||||
ARG NEXT_PUBLIC_KEYCLOAK_ISSUER="https://keycloak.bwgdi.com/realms/beaver"
|
||||
ARG NEXT_PUBLIC_KEYCLOAK_CLIENT_ID="beaver-agnet"
|
||||
ARG NEXT_PUBLIC_KEYCLOAK_REDIRECT_URI=""
|
||||
ARG NEXT_PUBLIC_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=""
|
||||
|
||||
ENV NEXT_PUBLIC_AUTH_PORTAL_URL=${NEXT_PUBLIC_AUTH_PORTAL_URL}
|
||||
ENV NEXT_PUBLIC_AUTH_PORTAL_PORT=${NEXT_PUBLIC_AUTH_PORTAL_PORT}
|
||||
ENV NEXT_PUBLIC_KEYCLOAK_ISSUER=${NEXT_PUBLIC_KEYCLOAK_ISSUER}
|
||||
ENV NEXT_PUBLIC_KEYCLOAK_CLIENT_ID=${NEXT_PUBLIC_KEYCLOAK_CLIENT_ID}
|
||||
ENV NEXT_PUBLIC_KEYCLOAK_REDIRECT_URI=${NEXT_PUBLIC_KEYCLOAK_REDIRECT_URI}
|
||||
ENV NEXT_PUBLIC_KEYCLOAK_POST_LOGOUT_REDIRECT_URI=${NEXT_PUBLIC_KEYCLOAK_POST_LOGOUT_REDIRECT_URI}
|
||||
|
||||
# API / WS 走同域反代,不在构建时写死实例地址。
|
||||
RUN npm run build
|
||||
@ -39,7 +43,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
||||
BEAVER_HOME=/root/.beaver \
|
||||
BEAVER_CONFIG_PATH=/root/.beaver/config.json \
|
||||
BEAVER_WORKSPACE=/root/.beaver/workspace \
|
||||
BEAVER_AUTH_FILE=/root/.beaver/web_auth_users.json \
|
||||
PORT=3000 \
|
||||
HOSTNAME=127.0.0.1
|
||||
|
||||
@ -47,8 +50,12 @@ ARG NPM_REGISTRY="https://registry.npmmirror.com"
|
||||
ARG NPM_FETCH_RETRIES="5"
|
||||
ARG NPM_FETCH_RETRY_MIN_TIMEOUT="20000"
|
||||
ARG NPM_FETCH_RETRY_MAX_TIMEOUT="120000"
|
||||
ARG APT_MIRROR="https://mirrors.tuna.tsinghua.edu.cn/debian"
|
||||
ARG PYPI_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
|
||||
RUN apt-get update && \
|
||||
RUN find /etc/apt -type f \( -name "*.list" -o -name "*.sources" \) -exec \
|
||||
sed -i "s|http://deb.debian.org/debian-security|${APT_MIRROR}-security|g; s|http://deb.debian.org/debian|${APT_MIRROR}|g; s|http://security.debian.org/debian-security|${APT_MIRROR}-security|g" {} + && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git nginx dumb-init && \
|
||||
mkdir -p /etc/apt/keyrings && \
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||
@ -63,7 +70,7 @@ WORKDIR /opt/app/backend
|
||||
|
||||
COPY backend/pyproject.toml backend/README.md ./
|
||||
COPY backend/beaver/ ./beaver/
|
||||
RUN uv pip install --system --no-cache .
|
||||
RUN uv pip install --system --no-cache --index-url "${PYPI_INDEX_URL}" ".[channels]"
|
||||
|
||||
WORKDIR /opt/app/frontend
|
||||
COPY --from=frontend-builder /build/frontend/next.config.js ./
|
||||
|
||||
@ -1,118 +1,56 @@
|
||||
# app-instance
|
||||
|
||||
单实例应用单元:
|
||||
|
||||
- 一个 Docker 容器里同时运行前端、后端和 Nginx 反代
|
||||
- 前端走 `/`
|
||||
- 后端 API 走 `/api`
|
||||
- WebSocket 走 `/ws`
|
||||
单体 app instance:一个 Docker 容器内运行前端、后端和 Nginx。这个目录只提供干净的应用实例,不再负责用户注册登录、实例编排、注册表或反向代理路由。
|
||||
|
||||
## 关键文件
|
||||
|
||||
- `Dockerfile`
|
||||
- 统一镜像构建入口
|
||||
- `entrypoint.sh`
|
||||
- 容器内启动前端、后端、Nginx
|
||||
- `create-instance.sh`
|
||||
- 创建实例目录、生成配置、启动容器、写注册表
|
||||
- `remove-instance.sh`
|
||||
- 删除容器、移除注册表、可选清理实例目录
|
||||
- `list-instances.sh`
|
||||
- 查看当前注册实例
|
||||
- `instance-registry.py`
|
||||
- 维护 `runtime/registry/instances.json`
|
||||
- `Dockerfile`: 统一镜像构建入口
|
||||
- `entrypoint.sh`: 容器内启动前端、后端、Nginx
|
||||
- `run-standalone.sh`: 从 `.env` 读取参数,构建并运行单体容器
|
||||
- `.env.example`: 可复制的部署参数模板
|
||||
|
||||
## 注册表
|
||||
## Keycloak SSO
|
||||
|
||||
默认注册表路径:
|
||||
当前登录交给 Keycloak:
|
||||
|
||||
- issuer: `https://keycloak.bwgdi.com/realms/beaver`
|
||||
- client_id: `beaver-agnet`
|
||||
- flow: Authorization Code + PKCE S256
|
||||
- password grant: disabled
|
||||
- client secret: none
|
||||
|
||||
应用需要在 Keycloak client 中配置:
|
||||
|
||||
- `redirect_uri`: `http://<host-ip>:<host-port>/auth/callback`
|
||||
- `post_logout_redirect_uri`: `http://<host-ip>:<host-port>/logout/callback`
|
||||
- `web_origin`: `http://<host-ip>:<host-port>`
|
||||
|
||||
默认单机部署示例:
|
||||
|
||||
```text
|
||||
runtime/registry/instances.json
|
||||
redirect_uri: http://172.19.0.245:18080/auth/callback
|
||||
post_logout_redirect_uri: http://172.19.0.245:18080/logout/callback
|
||||
web_origin: http://172.19.0.245:18080
|
||||
```
|
||||
|
||||
每条记录至少包含:
|
||||
|
||||
- `instance_id`
|
||||
- `instance_slug`
|
||||
- `container_name`
|
||||
- `host_port`
|
||||
- `public_url`
|
||||
- `instance_root`
|
||||
- `image_name`
|
||||
|
||||
## 常用命令
|
||||
|
||||
### 1. 构建镜像
|
||||
## 部署
|
||||
|
||||
```bash
|
||||
docker build -t beaver/app-instance:latest .
|
||||
cp .env.example .env
|
||||
./run-standalone.sh --build --replace
|
||||
```
|
||||
|
||||
### 2. 创建实例
|
||||
脚本会输出当前需要填到 Keycloak 的 `redirect_uri`、`post_logout_redirect_uri` 和 `web_origin`。
|
||||
|
||||
```bash
|
||||
./create-instance.sh \
|
||||
--image beaver/app-instance:latest \
|
||||
--instance-id demo-001 \
|
||||
--auth-username admin \
|
||||
--auth-password 123456 \
|
||||
--api-key 'your-api-key'
|
||||
```
|
||||
## 数据目录
|
||||
|
||||
可选参数:
|
||||
|
||||
- `--host-port`
|
||||
- `--public-url`
|
||||
- `--username`
|
||||
- `--email`
|
||||
- `--instance-host`
|
||||
- `--authz-base-url`
|
||||
- `--backend-id`
|
||||
- `--client-id`
|
||||
- `--client-secret`
|
||||
- `--network`
|
||||
- `--host-bind-ip`
|
||||
- `--build`
|
||||
- `--replace`
|
||||
|
||||
### 3. 查看实例
|
||||
|
||||
```bash
|
||||
./list-instances.sh
|
||||
./list-instances.sh --json
|
||||
```
|
||||
|
||||
### 4. 删除实例
|
||||
|
||||
```bash
|
||||
./remove-instance.sh --instance-id demo-001
|
||||
```
|
||||
|
||||
如果要把实例目录也一并清掉:
|
||||
|
||||
```bash
|
||||
./remove-instance.sh --instance-id demo-001 --purge-data
|
||||
```
|
||||
|
||||
## 目录约定
|
||||
|
||||
默认实例数据目录:
|
||||
默认数据目录:
|
||||
|
||||
```text
|
||||
runtime/instances/<instance-slug>/
|
||||
runtime/standalone/beaver-home/
|
||||
```
|
||||
|
||||
其中会生成:
|
||||
|
||||
```text
|
||||
runtime/instances/<instance-slug>/
|
||||
└── beaver-home
|
||||
├── config.json
|
||||
├── web_auth_users.json
|
||||
└── workspace/
|
||||
```
|
||||
|
||||
这个目录是单用户 sandbox 的配置与数据边界。容器内会把它挂到:
|
||||
容器内挂载到:
|
||||
|
||||
```text
|
||||
/root/.beaver/
|
||||
@ -125,28 +63,4 @@ BEAVER_CONFIG_PATH=/root/.beaver/config.json
|
||||
BEAVER_WORKSPACE=/root/.beaver/workspace
|
||||
```
|
||||
|
||||
所以模型 `provider/api_key/api_base/model` 配一次即可,Web / channel 请求不需要、也不应该携带 API Key。
|
||||
|
||||
## 当前状态
|
||||
|
||||
这层已经支持:
|
||||
|
||||
- 统一镜像构建
|
||||
- 镜像内安装并启动新的 `beaver` 后端
|
||||
- 实例创建
|
||||
- 实例删除
|
||||
- 实例列表
|
||||
- 基于注册表的端口分配
|
||||
- 为 deploy-control / router-proxy 记录用户名和实例 host
|
||||
|
||||
## 生产注意
|
||||
|
||||
- 实例容器的宿主机端口默认只绑定 `127.0.0.1`
|
||||
- 外部访问应统一走 `router-proxy`
|
||||
- 如果你确实要把单个实例端口直接暴露到公网,再显式传 `--host-bind-ip 0.0.0.0`
|
||||
|
||||
下一步可以继续接:
|
||||
|
||||
- portal 调用创建实例
|
||||
- URL 分配和反向代理
|
||||
- 实例续期 / 停用 / 启用
|
||||
模型 provider、API key、workspace 数据都在这个目录下维护。
|
||||
|
||||
@ -1,145 +1,4 @@
|
||||
{
|
||||
"agents": [
|
||||
{
|
||||
"agent_id": "researcher",
|
||||
"capabilities": [
|
||||
"research",
|
||||
"analysis",
|
||||
"source review",
|
||||
"requirements"
|
||||
],
|
||||
"created_at": "2026-05-27T05:25:11.756341+00:00",
|
||||
"description": "Finds facts, references, constraints, and implementation options.",
|
||||
"display_name": "Researcher",
|
||||
"metadata": {},
|
||||
"model": null,
|
||||
"name": "researcher",
|
||||
"priority": 50,
|
||||
"provider_name": null,
|
||||
"role": "research",
|
||||
"skill_names": [],
|
||||
"source": "builtin",
|
||||
"status": "active",
|
||||
"system_prompt": "You are a research specialist. Gather concise evidence and tradeoffs for the parent task.",
|
||||
"tags": [
|
||||
"planning",
|
||||
"research"
|
||||
],
|
||||
"tool_hints": [],
|
||||
"updated_at": "2026-05-27T05:25:11.756349+00:00"
|
||||
},
|
||||
{
|
||||
"agent_id": "implementer",
|
||||
"capabilities": [
|
||||
"implementation",
|
||||
"coding",
|
||||
"refactor",
|
||||
"integration"
|
||||
],
|
||||
"created_at": "2026-05-27T05:25:11.756351+00:00",
|
||||
"description": "Builds scoped implementation slices and proposes concrete changes.",
|
||||
"display_name": "Implementer",
|
||||
"metadata": {},
|
||||
"model": null,
|
||||
"name": "implementer",
|
||||
"priority": 45,
|
||||
"provider_name": null,
|
||||
"role": "implementation",
|
||||
"skill_names": [],
|
||||
"source": "builtin",
|
||||
"status": "active",
|
||||
"system_prompt": "You are an implementation specialist. Produce practical, scoped implementation output.",
|
||||
"tags": [
|
||||
"coding",
|
||||
"build"
|
||||
],
|
||||
"tool_hints": [],
|
||||
"updated_at": "2026-05-27T05:25:11.756353+00:00"
|
||||
},
|
||||
{
|
||||
"agent_id": "reviewer",
|
||||
"capabilities": [
|
||||
"review",
|
||||
"quality",
|
||||
"risk",
|
||||
"verification"
|
||||
],
|
||||
"created_at": "2026-05-27T05:25:11.756355+00:00",
|
||||
"description": "Reviews plans, code, outputs, and risks before final synthesis.",
|
||||
"display_name": "Reviewer",
|
||||
"metadata": {},
|
||||
"model": null,
|
||||
"name": "reviewer",
|
||||
"priority": 45,
|
||||
"provider_name": null,
|
||||
"role": "review",
|
||||
"skill_names": [],
|
||||
"source": "builtin",
|
||||
"status": "active",
|
||||
"system_prompt": "You are a review specialist. Focus on defects, missing requirements, and risks.",
|
||||
"tags": [
|
||||
"review",
|
||||
"quality"
|
||||
],
|
||||
"tool_hints": [],
|
||||
"updated_at": "2026-05-27T05:25:11.756356+00:00"
|
||||
},
|
||||
{
|
||||
"agent_id": "tester",
|
||||
"capabilities": [
|
||||
"testing",
|
||||
"verification",
|
||||
"regression",
|
||||
"qa"
|
||||
],
|
||||
"created_at": "2026-05-27T05:25:11.756358+00:00",
|
||||
"description": "Designs and executes verification checks for task outputs.",
|
||||
"display_name": "Tester",
|
||||
"metadata": {},
|
||||
"model": null,
|
||||
"name": "tester",
|
||||
"priority": 40,
|
||||
"provider_name": null,
|
||||
"role": "testing",
|
||||
"skill_names": [],
|
||||
"source": "builtin",
|
||||
"status": "active",
|
||||
"system_prompt": "You are a testing specialist. Identify focused checks and report pass/fail evidence.",
|
||||
"tags": [
|
||||
"test",
|
||||
"quality"
|
||||
],
|
||||
"tool_hints": [],
|
||||
"updated_at": "2026-05-27T05:25:11.756358+00:00"
|
||||
},
|
||||
{
|
||||
"agent_id": "documenter",
|
||||
"capabilities": [
|
||||
"documentation",
|
||||
"explanation",
|
||||
"migration notes",
|
||||
"release notes"
|
||||
],
|
||||
"created_at": "2026-05-27T05:25:11.756360+00:00",
|
||||
"description": "Writes and reconciles user-facing and internal documentation updates.",
|
||||
"display_name": "Documenter",
|
||||
"metadata": {},
|
||||
"model": null,
|
||||
"name": "documenter",
|
||||
"priority": 35,
|
||||
"provider_name": null,
|
||||
"role": "documentation",
|
||||
"skill_names": [],
|
||||
"source": "builtin",
|
||||
"status": "active",
|
||||
"system_prompt": "You are a documentation specialist. Produce concise docs aligned with the implementation.",
|
||||
"tags": [
|
||||
"docs",
|
||||
"communication"
|
||||
],
|
||||
"tool_hints": [],
|
||||
"updated_at": "2026-05-27T05:25:11.756360+00:00"
|
||||
}
|
||||
],
|
||||
"agents": [],
|
||||
"version": 1
|
||||
}
|
||||
|
||||
@ -15,7 +15,9 @@ class AgentRegistry:
|
||||
self.path = self.workspace / "agents" / "registry.json"
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not self.path.exists():
|
||||
self._write_agents(_builtin_agents())
|
||||
self._write_agents([])
|
||||
else:
|
||||
self._drop_legacy_builtin_agents()
|
||||
|
||||
def list_agents(self, *, include_disabled: bool = True) -> list[RegisteredAgent]:
|
||||
agents = self._read_agents()
|
||||
@ -125,72 +127,14 @@ class AgentRegistry:
|
||||
payload = {"version": 1, "agents": [agent.to_dict() for agent in agents]}
|
||||
self.path.write_text(json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
|
||||
def _drop_legacy_builtin_agents(self) -> None:
|
||||
agents = self._read_agents()
|
||||
migrated = [agent for agent in agents if agent.source != "builtin"]
|
||||
if len(migrated) != len(agents):
|
||||
self._write_agents(migrated)
|
||||
|
||||
|
||||
def _terms(text: str) -> set[str]:
|
||||
normalized = "".join(ch.lower() if ch.isalnum() else " " for ch in text)
|
||||
return {part for part in normalized.split() if part}
|
||||
|
||||
|
||||
def _builtin_agents() -> list[RegisteredAgent]:
|
||||
return [
|
||||
RegisteredAgent(
|
||||
agent_id="researcher",
|
||||
name="researcher",
|
||||
display_name="Researcher",
|
||||
role="research",
|
||||
description="Finds facts, references, constraints, and implementation options.",
|
||||
system_prompt="You are a research specialist. Gather concise evidence and tradeoffs for the parent task.",
|
||||
capabilities=["research", "analysis", "source review", "requirements"],
|
||||
tags=["planning", "research"],
|
||||
priority=50,
|
||||
source="builtin",
|
||||
),
|
||||
RegisteredAgent(
|
||||
agent_id="implementer",
|
||||
name="implementer",
|
||||
display_name="Implementer",
|
||||
role="implementation",
|
||||
description="Builds scoped implementation slices and proposes concrete changes.",
|
||||
system_prompt="You are an implementation specialist. Produce practical, scoped implementation output.",
|
||||
capabilities=["implementation", "coding", "refactor", "integration"],
|
||||
tags=["coding", "build"],
|
||||
priority=45,
|
||||
source="builtin",
|
||||
),
|
||||
RegisteredAgent(
|
||||
agent_id="reviewer",
|
||||
name="reviewer",
|
||||
display_name="Reviewer",
|
||||
role="review",
|
||||
description="Reviews plans, code, outputs, and risks before final synthesis.",
|
||||
system_prompt="You are a review specialist. Focus on defects, missing requirements, and risks.",
|
||||
capabilities=["review", "quality", "risk", "verification"],
|
||||
tags=["review", "quality"],
|
||||
priority=45,
|
||||
source="builtin",
|
||||
),
|
||||
RegisteredAgent(
|
||||
agent_id="tester",
|
||||
name="tester",
|
||||
display_name="Tester",
|
||||
role="testing",
|
||||
description="Designs and executes verification checks for task outputs.",
|
||||
system_prompt="You are a testing specialist. Identify focused checks and report pass/fail evidence.",
|
||||
capabilities=["testing", "verification", "regression", "qa"],
|
||||
tags=["test", "quality"],
|
||||
priority=40,
|
||||
source="builtin",
|
||||
),
|
||||
RegisteredAgent(
|
||||
agent_id="documenter",
|
||||
name="documenter",
|
||||
display_name="Documenter",
|
||||
role="documentation",
|
||||
description="Writes and reconciles user-facing and internal documentation updates.",
|
||||
system_prompt="You are a documentation specialist. Produce concise docs aligned with the implementation.",
|
||||
capabilities=["documentation", "explanation", "migration notes", "release notes"],
|
||||
tags=["docs", "communication"],
|
||||
priority=35,
|
||||
source="builtin",
|
||||
),
|
||||
]
|
||||
|
||||
@ -27,13 +27,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from beaver.memory.curated.snapshot import MemorySnapshot
|
||||
|
||||
|
||||
BEAVER_USER_ASSISTANT_IDENTITY_PROMPT = (
|
||||
"You are 海狸 (Beaver), an AI assistant developed by 博维资讯系统有限公司. "
|
||||
"When communicating with users, keep this identity consistent. "
|
||||
"If users ask who you are, say that you are 海狸 (Beaver), 博维资讯系统有限公司研发的 AI 助手."
|
||||
)
|
||||
from beaver.prompts import get_main_agent_prompt
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -76,7 +70,12 @@ class SessionContext:
|
||||
model: str | None = None
|
||||
user_id: str | None = None
|
||||
channel: str | None = None
|
||||
channel_kind: str | None = None
|
||||
account_id: str | None = None
|
||||
peer_id: str | None = None
|
||||
peer_type: str | None = None
|
||||
chat_id: str | None = None
|
||||
thread_id: str | None = None
|
||||
parent_session_id: str | None = None
|
||||
|
||||
|
||||
@ -108,6 +107,7 @@ class ContextBuildInput:
|
||||
"""
|
||||
|
||||
base_system_prompt: str = ""
|
||||
prompt_locale: str | None = None
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
current_user_input: str | list[dict[str, Any]] | None = None
|
||||
memory_snapshot: MemorySnapshot | None = None
|
||||
@ -166,7 +166,7 @@ class ContextBuilder:
|
||||
- activated skill 正文放到显式消息里,避免 system prompt 持续膨胀
|
||||
"""
|
||||
|
||||
sections: list[str] = [BEAVER_USER_ASSISTANT_IDENTITY_PROMPT]
|
||||
sections: list[str] = [get_main_agent_prompt(build_input.prompt_locale)]
|
||||
|
||||
base_system_prompt = (build_input.base_system_prompt or "").strip()
|
||||
if base_system_prompt:
|
||||
@ -354,8 +354,18 @@ class ContextBuilder:
|
||||
rows.append(f"User ID: {session_context.user_id}")
|
||||
if session_context.channel:
|
||||
rows.append(f"Channel: {session_context.channel}")
|
||||
if session_context.channel_kind:
|
||||
rows.append(f"Channel Kind: {session_context.channel_kind}")
|
||||
if session_context.account_id:
|
||||
rows.append(f"Account ID: {session_context.account_id}")
|
||||
if session_context.peer_id:
|
||||
rows.append(f"Peer ID: {session_context.peer_id}")
|
||||
if session_context.peer_type:
|
||||
rows.append(f"Peer Type: {session_context.peer_type}")
|
||||
if session_context.chat_id:
|
||||
rows.append(f"Chat ID: {session_context.chat_id}")
|
||||
if session_context.thread_id:
|
||||
rows.append(f"Thread ID: {session_context.thread_id}")
|
||||
if session_context.parent_session_id:
|
||||
rows.append(f"Parent Session ID: {session_context.parent_session_id}")
|
||||
|
||||
|
||||
@ -48,6 +48,12 @@ from beaver.tools.builtins import (
|
||||
SkillsListTool,
|
||||
TerminalTool,
|
||||
TodoTool,
|
||||
UserFilesCopyToWorkspaceTool,
|
||||
UserFilesListTool,
|
||||
UserFilesMkdirTool,
|
||||
UserFilesPublishOutputTool,
|
||||
UserFilesReadTool,
|
||||
UserFilesWriteTool,
|
||||
WebFetchTool,
|
||||
WebSearchTool,
|
||||
WriteFileTool,
|
||||
@ -220,6 +226,12 @@ class EngineLoader:
|
||||
ObjectBackedTool(SearchFilesTool()),
|
||||
ObjectBackedTool(WriteFileTool()),
|
||||
ObjectBackedTool(PatchFileTool()),
|
||||
ObjectBackedTool(UserFilesListTool()),
|
||||
ObjectBackedTool(UserFilesReadTool()),
|
||||
ObjectBackedTool(UserFilesWriteTool()),
|
||||
ObjectBackedTool(UserFilesMkdirTool()),
|
||||
ObjectBackedTool(UserFilesCopyToWorkspaceTool()),
|
||||
ObjectBackedTool(UserFilesPublishOutputTool()),
|
||||
ObjectBackedTool(WebFetchTool()),
|
||||
ObjectBackedTool(WebSearchTool()),
|
||||
ObjectBackedTool(TerminalTool()),
|
||||
|
||||
@ -13,6 +13,7 @@ from uuid import uuid4
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from beaver.engine.context import ContextBuildInput, RuntimeContext, SessionContext, SkillContext
|
||||
from beaver.foundation.events import ChannelIdentity
|
||||
from beaver.memory.runs import RunRecord, SkillEffectRecord
|
||||
from beaver.skills.learning import RunReceiptContext
|
||||
from beaver.skills.catalog.utils import strip_frontmatter
|
||||
@ -223,6 +224,7 @@ class AgentLoop:
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
skill_selection_context: str | None = None,
|
||||
prompt_locale: str | None = None,
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
@ -246,8 +248,10 @@ class AgentLoop:
|
||||
attempt_index: int | None = None,
|
||||
pinned_skill_names: list[str] | None = None,
|
||||
pinned_skill_contexts: list[SkillContext] | None = None,
|
||||
tool_executor_override: Any = None,
|
||||
allow_candidate_generation: bool = False,
|
||||
intent_agent_decision: dict[str, Any] | None = None,
|
||||
channel_identity: ChannelIdentity | None = None,
|
||||
) -> AgentRunResult:
|
||||
"""跑通最小 direct run 主链。
|
||||
|
||||
@ -272,6 +276,7 @@ class AgentLoop:
|
||||
title=title,
|
||||
execution_context=execution_context,
|
||||
skill_selection_context=skill_selection_context,
|
||||
prompt_locale=prompt_locale,
|
||||
model=model,
|
||||
provider_name=provider_name,
|
||||
api_key=api_key,
|
||||
@ -295,8 +300,10 @@ class AgentLoop:
|
||||
attempt_index=attempt_index,
|
||||
pinned_skill_names=pinned_skill_names,
|
||||
pinned_skill_contexts=pinned_skill_contexts,
|
||||
tool_executor_override=tool_executor_override,
|
||||
allow_candidate_generation=allow_candidate_generation,
|
||||
intent_agent_decision=intent_agent_decision,
|
||||
channel_identity=channel_identity,
|
||||
)
|
||||
|
||||
async def _process_direct_impl(
|
||||
@ -309,6 +316,7 @@ class AgentLoop:
|
||||
title: str | None = None,
|
||||
execution_context: str | None = None,
|
||||
skill_selection_context: str | None = None,
|
||||
prompt_locale: str | None = None,
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
@ -332,8 +340,10 @@ class AgentLoop:
|
||||
attempt_index: int | None = None,
|
||||
pinned_skill_names: list[str] | None = None,
|
||||
pinned_skill_contexts: list[SkillContext] | None = None,
|
||||
tool_executor_override: Any = None,
|
||||
allow_candidate_generation: bool = False,
|
||||
intent_agent_decision: dict[str, Any] | None = None,
|
||||
channel_identity: ChannelIdentity | None = None,
|
||||
) -> AgentRunResult:
|
||||
"""真正执行一轮 direct run 的内部实现。
|
||||
|
||||
@ -350,6 +360,7 @@ class AgentLoop:
|
||||
tool_registry = self._require_loaded("tool_registry")
|
||||
tool_assembler = self._require_loaded("tool_assembler")
|
||||
tool_executor = self._require_loaded("tool_executor")
|
||||
effective_tool_executor = tool_executor_override or tool_executor
|
||||
skills_loader = self._require_loaded("skills_loader")
|
||||
skill_assembler = self._require_loaded("skill_assembler")
|
||||
skill_learning_service = self._require_loaded("skill_learning_service")
|
||||
@ -564,6 +575,7 @@ class AgentLoop:
|
||||
|
||||
build_input = ContextBuildInput(
|
||||
base_system_prompt=self.profile.system_prompt,
|
||||
prompt_locale=prompt_locale,
|
||||
history=session_manager.get_history(
|
||||
resolved_session_id,
|
||||
max_messages=max(1, self.profile.max_context_messages),
|
||||
@ -576,6 +588,13 @@ class AgentLoop:
|
||||
source=source,
|
||||
model=resolved_model,
|
||||
user_id=user_id,
|
||||
channel=channel_identity.channel_id if channel_identity else None,
|
||||
channel_kind=channel_identity.kind if channel_identity else None,
|
||||
account_id=channel_identity.account_id if channel_identity else None,
|
||||
peer_id=channel_identity.peer_id if channel_identity else None,
|
||||
peer_type=channel_identity.peer_type if channel_identity else None,
|
||||
chat_id=channel_identity.peer_id if channel_identity else None,
|
||||
thread_id=channel_identity.thread_id if channel_identity else None,
|
||||
parent_session_id=parent_session_id,
|
||||
),
|
||||
runtime_context=self._current_runtime_context(),
|
||||
@ -646,11 +665,17 @@ class AgentLoop:
|
||||
"tool_registry": tool_registry,
|
||||
"skills_loader": skills_loader,
|
||||
"draft_service": getattr(loaded, "draft_service", None),
|
||||
"beaver_config": loaded.config,
|
||||
"task_id": task_id,
|
||||
"run_id": resolved_run_id,
|
||||
**self.runtime_services,
|
||||
},
|
||||
metadata={
|
||||
"source": source,
|
||||
"agent_name": self.profile.name,
|
||||
"session_id": resolved_session_id,
|
||||
"task_id": task_id,
|
||||
"run_id": resolved_run_id,
|
||||
},
|
||||
)
|
||||
|
||||
@ -772,7 +797,7 @@ class AgentLoop:
|
||||
|
||||
iterations += 1
|
||||
for tool_call in response.tool_calls:
|
||||
result = await tool_executor.execute_tool_call(tool_call, context=tool_context)
|
||||
result = await effective_tool_executor.execute_tool_call(tool_call, context=tool_context)
|
||||
session_manager.append_message(
|
||||
resolved_session_id,
|
||||
run_id=resolved_run_id,
|
||||
|
||||
@ -3,9 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from ipaddress import ip_address
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from .base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from .registry import find_by_model, find_by_name, find_gateway
|
||||
@ -26,6 +28,23 @@ except ModuleNotFoundError: # pragma: no cover
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||
|
||||
|
||||
def _looks_like_local_vllm_api_base(api_base: str | None) -> bool:
|
||||
if not api_base:
|
||||
return False
|
||||
lowered = api_base.lower()
|
||||
if "vllm" in lowered or "localhost" in lowered:
|
||||
return True
|
||||
|
||||
host = urlsplit(lowered).hostname or ""
|
||||
if host in {"127.0.0.1", "::1", "0.0.0.0"}:
|
||||
return True
|
||||
try:
|
||||
parsed_host = ip_address(host)
|
||||
except ValueError:
|
||||
return False
|
||||
return parsed_host.is_private or parsed_host.is_loopback
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""通过 LiteLLM 统一访问大多数 provider。"""
|
||||
|
||||
@ -185,6 +204,13 @@ class LiteLLMProvider(LLMProvider):
|
||||
kwargs["provider"] = provider_payload
|
||||
|
||||
def _apply_thinking_mode(self, original_model: str, resolved_model: str, kwargs: dict[str, Any], enabled: bool | None) -> None:
|
||||
if self._uses_mistral_reasoning_parser(original_model, resolved_model):
|
||||
if enabled is not None:
|
||||
extra_body = dict(kwargs.get("extra_body") or {})
|
||||
extra_body["reasoning_effort"] = "high" if enabled else "none"
|
||||
kwargs["extra_body"] = extra_body
|
||||
return
|
||||
|
||||
extra_body = dict(kwargs.get("extra_body") or {})
|
||||
chat_template_kwargs = dict(extra_body.get("chat_template_kwargs") or {})
|
||||
chat_template_kwargs["enable_thinking"] = False
|
||||
@ -192,6 +218,14 @@ class LiteLLMProvider(LLMProvider):
|
||||
extra_body["thinking"] = {"type": "disabled"}
|
||||
kwargs["extra_body"] = extra_body
|
||||
|
||||
def _uses_mistral_reasoning_parser(self, original_model: str, resolved_model: str) -> bool:
|
||||
model_names = f"{original_model} {resolved_model}".lower()
|
||||
if "mistral" not in model_names:
|
||||
return False
|
||||
if self.provider_name == "vllm":
|
||||
return True
|
||||
return self.provider_name in {"openai", "custom"} and _looks_like_local_vllm_api_base(self.api_base)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
@ -110,6 +111,12 @@ END;
|
||||
"""
|
||||
|
||||
|
||||
def _sqlite_journal_mode() -> str:
|
||||
requested = os.getenv("BEAVER_SQLITE_JOURNAL_MODE", "DELETE").strip().upper()
|
||||
allowed = {"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "OFF", "WAL"}
|
||||
return requested if requested in allowed else "DELETE"
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""SQLite-backed session store."""
|
||||
|
||||
@ -119,7 +126,9 @@ class SessionStore:
|
||||
self._lock = threading.Lock()
|
||||
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False, isolation_level=None)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA mmap_size=0")
|
||||
self._conn.execute("PRAGMA busy_timeout=5000")
|
||||
self._conn.execute(f"PRAGMA journal_mode={_sqlite_journal_mode()}")
|
||||
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||
self._init_schema()
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from .schema import (
|
||||
AuthzConfig,
|
||||
BackendIdentityConfig,
|
||||
BeaverConfig,
|
||||
ChannelConfig,
|
||||
EmbeddingConfig,
|
||||
MCPServerConfig,
|
||||
ProviderConfig,
|
||||
@ -20,7 +21,7 @@ from .schema import (
|
||||
)
|
||||
|
||||
LOCAL_MCP_CATEGORIES: dict[str, dict[str, str]] = {
|
||||
"local_filesystem_mcp": {"category": "filesystem", "display_name": "本地文件工具"},
|
||||
"local_filesystem_mcp": {"category": "filesystem", "display_name": "个人智能体文件系统工具"},
|
||||
"local_runtime_mcp": {"category": "runtime", "display_name": "本地运行工具"},
|
||||
"local_memory_mcp": {"category": "memory", "display_name": "本地记忆工具"},
|
||||
"local_skills_mcp": {"category": "skills", "display_name": "本地技能工具"},
|
||||
@ -73,6 +74,7 @@ def load_config(
|
||||
embedding=_parse_embedding(data),
|
||||
tools=_parse_tools(data.get("tools")),
|
||||
authz=_parse_authz(data.get("authz")),
|
||||
channels=_parse_channels(data.get("channels")),
|
||||
backend_identity=_parse_backend_identity(data.get("backend_identity") or data.get("backendIdentity")),
|
||||
config_path=path,
|
||||
)
|
||||
@ -196,6 +198,48 @@ def _parse_authz(raw: Any) -> AuthzConfig:
|
||||
)
|
||||
|
||||
|
||||
def _parse_channels(raw: Any) -> dict[str, ChannelConfig]:
|
||||
channels: dict[str, ChannelConfig] = {}
|
||||
for channel_id, payload in _as_dict(raw).items():
|
||||
cleaned_id = str(channel_id).strip()
|
||||
if not cleaned_id:
|
||||
continue
|
||||
channels[cleaned_id] = _parse_channel_config(payload)
|
||||
return channels
|
||||
|
||||
|
||||
def _parse_channel_config(payload: Any) -> ChannelConfig:
|
||||
data = _as_dict(payload)
|
||||
return ChannelConfig(
|
||||
enabled=_bool(data.get("enabled"), default=False),
|
||||
kind=_string(data.get("kind")) or "",
|
||||
mode=_string(data.get("mode")) or "webhook",
|
||||
account_id=_string(data.get("accountId") or data.get("account_id")) or "",
|
||||
display_name=_string(data.get("displayName") or data.get("display_name")) or "",
|
||||
config=_normalize_config_map(data.get("config")),
|
||||
secrets=_string_dict(data.get("secrets")),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_config_map(value: Any) -> dict[str, Any]:
|
||||
if not isinstance(value, dict):
|
||||
return {}
|
||||
return {
|
||||
_camel_to_snake_key(str(key)): item
|
||||
for key, item in value.items()
|
||||
if str(key).strip()
|
||||
}
|
||||
|
||||
|
||||
def _camel_to_snake_key(value: str) -> str:
|
||||
result: list[str] = []
|
||||
for char in value:
|
||||
if char.isupper() and result:
|
||||
result.append("_")
|
||||
result.append(char.lower())
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _parse_backend_identity(raw: Any) -> BackendIdentityConfig:
|
||||
data = _as_dict(raw)
|
||||
return BackendIdentityConfig(
|
||||
|
||||
@ -91,6 +91,19 @@ class AuthzConfig:
|
||||
outlook_mcp_url: str = ""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChannelConfig:
|
||||
"""One configured channel adapter instance."""
|
||||
|
||||
enabled: bool = False
|
||||
kind: str = ""
|
||||
mode: str = "webhook"
|
||||
account_id: str = ""
|
||||
display_name: str = ""
|
||||
config: dict[str, Any] = field(default_factory=dict)
|
||||
secrets: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BackendIdentityConfig:
|
||||
"""This backend's AuthZ client identity."""
|
||||
@ -111,6 +124,7 @@ class BeaverConfig:
|
||||
embedding: EmbeddingConfig = field(default_factory=EmbeddingConfig)
|
||||
tools: ToolsConfig = field(default_factory=ToolsConfig)
|
||||
authz: AuthzConfig = field(default_factory=AuthzConfig)
|
||||
channels: dict[str, ChannelConfig] = field(default_factory=dict)
|
||||
backend_identity: BackendIdentityConfig = field(default_factory=BackendIdentityConfig)
|
||||
config_path: Path | None = None
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""Event contracts and dispatch helpers."""
|
||||
|
||||
from .message_bus import InboundMessage, MessageBus, OutboundMessage
|
||||
from .message_bus import ChannelIdentity, InboundMessage, MessageBus, OutboundMessage
|
||||
|
||||
__all__ = ["InboundMessage", "MessageBus", "OutboundMessage"]
|
||||
__all__ = ["ChannelIdentity", "InboundMessage", "MessageBus", "OutboundMessage"]
|
||||
|
||||
@ -9,12 +9,58 @@ from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChannelIdentity:
|
||||
"""Normalized channel routing identity.
|
||||
|
||||
`channel_id` is the Beaver adapter instance id, not the platform kind.
|
||||
"""
|
||||
|
||||
channel_id: str
|
||||
kind: str
|
||||
account_id: str
|
||||
peer_id: str
|
||||
thread_id: str | None = None
|
||||
peer_type: str = "unknown"
|
||||
user_id: str | None = None
|
||||
message_id: str | None = None
|
||||
|
||||
def validation_error(self) -> str | None:
|
||||
if not self.channel_id.strip():
|
||||
return "channel_id is required"
|
||||
if not self.account_id.strip():
|
||||
return "account_id is required"
|
||||
if not self.peer_id.strip():
|
||||
return "peer_id is required"
|
||||
return None
|
||||
|
||||
def session_id(self) -> str:
|
||||
parts = [self.channel_id, self.account_id, self.peer_id]
|
||||
if self.thread_id:
|
||||
parts.append(self.thread_id)
|
||||
return ":".join(_clean_session_part(part) for part in parts)
|
||||
|
||||
def dedupe_key(self) -> str | None:
|
||||
if not self.message_id:
|
||||
return None
|
||||
return f"{self.session_id()}:{_clean_session_part(self.message_id)}"
|
||||
|
||||
|
||||
def _clean_session_part(value: str) -> str:
|
||||
cleaned = str(value).strip()
|
||||
if not cleaned:
|
||||
return "unknown"
|
||||
return cleaned.replace(":", "_")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class InboundMessage:
|
||||
"""A minimal inbound message accepted by the gateway bridge."""
|
||||
|
||||
channel: str
|
||||
content: str
|
||||
content_type: str = "text"
|
||||
channel_identity: ChannelIdentity | None = None
|
||||
session_id: str | None = None
|
||||
user_id: str | None = None
|
||||
title: str | None = None
|
||||
@ -35,6 +81,8 @@ class OutboundMessage:
|
||||
content: str
|
||||
session_id: str | None
|
||||
finish_reason: str
|
||||
content_type: str = "text"
|
||||
channel_identity: ChannelIdentity | None = None
|
||||
message_id: str = field(default_factory=lambda: str(uuid4()))
|
||||
run_id: str | None = None
|
||||
provider_name: str | None = None
|
||||
|
||||
@ -109,3 +109,15 @@ class AuthzClient:
|
||||
async def delete_outlook_settings(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("DELETE", f"/backends/{backend_id}/settings/outlook")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def get_minio_settings(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("GET", f"/backends/{backend_id}/settings/minio")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def set_minio_settings(self, backend_id: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
data = await self._request("POST", f"/backends/{backend_id}/settings/minio", json_body=payload)
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
async def delete_minio_settings(self, backend_id: str) -> dict[str, Any]:
|
||||
data = await self._request("DELETE", f"/backends/{backend_id}/settings/minio")
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
@ -1,7 +1,17 @@
|
||||
"""Channel interfaces."""
|
||||
|
||||
from .base import ChannelAdapter
|
||||
from .base import ChannelInboundSink
|
||||
from .external_connector import ExternalConnectorChannel
|
||||
from .manager import ChannelManager
|
||||
from .memory import MemoryChannelAdapter
|
||||
from .terminal_websocket import TerminalWebSocketAdapter
|
||||
|
||||
__all__ = ["ChannelAdapter", "ChannelManager", "MemoryChannelAdapter"]
|
||||
__all__ = [
|
||||
"ChannelAdapter",
|
||||
"ChannelInboundSink",
|
||||
"ExternalConnectorChannel",
|
||||
"ChannelManager",
|
||||
"MemoryChannelAdapter",
|
||||
"TerminalWebSocketAdapter",
|
||||
]
|
||||
|
||||
@ -2,16 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
from typing import Any, Protocol
|
||||
|
||||
from beaver.foundation.events import MessageBus, OutboundMessage
|
||||
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||
|
||||
|
||||
class ChannelAdapter(Protocol):
|
||||
"""Minimal contract every gateway channel must implement."""
|
||||
"""Minimal contract every runtime channel adapter must implement."""
|
||||
|
||||
name: str
|
||||
bus: MessageBus
|
||||
channel_id: str
|
||||
kind: str
|
||||
mode: str
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Prepare the channel before messages are routed."""
|
||||
@ -22,3 +23,9 @@ class ChannelAdapter(Protocol):
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
"""Deliver an outbound message to the concrete channel."""
|
||||
|
||||
|
||||
class ChannelInboundSink(Protocol):
|
||||
"""Runtime callback used by adapters to submit normalized inbound messages."""
|
||||
|
||||
async def accept_inbound(self, message: InboundMessage) -> Any:
|
||||
"""Accept a normalized inbound message from an adapter."""
|
||||
|
||||
@ -0,0 +1,29 @@
|
||||
"""Channel connection setup layer."""
|
||||
|
||||
from .connectors import ChannelConnector, ChannelConnectorRegistry
|
||||
from .dedupe import ConnectorMessageDedupeRecord, DedupeBeginResult, MessageDedupeStore
|
||||
from .external import ExternalConnectorBase, FeishuConnector, WeixinConnector
|
||||
from .models import ChannelConnection, ChannelRuntimeSpec, PairingSession, ValidationResult
|
||||
from .sidecar_client import ConnectorSidecarClient
|
||||
from .store import ChannelConnectionStore, CredentialStore, PairingTokenStore
|
||||
from .telegram import TelegramConnector
|
||||
|
||||
__all__ = [
|
||||
"ChannelConnector",
|
||||
"ChannelConnectorRegistry",
|
||||
"ConnectorMessageDedupeRecord",
|
||||
"DedupeBeginResult",
|
||||
"MessageDedupeStore",
|
||||
"ExternalConnectorBase",
|
||||
"FeishuConnector",
|
||||
"WeixinConnector",
|
||||
"ConnectorSidecarClient",
|
||||
"ChannelConnection",
|
||||
"ChannelRuntimeSpec",
|
||||
"PairingSession",
|
||||
"ValidationResult",
|
||||
"ChannelConnectionStore",
|
||||
"CredentialStore",
|
||||
"PairingTokenStore",
|
||||
"TelegramConnector",
|
||||
]
|
||||
@ -0,0 +1,93 @@
|
||||
"""Channel connector registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from beaver.foundation.config.schema import ChannelConfig
|
||||
|
||||
from .models import ChannelRuntimeSpec, ValidationResult
|
||||
from .store import ChannelConnectionStore, CredentialStore
|
||||
|
||||
|
||||
class ChannelConnector(Protocol):
|
||||
kind: str
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
...
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
...
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
...
|
||||
|
||||
|
||||
class ChannelConnectorRegistry:
|
||||
def __init__(self, *, connection_store: ChannelConnectionStore, credential_store: CredentialStore) -> None:
|
||||
self.connection_store = connection_store
|
||||
self.credential_store = credential_store
|
||||
self._connectors: dict[str, ChannelConnector] = {}
|
||||
|
||||
def register(self, connector: ChannelConnector) -> None:
|
||||
kind = connector.kind.strip()
|
||||
if not kind:
|
||||
raise ValueError("Connector kind is required")
|
||||
if kind in self._connectors:
|
||||
raise ValueError(f"Connector already registered: {kind}")
|
||||
self._connectors[kind] = connector
|
||||
|
||||
def connectors(self) -> list[dict[str, str]]:
|
||||
return [{"kind": kind} for kind in sorted(self._connectors)]
|
||||
|
||||
def connector_for_kind(self, kind: str) -> ChannelConnector:
|
||||
return self._connector(kind)
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
connector = self._connector(connection.kind)
|
||||
result = await connector.validate(connection_id)
|
||||
self.connection_store.update_status(
|
||||
connection_id,
|
||||
status=result.status,
|
||||
last_error=result.error,
|
||||
)
|
||||
return result
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
return await self._connector(connection.kind).materialize_runtime(connection_id)
|
||||
|
||||
async def materialize_connected_runtime_specs(self) -> list[ChannelRuntimeSpec]:
|
||||
specs: list[ChannelRuntimeSpec] = []
|
||||
for connection in self.connection_store.list():
|
||||
if connection.status not in {"connected", "running"}:
|
||||
continue
|
||||
specs.append(await self._connector(connection.kind).materialize_runtime(connection.connection_id))
|
||||
return specs
|
||||
|
||||
async def materialize_channel_configs(self) -> dict[str, ChannelConfig]:
|
||||
channels: dict[str, ChannelConfig] = {}
|
||||
for spec in await self.materialize_connected_runtime_specs():
|
||||
secrets = self.credential_store.get(spec.secrets_ref) if spec.secrets_ref else {}
|
||||
channels[spec.channel_id] = ChannelConfig(
|
||||
enabled=True,
|
||||
kind=spec.kind,
|
||||
mode=spec.mode,
|
||||
account_id=spec.account_id,
|
||||
display_name=spec.display_name,
|
||||
config=dict(spec.config),
|
||||
secrets=secrets,
|
||||
)
|
||||
return channels
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
await self._connector(connection.kind).revoke(connection_id)
|
||||
self.connection_store.revoke(connection_id)
|
||||
|
||||
def _connector(self, kind: str) -> ChannelConnector:
|
||||
connector = self._connectors.get(kind)
|
||||
if connector is None:
|
||||
raise KeyError(f"Connector not registered: {kind}")
|
||||
return connector
|
||||
@ -0,0 +1,144 @@
|
||||
"""Bridge event dedupe store for external connector retries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _parse_iso(value: str) -> datetime:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ConnectorMessageDedupeRecord:
|
||||
dedupe_key: str
|
||||
connection_id: str
|
||||
event_id: str
|
||||
status: str
|
||||
first_seen_at: str
|
||||
updated_at: str
|
||||
delivery_attempts: int
|
||||
message_id: str | None = None
|
||||
last_error: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ConnectorMessageDedupeRecord":
|
||||
return cls(
|
||||
dedupe_key=str(data.get("dedupe_key") or ""),
|
||||
connection_id=str(data.get("connection_id") or ""),
|
||||
event_id=str(data.get("event_id") or ""),
|
||||
status=str(data.get("status") or "processing"),
|
||||
first_seen_at=str(data.get("first_seen_at") or _iso_now()),
|
||||
updated_at=str(data.get("updated_at") or _iso_now()),
|
||||
delivery_attempts=int(data.get("delivery_attempts") or 0),
|
||||
message_id=str(data["message_id"]) if data.get("message_id") is not None else None,
|
||||
last_error=str(data["last_error"]) if data.get("last_error") is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DedupeBeginResult:
|
||||
should_process: bool
|
||||
dedupe_key: str
|
||||
status: str
|
||||
http_status: int
|
||||
retry_after_seconds: int | None
|
||||
record: ConnectorMessageDedupeRecord
|
||||
|
||||
|
||||
class MessageDedupeStore:
|
||||
def __init__(self, path: Path, *, processing_ttl_seconds: int = 60) -> None:
|
||||
self.path = Path(path)
|
||||
self.processing_ttl_seconds = int(processing_ttl_seconds)
|
||||
self._lock = Lock()
|
||||
|
||||
def begin(self, *, connection_id: str, event_id: str, delivery_attempt: int) -> DedupeBeginResult:
|
||||
dedupe_key = f"{connection_id}:{event_id}"
|
||||
now = _iso_now()
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
raw = data["records"].get(dedupe_key)
|
||||
if isinstance(raw, dict):
|
||||
record = ConnectorMessageDedupeRecord.from_dict(raw)
|
||||
if record.status == "completed":
|
||||
return DedupeBeginResult(False, dedupe_key, record.status, 200, None, record)
|
||||
if record.status == "processing" and not self._is_stale(record, now):
|
||||
return DedupeBeginResult(False, dedupe_key, record.status, 409, 5, record)
|
||||
record.status = "processing"
|
||||
record.updated_at = now
|
||||
record.delivery_attempts = max(record.delivery_attempts + 1, int(delivery_attempt))
|
||||
record.last_error = None
|
||||
else:
|
||||
record = ConnectorMessageDedupeRecord(
|
||||
dedupe_key=dedupe_key,
|
||||
connection_id=connection_id,
|
||||
event_id=event_id,
|
||||
status="processing",
|
||||
first_seen_at=now,
|
||||
updated_at=now,
|
||||
delivery_attempts=max(1, int(delivery_attempt)),
|
||||
)
|
||||
data["records"][dedupe_key] = record.to_dict()
|
||||
self._save(data)
|
||||
return DedupeBeginResult(True, dedupe_key, record.status, 200, None, record)
|
||||
|
||||
def complete(self, dedupe_key: str, *, message_id: str | None) -> ConnectorMessageDedupeRecord:
|
||||
return self._mark(dedupe_key, status="completed", message_id=message_id, error=None)
|
||||
|
||||
def fail(self, dedupe_key: str, *, error: str) -> ConnectorMessageDedupeRecord:
|
||||
return self._mark(dedupe_key, status="failed", message_id=None, error=error)
|
||||
|
||||
def _mark(
|
||||
self,
|
||||
dedupe_key: str,
|
||||
*,
|
||||
status: str,
|
||||
message_id: str | None,
|
||||
error: str | None,
|
||||
) -> ConnectorMessageDedupeRecord:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
raw = data["records"].get(dedupe_key)
|
||||
if not isinstance(raw, dict):
|
||||
raise KeyError(dedupe_key)
|
||||
record = ConnectorMessageDedupeRecord.from_dict(raw)
|
||||
record.status = status
|
||||
record.updated_at = _iso_now()
|
||||
record.message_id = message_id or record.message_id
|
||||
record.last_error = error
|
||||
data["records"][dedupe_key] = record.to_dict()
|
||||
self._save(data)
|
||||
return record
|
||||
|
||||
def _is_stale(self, record: ConnectorMessageDedupeRecord, now: str) -> bool:
|
||||
age = (_parse_iso(now) - _parse_iso(record.updated_at)).total_seconds()
|
||||
return age >= self.processing_ttl_seconds
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"records": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"records": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("records"), dict):
|
||||
return {"records": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
@ -0,0 +1,210 @@
|
||||
"""Sidecar-backed channel connectors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from .models import ChannelRuntimeSpec, ValidationResult
|
||||
from .sidecar_client import ConnectorSidecarClient
|
||||
from .store import ChannelConnectionStore, CredentialStore
|
||||
|
||||
POLICY_CONFIG_KEYS = {
|
||||
"allowFrom",
|
||||
"groupAllowFrom",
|
||||
"requireMentionInGroups",
|
||||
"respondToMentionAll",
|
||||
"dmMode",
|
||||
"maxMessageChars",
|
||||
"textBatchDelayMs",
|
||||
"textBatchMaxMessages",
|
||||
"textBatchMaxChars",
|
||||
}
|
||||
|
||||
|
||||
class ExternalConnectorBase:
|
||||
kind = ""
|
||||
capabilities: list[str] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
connection_store: ChannelConnectionStore,
|
||||
credential_store: CredentialStore,
|
||||
sidecar_client: ConnectorSidecarClient | Any,
|
||||
sidecar_base_url: str,
|
||||
) -> None:
|
||||
self.connection_store = connection_store
|
||||
self.credential_store = credential_store
|
||||
self.sidecar_client = sidecar_client
|
||||
self.sidecar_base_url = sidecar_base_url
|
||||
self.callback_base_url = _callback_base_url()
|
||||
|
||||
async def start_session(
|
||||
self,
|
||||
*,
|
||||
display_name: str,
|
||||
owner_user_id: str | None,
|
||||
options: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
runtime_config = {"sidecarBaseUrl": self.sidecar_base_url}
|
||||
runtime_config.update(_policy_runtime_config(options))
|
||||
connection = self.connection_store.create(
|
||||
kind=self.kind,
|
||||
mode="sidecar",
|
||||
display_name=display_name or self.kind,
|
||||
account_id="",
|
||||
owner_user_id=owner_user_id,
|
||||
auth_type="connector_session",
|
||||
runtime_config=runtime_config,
|
||||
capabilities=list(self.capabilities),
|
||||
)
|
||||
connection = self.connection_store.update_status(connection.connection_id, status="pairing", last_error=None)
|
||||
payload = {
|
||||
"kind": self.kind,
|
||||
"connectionId": connection.connection_id,
|
||||
"channelId": connection.channel_id,
|
||||
"displayName": connection.display_name,
|
||||
"callbackBaseUrl": self.callback_base_url,
|
||||
"options": dict(options),
|
||||
}
|
||||
view = dict(await self.sidecar_client.start_session(payload))
|
||||
connection.pairing_session_id = str(view.get("sessionId") or "")
|
||||
connection = self.connection_store.update(connection)
|
||||
connection = self._apply_session_view(connection, view)
|
||||
view["connectionId"] = connection.connection_id
|
||||
view["channelId"] = connection.channel_id
|
||||
return view
|
||||
|
||||
async def poll_session(self, session_id: str) -> dict[str, Any]:
|
||||
view = dict(await self.sidecar_client.get_session(session_id))
|
||||
connection = self._connection_for_session(session_id)
|
||||
connection = self._apply_session_view(connection, view)
|
||||
view["connectionId"] = connection.connection_id
|
||||
view["channelId"] = connection.channel_id
|
||||
return view
|
||||
|
||||
def _apply_session_view(self, connection: Any, view: dict[str, Any]) -> Any:
|
||||
status = str(view.get("status") or "")
|
||||
if status == "connected":
|
||||
connection.account_id = str(view.get("accountId") or connection.account_id)
|
||||
connection.display_name = str(view.get("displayName") or connection.display_name)
|
||||
metadata = view.get("metadata") if isinstance(view.get("metadata"), dict) else {}
|
||||
state_ref = metadata.get("stateRef")
|
||||
if state_ref:
|
||||
connection.credentials_ref = self.credential_store.put(kind=self.kind, values={"stateRef": state_ref})
|
||||
self.connection_store.update(connection)
|
||||
self.connection_store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||
elif status in {"expired", "error", "cancelled"}:
|
||||
self.connection_store.update_status(
|
||||
connection.connection_id,
|
||||
status="error",
|
||||
last_error=str(view.get("error") or status),
|
||||
)
|
||||
return self.connection_store.get(connection.connection_id)
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
if connection.status in {"connected", "running"}:
|
||||
return ValidationResult(
|
||||
ok=True,
|
||||
status="connected",
|
||||
account_id=connection.account_id,
|
||||
display_name=connection.display_name,
|
||||
)
|
||||
return ValidationResult(ok=False, status=connection.status, error=connection.last_error)
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
if connection.status not in {"connected", "running"}:
|
||||
raise ValueError(f"Connection is not connected: {connection.connection_id}")
|
||||
return ChannelRuntimeSpec(
|
||||
channel_id=connection.channel_id,
|
||||
kind="external_connector",
|
||||
mode="http",
|
||||
account_id=connection.account_id,
|
||||
display_name=connection.display_name,
|
||||
config={
|
||||
"platformKind": self.kind,
|
||||
"connectionId": connection.connection_id,
|
||||
**dict(connection.runtime_config),
|
||||
"sidecarBaseUrl": connection.runtime_config.get("sidecarBaseUrl") or self.sidecar_base_url,
|
||||
},
|
||||
secrets_ref=None,
|
||||
)
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
await self.sidecar_client.logout(connection_id)
|
||||
|
||||
def _connection_for_session(self, session_id: str):
|
||||
for connection in self.connection_store.list():
|
||||
if connection.pairing_session_id == session_id:
|
||||
return connection
|
||||
raise KeyError(session_id)
|
||||
|
||||
|
||||
class WeixinConnector(ExternalConnectorBase):
|
||||
kind = "weixin"
|
||||
capabilities = ["receive_text", "send_text", "receive_media", "direct_messages"]
|
||||
|
||||
|
||||
class FeishuConnector(ExternalConnectorBase):
|
||||
kind = "feishu"
|
||||
capabilities = ["receive_text", "send_text", "receive_media", "groups"]
|
||||
|
||||
|
||||
def _policy_runtime_config(options: dict[str, Any]) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
for key in POLICY_CONFIG_KEYS:
|
||||
if key not in options:
|
||||
continue
|
||||
value = options[key]
|
||||
if key in {"allowFrom", "groupAllowFrom"}:
|
||||
items = _string_list(value)
|
||||
if items:
|
||||
result[key] = items
|
||||
continue
|
||||
if key in {"maxMessageChars", "textBatchDelayMs", "textBatchMaxMessages", "textBatchMaxChars"}:
|
||||
number = _positive_int(value)
|
||||
if number is not None:
|
||||
result[key] = number
|
||||
continue
|
||||
if key in {"requireMentionInGroups", "respondToMentionAll"}:
|
||||
result[key] = _bool(value)
|
||||
continue
|
||||
text = str(value or "").strip()
|
||||
if text:
|
||||
result[key] = text
|
||||
return result
|
||||
|
||||
|
||||
def _callback_base_url() -> str:
|
||||
for name in ("EXTERNAL_CONNECTOR_CALLBACK_BASE_URL", "BEAVER_CONNECTOR_CALLBACK_BASE_URL"):
|
||||
value = os.environ.get(name, "").strip()
|
||||
if value:
|
||||
return value.rstrip("/")
|
||||
return ""
|
||||
|
||||
|
||||
def _string_list(value: Any) -> list[str]:
|
||||
if isinstance(value, str):
|
||||
raw_items = value.replace("\n", ",").split(",")
|
||||
elif isinstance(value, list):
|
||||
raw_items = value
|
||||
else:
|
||||
raw_items = []
|
||||
return [str(item).strip() for item in raw_items if str(item).strip()]
|
||||
|
||||
|
||||
def _positive_int(value: Any) -> int | None:
|
||||
try:
|
||||
number = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return number if number > 0 else None
|
||||
|
||||
|
||||
def _bool(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
@ -0,0 +1,117 @@
|
||||
"""Channel connection setup models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
CONNECTION_STATUSES = {"draft", "pairing", "connected", "running", "degraded", "error", "revoked"}
|
||||
|
||||
|
||||
def iso_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChannelConnection:
|
||||
connection_id: str
|
||||
owner_user_id: str | None
|
||||
channel_id: str
|
||||
kind: str
|
||||
mode: str
|
||||
display_name: str
|
||||
account_id: str
|
||||
status: str
|
||||
auth_type: str
|
||||
credentials_ref: str | None = None
|
||||
connector_ref: str | None = None
|
||||
pairing_session_id: str | None = None
|
||||
runtime_config: dict[str, Any] = field(default_factory=dict)
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
created_at: str = field(default_factory=iso_now)
|
||||
updated_at: str = field(default_factory=iso_now)
|
||||
last_seen_at: str | None = None
|
||||
last_error: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ChannelConnection":
|
||||
return cls(
|
||||
connection_id=str(data.get("connection_id") or ""),
|
||||
owner_user_id=_optional_string(data.get("owner_user_id")),
|
||||
channel_id=str(data.get("channel_id") or ""),
|
||||
kind=str(data.get("kind") or ""),
|
||||
mode=str(data.get("mode") or ""),
|
||||
display_name=str(data.get("display_name") or ""),
|
||||
account_id=str(data.get("account_id") or ""),
|
||||
status=str(data.get("status") or "draft"),
|
||||
auth_type=str(data.get("auth_type") or ""),
|
||||
credentials_ref=_optional_string(data.get("credentials_ref")),
|
||||
connector_ref=_optional_string(data.get("connector_ref")),
|
||||
pairing_session_id=_optional_string(data.get("pairing_session_id")),
|
||||
runtime_config=dict(data.get("runtime_config") or {}),
|
||||
capabilities=[str(item) for item in data.get("capabilities") or []],
|
||||
created_at=str(data.get("created_at") or iso_now()),
|
||||
updated_at=str(data.get("updated_at") or iso_now()),
|
||||
last_seen_at=_optional_string(data.get("last_seen_at")),
|
||||
last_error=_optional_string(data.get("last_error")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PairingSession:
|
||||
pairing_session_id: str
|
||||
kind: str
|
||||
scope: str
|
||||
token: str
|
||||
status: str
|
||||
expires_at_ms: int
|
||||
created_at_ms: int
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PairingSession":
|
||||
return cls(
|
||||
pairing_session_id=str(data.get("pairing_session_id") or ""),
|
||||
kind=str(data.get("kind") or ""),
|
||||
scope=str(data.get("scope") or ""),
|
||||
token=str(data.get("token") or ""),
|
||||
status=str(data.get("status") or "pending"),
|
||||
expires_at_ms=int(data.get("expires_at_ms") or 0),
|
||||
created_at_ms=int(data.get("created_at_ms") or 0),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChannelRuntimeSpec:
|
||||
channel_id: str
|
||||
kind: str
|
||||
mode: str
|
||||
account_id: str
|
||||
display_name: str
|
||||
config: dict[str, Any] = field(default_factory=dict)
|
||||
secrets_ref: str | None = None
|
||||
external_endpoint: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ValidationResult:
|
||||
ok: bool
|
||||
status: str
|
||||
account_id: str | None = None
|
||||
display_name: str | None = None
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _optional_string(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
@ -0,0 +1,39 @@
|
||||
"""HTTP client for the generic external connector sidecar."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class ConnectorSidecarClient:
|
||||
def __init__(self, *, base_url: str, token: str, timeout_seconds: float = 20.0) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token
|
||||
self.timeout_seconds = float(timeout_seconds)
|
||||
|
||||
async def get_connectors(self) -> list[dict[str, Any]]:
|
||||
return await self._request("GET", "/connectors")
|
||||
|
||||
async def start_session(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return await self._request("POST", "/connector-sessions", json=payload)
|
||||
|
||||
async def get_session(self, session_id: str) -> dict[str, Any]:
|
||||
return await self._request("GET", f"/connector-sessions/{session_id}")
|
||||
|
||||
async def cancel_session(self, session_id: str) -> dict[str, Any]:
|
||||
return await self._request("POST", f"/connector-sessions/{session_id}/cancel", json={})
|
||||
|
||||
async def logout(self, connection_id: str) -> dict[str, Any]:
|
||||
return await self._request("POST", f"/connections/{connection_id}/logout", json={})
|
||||
|
||||
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return await self._request("POST", "/send", json=payload)
|
||||
|
||||
async def _request(self, method: str, path: str, *, json: dict[str, Any] | None = None) -> Any:
|
||||
headers = {"Authorization": f"Bearer {self.token}"} if self.token else {}
|
||||
async with httpx.AsyncClient(timeout=self.timeout_seconds) as client:
|
||||
response = await client.request(method, f"{self.base_url}{path}", json=json, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@ -0,0 +1,222 @@
|
||||
"""Persistent channel connection stores."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from .models import CONNECTION_STATUSES, ChannelConnection, PairingSession, iso_now
|
||||
|
||||
|
||||
class ChannelConnectionStore:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = Path(path)
|
||||
self._lock = Lock()
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
kind: str,
|
||||
mode: str,
|
||||
display_name: str,
|
||||
account_id: str,
|
||||
owner_user_id: str | None,
|
||||
auth_type: str,
|
||||
runtime_config: dict[str, Any] | None = None,
|
||||
capabilities: list[str] | None = None,
|
||||
credentials_ref: str | None = None,
|
||||
) -> ChannelConnection:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
connection_id = f"conn_{uuid4().hex}"
|
||||
channel_id = f"{_slug(kind)}-{uuid4().hex[:8]}"
|
||||
now = iso_now()
|
||||
connection = ChannelConnection(
|
||||
connection_id=connection_id,
|
||||
owner_user_id=owner_user_id,
|
||||
channel_id=channel_id,
|
||||
kind=kind,
|
||||
mode=mode,
|
||||
display_name=display_name or channel_id,
|
||||
account_id=account_id,
|
||||
status="draft",
|
||||
auth_type=auth_type,
|
||||
credentials_ref=credentials_ref,
|
||||
runtime_config=runtime_config or {},
|
||||
capabilities=capabilities or [],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
data["connections"][connection_id] = connection.to_dict()
|
||||
self._save(data)
|
||||
return connection
|
||||
|
||||
def get(self, connection_id: str) -> ChannelConnection:
|
||||
data = self._load()
|
||||
raw = data["connections"].get(connection_id)
|
||||
if not isinstance(raw, dict):
|
||||
raise KeyError(connection_id)
|
||||
return ChannelConnection.from_dict(raw)
|
||||
|
||||
def list(self) -> list[ChannelConnection]:
|
||||
data = self._load()
|
||||
return [ChannelConnection.from_dict(item) for item in data["connections"].values() if isinstance(item, dict)]
|
||||
|
||||
def update(self, connection: ChannelConnection) -> ChannelConnection:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
if connection.connection_id not in data["connections"]:
|
||||
raise KeyError(connection.connection_id)
|
||||
connection.updated_at = iso_now()
|
||||
data["connections"][connection.connection_id] = connection.to_dict()
|
||||
self._save(data)
|
||||
return connection
|
||||
|
||||
def update_status(self, connection_id: str, *, status: str, last_error: str | None) -> ChannelConnection:
|
||||
if status not in CONNECTION_STATUSES:
|
||||
raise ValueError(f"Unsupported connection status: {status}")
|
||||
connection = self.get(connection_id)
|
||||
connection.status = status
|
||||
connection.last_error = last_error
|
||||
if status in {"connected", "running"}:
|
||||
connection.last_seen_at = iso_now()
|
||||
return self.update(connection)
|
||||
|
||||
def revoke(self, connection_id: str) -> ChannelConnection:
|
||||
return self.update_status(connection_id, status="revoked", last_error=None)
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"connections": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"connections": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("connections"), dict):
|
||||
return {"connections": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
|
||||
|
||||
class CredentialStore:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = Path(path)
|
||||
self._lock = Lock()
|
||||
|
||||
def put(self, *, kind: str, values: dict[str, Any]) -> str:
|
||||
cleaned = {str(key): str(value) for key, value in values.items() if str(key).strip() and str(value).strip()}
|
||||
ref = f"cred_{uuid4().hex}"
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
data["credentials"][ref] = {"kind": kind, "values": cleaned, "created_at": iso_now()}
|
||||
self._save(data)
|
||||
return ref
|
||||
|
||||
def get(self, ref: str) -> dict[str, str]:
|
||||
data = self._load()
|
||||
item = data["credentials"].get(ref)
|
||||
if not isinstance(item, dict):
|
||||
raise KeyError(ref)
|
||||
values = item.get("values")
|
||||
if not isinstance(values, dict):
|
||||
return {}
|
||||
return {str(key): str(value) for key, value in values.items()}
|
||||
|
||||
def redacted(self, ref: str | None) -> dict[str, str]:
|
||||
if not ref:
|
||||
return {}
|
||||
try:
|
||||
values = self.get(ref)
|
||||
except KeyError:
|
||||
return {}
|
||||
return {key: "***" for key in values}
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"credentials": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"credentials": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("credentials"), dict):
|
||||
return {"credentials": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
|
||||
|
||||
class PairingTokenStore:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = Path(path)
|
||||
self._lock = Lock()
|
||||
|
||||
def create(self, *, kind: str, ttl_seconds: int, scope: str) -> PairingSession:
|
||||
now_ms = _now_ms()
|
||||
session = PairingSession(
|
||||
pairing_session_id=f"pair_{uuid4().hex}",
|
||||
kind=kind,
|
||||
scope=scope,
|
||||
token=f"pair_{uuid4().hex}",
|
||||
status="pending",
|
||||
expires_at_ms=now_ms + int(ttl_seconds * 1000),
|
||||
created_at_ms=now_ms,
|
||||
)
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
data["sessions"][session.pairing_session_id] = session.to_dict()
|
||||
self._save(data)
|
||||
return session
|
||||
|
||||
def consume(self, token: str, *, expected_kind: str) -> PairingSession | None:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
for key, raw in data["sessions"].items():
|
||||
session = PairingSession.from_dict(raw)
|
||||
if session.token != token or session.kind != expected_kind:
|
||||
continue
|
||||
if session.status != "pending" or session.expires_at_ms <= _now_ms():
|
||||
return None
|
||||
session.status = "consumed"
|
||||
data["sessions"][key] = session.to_dict()
|
||||
self._save(data)
|
||||
return session
|
||||
return None
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"sessions": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"sessions": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("sessions"), dict):
|
||||
return {"sessions": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _slug(value: str) -> str:
|
||||
text = "".join(char if char.isalnum() else "-" for char in str(value).strip().lower())
|
||||
return "-".join(part for part in text.split("-") if part) or "channel"
|
||||
@ -0,0 +1,92 @@
|
||||
"""Telegram channel connector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from .models import ChannelRuntimeSpec, ValidationResult
|
||||
from .store import ChannelConnectionStore, CredentialStore
|
||||
|
||||
|
||||
class TelegramConnector:
|
||||
kind = "telegram"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
connection_store: ChannelConnectionStore,
|
||||
credential_store: CredentialStore,
|
||||
client_factory: Callable[[str], Any] | None = None,
|
||||
) -> None:
|
||||
self.connection_store = connection_store
|
||||
self.credential_store = credential_store
|
||||
self.client_factory = client_factory or _default_client_factory
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
token = self._bot_token(connection.credentials_ref)
|
||||
try:
|
||||
client = self.client_factory(token)
|
||||
raw = await client.get_me()
|
||||
bot_id = _value(raw, "id")
|
||||
username = _value(raw, "username")
|
||||
first_name = _value(raw, "first_name") or "Telegram Bot"
|
||||
account_id = f"telegram:{bot_id}" if bot_id else connection.account_id
|
||||
display_name = f"{first_name} (@{username})" if username else first_name
|
||||
connection.account_id = account_id
|
||||
connection.display_name = display_name
|
||||
connection.capabilities = ["receive_text", "send_text", "receive_media", "groups"]
|
||||
self.connection_store.update(connection)
|
||||
return ValidationResult(
|
||||
ok=True,
|
||||
status="connected",
|
||||
account_id=account_id,
|
||||
display_name=display_name,
|
||||
metadata={"username": username} if username else {},
|
||||
)
|
||||
except Exception as exc:
|
||||
return ValidationResult(ok=False, status="error", error=str(exc))
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
connection = self.connection_store.get(connection_id)
|
||||
if connection.status not in {"connected", "running"}:
|
||||
raise ValueError(f"Connection is not connected: {connection.connection_id}")
|
||||
return ChannelRuntimeSpec(
|
||||
channel_id=connection.channel_id,
|
||||
kind=connection.kind,
|
||||
mode=connection.mode,
|
||||
account_id=connection.account_id,
|
||||
display_name=connection.display_name,
|
||||
config=dict(connection.runtime_config),
|
||||
secrets_ref=connection.credentials_ref,
|
||||
)
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
# Telegram bot tokens do not have a Beaver-managed platform revoke action.
|
||||
# The registry owns local connection state transitions.
|
||||
return None
|
||||
|
||||
def _bot_token(self, credentials_ref: str | None) -> str:
|
||||
if not credentials_ref:
|
||||
raise ValueError("Telegram credentials are missing")
|
||||
token = self.credential_store.get(credentials_ref).get("botToken")
|
||||
if not token:
|
||||
raise ValueError("botToken is required")
|
||||
return token
|
||||
|
||||
|
||||
def _value(raw: Any, key: str) -> str:
|
||||
if isinstance(raw, dict):
|
||||
value = raw.get(key)
|
||||
else:
|
||||
value = getattr(raw, key, None)
|
||||
return str(value).strip() if value is not None else ""
|
||||
|
||||
|
||||
def _default_client_factory(token: str) -> Any:
|
||||
try:
|
||||
from telegram import Bot
|
||||
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||
raise RuntimeError("Install beaver-backend[telegram] to validate Telegram connections") from exc
|
||||
return Bot(token=token)
|
||||
@ -0,0 +1,97 @@
|
||||
"""Generic runtime channel backed by an external connector sidecar."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import OutboundMessage
|
||||
from beaver.interfaces.channels.connections.sidecar_client import ConnectorSidecarClient
|
||||
|
||||
|
||||
class ExternalConnectorChannel:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
platform_kind: str,
|
||||
connection_id: str,
|
||||
account_id: str,
|
||||
display_name: str,
|
||||
sidecar_client: ConnectorSidecarClient | Any,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = "external_connector"
|
||||
self.mode = "http"
|
||||
self.platform_kind = platform_kind
|
||||
self.connection_id = connection_id
|
||||
self.account_id = account_id
|
||||
self.display_name = display_name or channel_id
|
||||
self.sidecar_client = sidecar_client
|
||||
self.started = False
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.started = False
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
identity = message.channel_identity
|
||||
if identity is None:
|
||||
raise ValueError("channel_identity is required for external connector sends")
|
||||
metadata = {
|
||||
"inboundMessageId": identity.message_id,
|
||||
"sessionId": message.session_id,
|
||||
}
|
||||
context_token = _context_token(message)
|
||||
if context_token:
|
||||
metadata["contextToken"] = context_token
|
||||
payload = {
|
||||
"requestId": _request_id(message),
|
||||
"connectionId": self.connection_id,
|
||||
"channelId": self.channel_id,
|
||||
"kind": self.platform_kind,
|
||||
"target": {
|
||||
"peerId": identity.peer_id,
|
||||
"peerType": identity.peer_type,
|
||||
"threadId": identity.thread_id,
|
||||
},
|
||||
"content": message.content,
|
||||
"metadata": metadata,
|
||||
}
|
||||
await self.sidecar_client.send(payload)
|
||||
|
||||
|
||||
def _request_id(message: OutboundMessage) -> str:
|
||||
identity = message.channel_identity
|
||||
channel = message.channel or (identity.channel_id if identity else "unknown")
|
||||
session_id = message.session_id or (identity.session_id() if identity else "unknown")
|
||||
message_id = str(message.message_id or "").strip()
|
||||
if not message_id:
|
||||
basis = "|".join(
|
||||
[
|
||||
message.content,
|
||||
identity.message_id if identity and identity.message_id else "",
|
||||
identity.peer_id if identity else "",
|
||||
message.finish_reason,
|
||||
]
|
||||
)
|
||||
message_id = hashlib.sha256(basis.encode("utf-8")).hexdigest()[:24]
|
||||
return f"out_{channel}:{session_id}:{message_id}"
|
||||
|
||||
|
||||
def _context_token(message: OutboundMessage) -> str | None:
|
||||
inbound_metadata = message.metadata.get("inbound_metadata")
|
||||
if isinstance(inbound_metadata, dict):
|
||||
value = _clean_optional(inbound_metadata.get("contextToken") or inbound_metadata.get("context_token"))
|
||||
if value:
|
||||
return value
|
||||
return _clean_optional(message.metadata.get("contextToken") or message.metadata.get("context_token"))
|
||||
|
||||
|
||||
def _clean_optional(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
@ -0,0 +1,116 @@
|
||||
"""Generic fixed-schema text webhook channel adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||
|
||||
|
||||
class GenericWebhookAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
mode: str,
|
||||
account_id: str,
|
||||
display_name: str = "",
|
||||
inbound_sink: ChannelInboundSink,
|
||||
response_timeout_seconds: float = 1800,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = kind
|
||||
self.mode = mode
|
||||
self.account_id = account_id
|
||||
self.display_name = display_name or channel_id
|
||||
self.inbound_sink = inbound_sink
|
||||
self.response_timeout_seconds = max(1.0, float(response_timeout_seconds))
|
||||
self.started = False
|
||||
self._pending: dict[str, asyncio.Future[OutboundMessage]] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.started = False
|
||||
for future in list(self._pending.values()):
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending.clear()
|
||||
|
||||
async def handle_webhook_payload(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
text = str(payload.get("text") or "").strip()
|
||||
peer_id = str(payload.get("peer_id") or "").strip()
|
||||
message_id = str(payload.get("message_id") or "").strip()
|
||||
thread_id = str(payload.get("thread_id") or "").strip() or None
|
||||
peer_type = str(payload.get("peer_type") or "unknown").strip() or "unknown"
|
||||
user_id = str(payload.get("user_id") or "").strip() or None
|
||||
if not text:
|
||||
return {"ok": False, "error": "text is required"}
|
||||
if not peer_id:
|
||||
return {"ok": False, "error": "peer_id is required"}
|
||||
if not message_id:
|
||||
return {"ok": False, "error": "message_id is required"}
|
||||
|
||||
identity = ChannelIdentity(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=peer_id,
|
||||
thread_id=thread_id,
|
||||
peer_type=peer_type,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
inbound = InboundMessage(
|
||||
channel=self.channel_id,
|
||||
content=text,
|
||||
user_id=user_id,
|
||||
channel_identity=identity,
|
||||
metadata={"webhook": {"peer_type": peer_type}},
|
||||
)
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._pending[inbound.message_id] = future
|
||||
accept = await self.inbound_sink.accept_inbound(inbound)
|
||||
if not accept.accepted:
|
||||
self._pending.pop(inbound.message_id, None)
|
||||
record = accept.record or {}
|
||||
return {
|
||||
"ok": accept.error is None,
|
||||
"duplicate": accept.duplicate,
|
||||
"pending": accept.pending,
|
||||
"session_id": accept.session_id,
|
||||
"status": record.get("status"),
|
||||
"run_id": record.get("run_id"),
|
||||
"reply": record.get("reply"),
|
||||
"error": accept.error or record.get("error"),
|
||||
}
|
||||
try:
|
||||
outbound = await asyncio.wait_for(future, timeout=self.response_timeout_seconds)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending.pop(inbound.message_id, None)
|
||||
return {
|
||||
"ok": True,
|
||||
"duplicate": False,
|
||||
"pending": True,
|
||||
"session_id": accept.session_id,
|
||||
}
|
||||
return {
|
||||
"ok": outbound.finish_reason != "error",
|
||||
"duplicate": False,
|
||||
"pending": False,
|
||||
"session_id": outbound.session_id,
|
||||
"run_id": outbound.run_id,
|
||||
"reply": outbound.content,
|
||||
"error": outbound.metadata.get("error"),
|
||||
}
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
future = self._pending.pop(message.message_id, None)
|
||||
if future is None or future.done():
|
||||
message.metadata["delivery_status"] = "unclaimed"
|
||||
return
|
||||
future.set_result(message)
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
|
||||
from beaver.foundation.events import MessageBus, OutboundMessage
|
||||
@ -20,13 +21,17 @@ class ChannelManager:
|
||||
self.started = False
|
||||
|
||||
def register(self, channel: ChannelAdapter) -> None:
|
||||
if self.started:
|
||||
raise RuntimeError("Cannot register channels after ChannelManager.start()")
|
||||
if channel.name in self.channels:
|
||||
raise ValueError(f"Channel already registered: {channel.name}")
|
||||
if channel.bus is not self.bus:
|
||||
raise ValueError("Channel must share the same MessageBus as ChannelManager")
|
||||
self.channels[channel.name] = channel
|
||||
if channel.channel_id in self.channels:
|
||||
raise ValueError(f"Channel already registered: {channel.channel_id}")
|
||||
self.channels[channel.channel_id] = channel
|
||||
|
||||
def unregister(self, channel_id: str) -> ChannelAdapter | None:
|
||||
return self.channels.pop(channel_id, None)
|
||||
|
||||
def replace_registered(self, channel: ChannelAdapter) -> ChannelAdapter | None:
|
||||
old = self.channels.get(channel.channel_id)
|
||||
self.channels[channel.channel_id] = channel
|
||||
return old
|
||||
|
||||
async def start(self) -> None:
|
||||
started: list[ChannelAdapter] = []
|
||||
@ -53,7 +58,13 @@ class ChannelManager:
|
||||
if errors:
|
||||
raise RuntimeError(f"Failed to stop {len(errors)} channel(s)") from errors[0]
|
||||
|
||||
async def dispatch_outbound(self, stop_event: asyncio.Event) -> None:
|
||||
async def dispatch_outbound(
|
||||
self,
|
||||
stop_event: asyncio.Event,
|
||||
*,
|
||||
on_delivered: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
||||
on_failed: Callable[[OutboundMessage, Exception | None], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
"""Route bus outbound messages until stopped and the queue is drained."""
|
||||
|
||||
while True:
|
||||
@ -68,9 +79,16 @@ class ChannelManager:
|
||||
channel = self.channels.get(message.channel)
|
||||
if channel is None:
|
||||
self.undeliverable.append(message)
|
||||
if on_failed is not None:
|
||||
await on_failed(message, None)
|
||||
continue
|
||||
|
||||
try:
|
||||
await channel.send(message)
|
||||
except Exception: # pragma: no cover - defensive channel isolation
|
||||
except Exception as exc: # pragma: no cover - defensive channel isolation
|
||||
self.undeliverable.append(message)
|
||||
if on_failed is not None:
|
||||
await on_failed(message, exc)
|
||||
else:
|
||||
if on_delivered is not None:
|
||||
await on_delivered(message)
|
||||
|
||||
@ -4,15 +4,27 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||
|
||||
|
||||
class MemoryChannelAdapter:
|
||||
"""A local channel that stores outbound messages in memory."""
|
||||
|
||||
def __init__(self, bus: MessageBus, *, name: str = "memory") -> None:
|
||||
self.name = name
|
||||
self.bus = bus
|
||||
def __init__(
|
||||
self,
|
||||
inbound_sink: ChannelInboundSink,
|
||||
*,
|
||||
channel_id: str = "memory-dev",
|
||||
kind: str = "memory",
|
||||
mode: str = "webhook",
|
||||
account_id: str = "memory",
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = kind
|
||||
self.mode = mode
|
||||
self.account_id = account_id
|
||||
self.inbound_sink = inbound_sink
|
||||
self.started = False
|
||||
self.sent_messages: list[OutboundMessage] = []
|
||||
|
||||
@ -36,12 +48,24 @@ class MemoryChannelAdapter:
|
||||
model: str | None = None,
|
||||
provider_name: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
peer_id: str = "default",
|
||||
thread_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> InboundMessage:
|
||||
"""Publish a text message from this channel into the shared bus."""
|
||||
|
||||
identity = ChannelIdentity(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=peer_id,
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
message = InboundMessage(
|
||||
channel=self.name,
|
||||
channel=self.channel_id,
|
||||
content=content,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
@ -50,9 +74,10 @@ class MemoryChannelAdapter:
|
||||
model=model,
|
||||
provider_name=provider_name,
|
||||
embedding_model=embedding_model,
|
||||
channel_identity=identity,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self.bus.publish_inbound(message)
|
||||
await self.inbound_sink.accept_inbound(message)
|
||||
return message
|
||||
|
||||
async def publish_external_text(
|
||||
@ -73,9 +98,6 @@ class MemoryChannelAdapter:
|
||||
the shared gateway bus.
|
||||
"""
|
||||
|
||||
session_parts = [self.name, chat_id]
|
||||
if thread_id:
|
||||
session_parts.append(thread_id)
|
||||
metadata = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
@ -84,8 +106,10 @@ class MemoryChannelAdapter:
|
||||
}
|
||||
return await self.publish_text(
|
||||
content,
|
||||
session_id=":".join(str(part) for part in session_parts if str(part)),
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
peer_id=chat_id,
|
||||
thread_id=thread_id,
|
||||
message_id=message_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@ -0,0 +1 @@
|
||||
"""Platform channel adapters."""
|
||||
@ -0,0 +1,138 @@
|
||||
"""Shared helpers for platform channel adapters."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class OutboundTarget:
|
||||
peer_id: str | None
|
||||
thread_id: str | None = None
|
||||
peer_type: str = "unknown"
|
||||
user_id: str | None = None
|
||||
|
||||
|
||||
class PlatformDeliveryError(RuntimeError):
|
||||
"""Raised when a platform client rejects a delivery."""
|
||||
|
||||
|
||||
def config_bool(config: dict[str, Any], key: str, *, default: bool = False) -> bool:
|
||||
value = config.get(key)
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
text = str(value).strip().lower()
|
||||
if text in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if text in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return default
|
||||
|
||||
|
||||
def config_list(config: dict[str, Any], key: str) -> list[str]:
|
||||
value = config.get(key)
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [part.strip() for part in value.split(",") if part.strip()]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [str(item).strip() for item in value if str(item).strip()]
|
||||
text = str(value).strip()
|
||||
return [text] if text else []
|
||||
|
||||
|
||||
def chunk_text(text: str, *, max_chars: int) -> list[str]:
|
||||
if max_chars <= 0:
|
||||
raise ValueError("max_chars must be positive")
|
||||
if not text:
|
||||
return [""]
|
||||
return [text[index : index + max_chars] for index in range(0, len(text), max_chars)]
|
||||
|
||||
|
||||
def compact_media_summary(media_type: str, *, file_name: str | None = None) -> str:
|
||||
label = str(media_type or "attachment").strip() or "attachment"
|
||||
if file_name:
|
||||
return f"[{label}: {file_name}]"
|
||||
return f"[{label}]"
|
||||
|
||||
|
||||
def target_from_session_id(session_id: str | None) -> OutboundTarget:
|
||||
if not session_id:
|
||||
return OutboundTarget(peer_id=None)
|
||||
parts = str(session_id).split(":")
|
||||
if len(parts) < 3:
|
||||
return OutboundTarget(peer_id=None)
|
||||
thread_id = parts[3] if len(parts) > 3 and parts[3] else None
|
||||
return OutboundTarget(peer_id=parts[2] or None, thread_id=thread_id)
|
||||
|
||||
|
||||
def outbound_target(message: OutboundMessage) -> OutboundTarget:
|
||||
identity = message.channel_identity
|
||||
if identity is None:
|
||||
return target_from_session_id(message.session_id)
|
||||
return OutboundTarget(
|
||||
peer_id=identity.peer_id,
|
||||
thread_id=identity.thread_id,
|
||||
peer_type=identity.peer_type,
|
||||
user_id=identity.user_id,
|
||||
)
|
||||
|
||||
|
||||
def mark_unclaimed(message: OutboundMessage) -> None:
|
||||
message.metadata["delivery_status"] = "unclaimed"
|
||||
|
||||
|
||||
def build_inbound_message(
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
account_id: str,
|
||||
peer_id: str,
|
||||
content: str,
|
||||
message_id: str | None,
|
||||
peer_type: str,
|
||||
user_id: str | None = None,
|
||||
thread_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> InboundMessage:
|
||||
identity = ChannelIdentity(
|
||||
channel_id=channel_id,
|
||||
kind=kind,
|
||||
account_id=account_id,
|
||||
peer_id=peer_id,
|
||||
thread_id=thread_id,
|
||||
peer_type=peer_type,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
return InboundMessage(
|
||||
channel=channel_id,
|
||||
content=content,
|
||||
session_id=identity.session_id(),
|
||||
user_id=user_id,
|
||||
message_id=message_id or "",
|
||||
channel_identity=identity,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
|
||||
def allowed_by_policy(
|
||||
*,
|
||||
policy: str | None,
|
||||
identifier: str | None,
|
||||
allowlist: list[str],
|
||||
default: str = "open",
|
||||
) -> bool:
|
||||
effective = (policy or default).strip().lower()
|
||||
if effective == "disabled":
|
||||
return False
|
||||
if effective == "allowlist":
|
||||
return bool(identifier and identifier in allowlist)
|
||||
return True
|
||||
@ -0,0 +1,207 @@
|
||||
"""Feishu/Lark channel adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||
|
||||
from .base import (
|
||||
build_inbound_message,
|
||||
chunk_text,
|
||||
compact_media_summary,
|
||||
config_bool,
|
||||
config_list,
|
||||
mark_unclaimed,
|
||||
outbound_target,
|
||||
)
|
||||
|
||||
EventRecorder = Callable[..., None]
|
||||
|
||||
|
||||
class FeishuAdapter:
|
||||
"""Feishu/Lark bot adapter with injectable client support."""
|
||||
|
||||
KIND = "feishu"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
mode: str,
|
||||
account_id: str,
|
||||
display_name: str | None,
|
||||
inbound_sink: ChannelInboundSink,
|
||||
secrets: dict[str, Any] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
event_recorder: EventRecorder | None = None,
|
||||
client: Any | None = None,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = kind
|
||||
self.mode = mode
|
||||
self.account_id = account_id
|
||||
self.display_name = display_name
|
||||
self.inbound_sink = inbound_sink
|
||||
self.secrets = secrets or {}
|
||||
self.config = config or {}
|
||||
self.event_recorder = event_recorder
|
||||
self._client = client
|
||||
self.max_message_chars = int(self.config.get("maxMessageChars") or 4096)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._client is not None:
|
||||
return
|
||||
if self.mode not in {"websocket", "webhook"}:
|
||||
raise ValueError(f"Unsupported feishu mode: {self.mode}")
|
||||
self._client = self._build_client()
|
||||
|
||||
async def stop(self) -> None:
|
||||
close = getattr(self._client, "close", None)
|
||||
if close is not None:
|
||||
result = close()
|
||||
if hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
async def handle_event_payload(self, payload: dict[str, Any]) -> None:
|
||||
message = self._normalize_payload(payload)
|
||||
if message is None:
|
||||
return
|
||||
await self.inbound_sink.accept_inbound(message)
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
target = outbound_target(message)
|
||||
if not target.peer_id:
|
||||
mark_unclaimed(message)
|
||||
return
|
||||
client = self._require_client()
|
||||
for chunk in chunk_text(message.content, max_chars=self.max_message_chars):
|
||||
await client.send_text(receive_id_type="chat_id", receive_id=target.peer_id, text=chunk)
|
||||
|
||||
def _normalize_payload(self, payload: dict[str, Any]) -> InboundMessage | None:
|
||||
event = payload.get("event") if isinstance(payload.get("event"), dict) else payload
|
||||
message = event.get("message") if isinstance(event.get("message"), dict) else {}
|
||||
sender = event.get("sender") if isinstance(event.get("sender"), dict) else {}
|
||||
|
||||
peer_id = _string_or_none(message.get("chat_id"))
|
||||
if not peer_id:
|
||||
return None
|
||||
|
||||
message_id = _string_or_none(message.get("message_id"))
|
||||
message_type = str(message.get("message_type") or "unknown")
|
||||
chat_type = str(message.get("chat_type") or "unknown")
|
||||
peer_type = "dm" if chat_type == "p2p" else "group"
|
||||
user_id = _sender_open_id(sender)
|
||||
|
||||
if peer_type == "dm" and not self._dm_allowed(user_id or peer_id):
|
||||
return None
|
||||
if peer_type == "group" and not self._group_allowed(peer_id, user_id):
|
||||
return None
|
||||
if peer_type == "group" and config_bool(self.config, "requireMentionInGroups", default=False):
|
||||
if not self._message_mentions_bot(message):
|
||||
return None
|
||||
|
||||
content = self._message_content(message_type, message)
|
||||
if not content:
|
||||
return None
|
||||
|
||||
metadata = {
|
||||
"chat_id": peer_id,
|
||||
"message_id": message_id,
|
||||
"chat_type": chat_type,
|
||||
"message_type": message_type,
|
||||
}
|
||||
|
||||
return build_inbound_message(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=peer_id,
|
||||
peer_type=peer_type,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _message_content(self, message_type: str, message: dict[str, Any]) -> str:
|
||||
content = _parse_json_object(message.get("content"))
|
||||
if message_type == "text":
|
||||
return str(content.get("text") or "").strip()
|
||||
file_name = _string_or_none(content.get("file_name") or content.get("name"))
|
||||
return compact_media_summary(message_type, file_name=file_name)
|
||||
|
||||
def _message_mentions_bot(self, message: dict[str, Any]) -> bool:
|
||||
bot_open_id = _string_or_none(self.config.get("botOpenId"))
|
||||
if not bot_open_id:
|
||||
return False
|
||||
mentions = message.get("mentions")
|
||||
if not isinstance(mentions, list):
|
||||
return False
|
||||
for mention in mentions:
|
||||
if not isinstance(mention, dict):
|
||||
continue
|
||||
mention_id = mention.get("id") if isinstance(mention.get("id"), dict) else {}
|
||||
if _string_or_none(mention_id.get("open_id")) == bot_open_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _dm_allowed(self, identifier: str | None) -> bool:
|
||||
allowlist = config_list(self.config, "allowFrom")
|
||||
if not allowlist:
|
||||
return True
|
||||
return bool(identifier and identifier in allowlist)
|
||||
|
||||
def _group_allowed(self, peer_id: str | None, user_id: str | None) -> bool:
|
||||
allowlist = config_list(self.config, "groupAllowFrom")
|
||||
if not allowlist:
|
||||
return True
|
||||
return bool((peer_id and peer_id in allowlist) or (user_id and user_id in allowlist))
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_client()
|
||||
return self._client
|
||||
|
||||
def _build_client(self) -> Any:
|
||||
self._require_secret("appId")
|
||||
self._require_secret("appSecret")
|
||||
try:
|
||||
import lark_oapi # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||
raise RuntimeError("Install beaver-backend[feishu] to enable FeishuAdapter") from exc
|
||||
raise RuntimeError("Feishu live client is not configured for direct construction")
|
||||
|
||||
def _require_secret(self, key: str) -> str:
|
||||
value = self.secrets.get(key)
|
||||
if not value:
|
||||
raise ValueError(f"{key} is required")
|
||||
return str(value)
|
||||
|
||||
|
||||
def _parse_json_object(value: Any) -> dict[str, Any]:
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
def _sender_open_id(sender: dict[str, Any]) -> str | None:
|
||||
sender_id = sender.get("sender_id") if isinstance(sender.get("sender_id"), dict) else {}
|
||||
return _string_or_none(sender_id.get("open_id"))
|
||||
|
||||
|
||||
def _string_or_none(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
@ -0,0 +1,206 @@
|
||||
"""QQ Bot channel adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||
|
||||
from .base import (
|
||||
allowed_by_policy,
|
||||
build_inbound_message,
|
||||
chunk_text,
|
||||
compact_media_summary,
|
||||
config_list,
|
||||
mark_unclaimed,
|
||||
outbound_target,
|
||||
)
|
||||
|
||||
EventRecorder = Callable[..., None]
|
||||
|
||||
|
||||
class QQBotAdapter:
|
||||
"""QQ Bot API adapter with injectable client support."""
|
||||
|
||||
KIND = "qqbot"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
mode: str,
|
||||
account_id: str,
|
||||
display_name: str | None,
|
||||
inbound_sink: ChannelInboundSink,
|
||||
secrets: dict[str, Any] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
event_recorder: EventRecorder | None = None,
|
||||
client: Any | None = None,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = kind
|
||||
self.mode = mode
|
||||
self.account_id = account_id
|
||||
self.display_name = display_name
|
||||
self.inbound_sink = inbound_sink
|
||||
self.secrets = secrets or {}
|
||||
self.config = config or {}
|
||||
self.event_recorder = event_recorder
|
||||
self._client = client
|
||||
self.max_message_chars = int(self.config.get("maxMessageChars") or 2000)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._client is not None:
|
||||
return
|
||||
if self.mode != "websocket":
|
||||
raise ValueError(f"Unsupported qqbot mode: {self.mode}")
|
||||
self._client = self._build_client()
|
||||
|
||||
async def stop(self) -> None:
|
||||
close = getattr(self._client, "close", None)
|
||||
if close is not None:
|
||||
result = close()
|
||||
if hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
async def handle_event_payload(self, payload: dict[str, Any]) -> None:
|
||||
message = self._normalize_payload(payload)
|
||||
if message is None:
|
||||
return
|
||||
await self.inbound_sink.accept_inbound(message)
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
target = outbound_target(message)
|
||||
if not target.peer_id:
|
||||
mark_unclaimed(message)
|
||||
return
|
||||
client = self._require_client()
|
||||
platform_message_id = message.channel_identity.message_id if message.channel_identity else None
|
||||
for chunk in chunk_text(message.content, max_chars=self.max_message_chars):
|
||||
await client.send_text(
|
||||
peer_type=target.peer_type,
|
||||
peer_id=target.peer_id,
|
||||
content=chunk,
|
||||
message_id=platform_message_id,
|
||||
)
|
||||
|
||||
def _normalize_payload(self, payload: dict[str, Any]) -> InboundMessage | None:
|
||||
event_type = str(payload.get("t") or payload.get("type") or "")
|
||||
data = payload.get("d") if isinstance(payload.get("d"), dict) else payload
|
||||
author = data.get("author") if isinstance(data.get("author"), dict) else {}
|
||||
|
||||
route = self._route(event_type, data, author)
|
||||
if route is None:
|
||||
return None
|
||||
peer_id, peer_type, user_id, thread_id = route
|
||||
|
||||
if peer_type == "dm":
|
||||
if not allowed_by_policy(
|
||||
policy=self.config.get("dmPolicy"),
|
||||
identifier=user_id or peer_id,
|
||||
allowlist=config_list(self.config, "allowFrom"),
|
||||
default="open",
|
||||
):
|
||||
return None
|
||||
elif peer_type == "group":
|
||||
if not allowed_by_policy(
|
||||
policy=self.config.get("groupPolicy"),
|
||||
identifier=peer_id,
|
||||
allowlist=config_list(self.config, "groupAllowFrom"),
|
||||
default="open",
|
||||
):
|
||||
return None
|
||||
|
||||
message_id = _string_or_none(data.get("id"))
|
||||
content = str(data.get("content") or "").strip()
|
||||
media_entries = self._media_entries(data)
|
||||
if media_entries:
|
||||
content = "\n".join([part for part in [content, *media_entries] if part]).strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
metadata = {
|
||||
"event_type": event_type,
|
||||
"message_id": message_id,
|
||||
"peer_type": peer_type,
|
||||
}
|
||||
if media_entries:
|
||||
metadata["media"] = media_entries
|
||||
|
||||
return build_inbound_message(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=peer_id,
|
||||
thread_id=thread_id,
|
||||
peer_type=peer_type,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _route(
|
||||
self,
|
||||
event_type: str,
|
||||
data: dict[str, Any],
|
||||
author: dict[str, Any],
|
||||
) -> tuple[str, str, str | None, str | None] | None:
|
||||
if event_type == "C2C_MESSAGE_CREATE":
|
||||
peer_id = _string_or_none(author.get("user_openid"))
|
||||
if not peer_id:
|
||||
return None
|
||||
return peer_id, "dm", peer_id, None
|
||||
if event_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
peer_id = _string_or_none(data.get("group_openid"))
|
||||
if not peer_id:
|
||||
return None
|
||||
return peer_id, "group", _string_or_none(author.get("member_openid")), None
|
||||
if data.get("guild_id") and data.get("channel_id"):
|
||||
peer_id = _string_or_none(data.get("channel_id"))
|
||||
if not peer_id:
|
||||
return None
|
||||
return peer_id, "channel", _string_or_none(author.get("id")), _string_or_none(data.get("guild_id"))
|
||||
return None
|
||||
|
||||
def _media_entries(self, data: dict[str, Any]) -> list[str]:
|
||||
entries: list[str] = []
|
||||
attachments = data.get("attachments")
|
||||
if not isinstance(attachments, list):
|
||||
return entries
|
||||
for attachment in attachments:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
media_type = str(attachment.get("content_type") or attachment.get("type") or "attachment")
|
||||
entries.append(compact_media_summary(media_type, file_name=_string_or_none(attachment.get("filename"))))
|
||||
return entries
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_client()
|
||||
return self._client
|
||||
|
||||
def _build_client(self) -> Any:
|
||||
self._require_secret("appId")
|
||||
self._require_secret("clientSecret")
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||
raise RuntimeError("Install beaver-backend[qqbot] to enable QQBotAdapter") from exc
|
||||
raise RuntimeError("QQBot live client is not configured for direct construction")
|
||||
|
||||
def _require_secret(self, key: str) -> str:
|
||||
value = self.secrets.get(key)
|
||||
if not value:
|
||||
raise ValueError(f"{key} is required")
|
||||
return str(value)
|
||||
|
||||
|
||||
def _string_or_none(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
@ -0,0 +1,244 @@
|
||||
"""Telegram channel adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||
|
||||
from .base import (
|
||||
build_inbound_message,
|
||||
chunk_text,
|
||||
compact_media_summary,
|
||||
config_bool,
|
||||
config_list,
|
||||
mark_unclaimed,
|
||||
outbound_target,
|
||||
)
|
||||
|
||||
EventRecorder = Callable[..., None]
|
||||
|
||||
|
||||
class TelegramAdapter:
|
||||
"""Telegram Bot API adapter with injectable client support."""
|
||||
|
||||
KIND = "telegram"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
mode: str,
|
||||
account_id: str,
|
||||
display_name: str | None,
|
||||
inbound_sink: ChannelInboundSink,
|
||||
secrets: dict[str, Any] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
event_recorder: EventRecorder | None = None,
|
||||
client: Any | None = None,
|
||||
application_factory: Callable[[], Any] | None = None,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = kind
|
||||
self.mode = mode
|
||||
self.account_id = account_id
|
||||
self.display_name = display_name
|
||||
self.inbound_sink = inbound_sink
|
||||
self.secrets = secrets or {}
|
||||
self.config = config or {}
|
||||
self.event_recorder = event_recorder
|
||||
self._client = client
|
||||
self._application_factory = application_factory
|
||||
self._application: Any | None = None
|
||||
self.max_message_chars = int(self.config.get("maxMessageChars") or 4096)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._client is not None:
|
||||
return
|
||||
if self.mode == "polling":
|
||||
self._application = self._build_application()
|
||||
await self._application.initialize()
|
||||
await self._application.start()
|
||||
if getattr(self._application, "updater", None) is not None:
|
||||
await self._application.updater.start_polling()
|
||||
self._client = self._application.bot
|
||||
return
|
||||
if self.mode == "webhook":
|
||||
self._client = self._build_bot()
|
||||
return
|
||||
raise ValueError(f"Unsupported telegram mode: {self.mode}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._application is None:
|
||||
return
|
||||
updater = getattr(self._application, "updater", None)
|
||||
if updater is not None:
|
||||
await updater.stop()
|
||||
await self._application.stop()
|
||||
await self._application.shutdown()
|
||||
self._application = None
|
||||
|
||||
async def handle_update_payload(self, payload: dict[str, Any]) -> None:
|
||||
message = self._normalize_payload(payload)
|
||||
if message is None:
|
||||
return
|
||||
await self.inbound_sink.accept_inbound(message)
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
target = outbound_target(message)
|
||||
if not target.peer_id:
|
||||
mark_unclaimed(message)
|
||||
return
|
||||
client = self._require_client()
|
||||
kwargs: dict[str, Any] = {"chat_id": target.peer_id}
|
||||
if target.thread_id:
|
||||
kwargs["message_thread_id"] = int(target.thread_id) if str(target.thread_id).isdigit() else target.thread_id
|
||||
for chunk in chunk_text(message.content, max_chars=self.max_message_chars):
|
||||
await client.send_message(**kwargs, text=chunk)
|
||||
|
||||
def _normalize_payload(self, payload: dict[str, Any]) -> InboundMessage | None:
|
||||
data = payload.get("message") or payload.get("edited_message")
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
chat = data.get("chat") if isinstance(data.get("chat"), dict) else {}
|
||||
sender = data.get("from") if isinstance(data.get("from"), dict) else {}
|
||||
peer_id = _string_or_none(chat.get("id"))
|
||||
if not peer_id:
|
||||
return None
|
||||
|
||||
chat_type = str(chat.get("type") or "unknown")
|
||||
peer_type = self._peer_type(chat_type)
|
||||
user_id = _string_or_none(sender.get("id"))
|
||||
message_id = _string_or_none(data.get("message_id"))
|
||||
thread_id = _string_or_none(data.get("message_thread_id"))
|
||||
|
||||
content = str(data.get("text") or data.get("caption") or "").strip()
|
||||
media_entries = self._media_entries(data)
|
||||
if media_entries:
|
||||
content = "\n".join([part for part in [content, *media_entries] if part]).strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
if peer_type in {"group", "channel"} and not self._group_allowed(peer_id, user_id):
|
||||
return None
|
||||
if peer_type == "dm" and not self._dm_allowed(user_id or peer_id):
|
||||
return None
|
||||
|
||||
if peer_type in {"group", "channel"} and config_bool(self.config, "requireMentionInGroups", default=False):
|
||||
gated = self._strip_required_mention(content)
|
||||
if gated is None:
|
||||
return None
|
||||
content = gated
|
||||
|
||||
metadata = {
|
||||
"chat_id": peer_id,
|
||||
"message_id": message_id,
|
||||
"chat_type": chat_type,
|
||||
}
|
||||
if media_entries:
|
||||
metadata["media"] = media_entries
|
||||
|
||||
return build_inbound_message(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=peer_id,
|
||||
thread_id=thread_id,
|
||||
peer_type=peer_type,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _media_entries(self, data: dict[str, Any]) -> list[str]:
|
||||
entries: list[str] = []
|
||||
if data.get("photo"):
|
||||
entries.append(compact_media_summary("photo"))
|
||||
for media_type in ("document", "audio", "video"):
|
||||
value = data.get(media_type)
|
||||
if isinstance(value, dict):
|
||||
entries.append(compact_media_summary(media_type, file_name=_string_or_none(value.get("file_name"))))
|
||||
return entries
|
||||
|
||||
def _strip_required_mention(self, content: str) -> str | None:
|
||||
username = str(self.config.get("botUsername") or "").strip().lstrip("@")
|
||||
if not username:
|
||||
return None
|
||||
mention = f"@{username}"
|
||||
if mention not in content:
|
||||
return None
|
||||
return content.replace(mention, "", 1).strip()
|
||||
|
||||
def _dm_allowed(self, identifier: str | None) -> bool:
|
||||
allowlist = config_list(self.config, "allowFrom")
|
||||
if not allowlist:
|
||||
return True
|
||||
return bool(identifier and identifier in allowlist)
|
||||
|
||||
def _group_allowed(self, peer_id: str | None, user_id: str | None) -> bool:
|
||||
allowlist = config_list(self.config, "groupAllowFrom")
|
||||
if not allowlist:
|
||||
return True
|
||||
return bool((peer_id and peer_id in allowlist) or (user_id and user_id in allowlist))
|
||||
|
||||
def _peer_type(self, chat_type: str) -> str:
|
||||
if chat_type == "private":
|
||||
return "dm"
|
||||
if chat_type in {"group", "supergroup"}:
|
||||
return "group"
|
||||
if chat_type == "channel":
|
||||
return "channel"
|
||||
return chat_type or "unknown"
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_bot()
|
||||
return self._client
|
||||
|
||||
def _build_bot(self) -> Any:
|
||||
token = self._require_secret("botToken")
|
||||
try:
|
||||
from telegram import Bot
|
||||
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||
raise RuntimeError("Install beaver-backend[telegram] to enable TelegramAdapter") from exc
|
||||
return Bot(token=token)
|
||||
|
||||
def _build_application(self) -> Any:
|
||||
if self._application_factory is not None:
|
||||
return self._application_factory()
|
||||
token = self._require_secret("botToken")
|
||||
try:
|
||||
from telegram.ext import Application
|
||||
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||
raise RuntimeError("Install beaver-backend[telegram] to enable TelegramAdapter") from exc
|
||||
|
||||
async def handle(update: Any, context: Any) -> None:
|
||||
if hasattr(update, "to_dict"):
|
||||
await self.handle_update_payload(update.to_dict())
|
||||
|
||||
application = Application.builder().token(token).build()
|
||||
try:
|
||||
from telegram.ext import MessageHandler, filters
|
||||
|
||||
application.add_handler(MessageHandler(filters.ALL, handle))
|
||||
except Exception:
|
||||
pass
|
||||
return application
|
||||
|
||||
def _require_secret(self, key: str) -> str:
|
||||
value = self.secrets.get(key)
|
||||
if not value:
|
||||
raise ValueError(f"{key} is required")
|
||||
return str(value)
|
||||
|
||||
|
||||
def _string_or_none(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
@ -0,0 +1,180 @@
|
||||
"""Weixin channel adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||
|
||||
from .base import (
|
||||
allowed_by_policy,
|
||||
build_inbound_message,
|
||||
chunk_text,
|
||||
compact_media_summary,
|
||||
config_list,
|
||||
mark_unclaimed,
|
||||
outbound_target,
|
||||
)
|
||||
|
||||
EventRecorder = Callable[..., None]
|
||||
|
||||
|
||||
class WeixinAdapter:
|
||||
"""Tencent iLink-style Weixin adapter with injectable client support."""
|
||||
|
||||
KIND = "weixin"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
mode: str,
|
||||
account_id: str,
|
||||
display_name: str | None,
|
||||
inbound_sink: ChannelInboundSink,
|
||||
secrets: dict[str, Any] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
event_recorder: EventRecorder | None = None,
|
||||
client: Any | None = None,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = kind
|
||||
self.mode = mode
|
||||
self.account_id = account_id
|
||||
self.display_name = display_name
|
||||
self.inbound_sink = inbound_sink
|
||||
self.secrets = secrets or {}
|
||||
self.config = config or {}
|
||||
self.event_recorder = event_recorder
|
||||
self._client = client
|
||||
self.max_message_chars = int(self.config.get("maxMessageChars") or 2000)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._client is not None:
|
||||
return
|
||||
if self.mode != "polling":
|
||||
raise ValueError(f"Unsupported weixin mode: {self.mode}")
|
||||
self._client = self._build_client()
|
||||
|
||||
async def stop(self) -> None:
|
||||
close = getattr(self._client, "close", None)
|
||||
if close is not None:
|
||||
result = close()
|
||||
if hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
async def handle_message_payload(self, payload: dict[str, Any]) -> None:
|
||||
message = self._normalize_payload(payload)
|
||||
if message is None:
|
||||
return
|
||||
await self.inbound_sink.accept_inbound(message)
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
target = outbound_target(message)
|
||||
if not target.peer_id:
|
||||
mark_unclaimed(message)
|
||||
return
|
||||
client = self._require_client()
|
||||
context_token = self._context_token(message)
|
||||
for chunk in chunk_text(message.content, max_chars=self.max_message_chars):
|
||||
await client.send_text(peer_id=target.peer_id, text=chunk, context_token=context_token)
|
||||
|
||||
def _normalize_payload(self, payload: dict[str, Any]) -> InboundMessage | None:
|
||||
sender_id = _string_or_none(payload.get("from") or payload.get("from_user"))
|
||||
room_id = _string_or_none(payload.get("room_id") or payload.get("roomId"))
|
||||
message_id = _string_or_none(payload.get("id") or payload.get("message_id"))
|
||||
message_type = str(payload.get("type") or payload.get("message_type") or "text")
|
||||
|
||||
if room_id:
|
||||
peer_id = room_id
|
||||
peer_type = "group"
|
||||
user_id = sender_id
|
||||
if not allowed_by_policy(
|
||||
policy=self.config.get("groupPolicy"),
|
||||
identifier=peer_id,
|
||||
allowlist=config_list(self.config, "groupAllowFrom"),
|
||||
default="disabled",
|
||||
):
|
||||
return None
|
||||
else:
|
||||
peer_id = sender_id
|
||||
peer_type = "dm"
|
||||
user_id = sender_id
|
||||
if not allowed_by_policy(
|
||||
policy=self.config.get("dmPolicy"),
|
||||
identifier=peer_id,
|
||||
allowlist=config_list(self.config, "allowFrom"),
|
||||
default="open",
|
||||
):
|
||||
return None
|
||||
if not peer_id:
|
||||
return None
|
||||
|
||||
content = self._content(message_type, payload)
|
||||
if not content:
|
||||
return None
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"message_type": message_type,
|
||||
}
|
||||
context_token = _string_or_none(payload.get("context_token") or payload.get("contextToken"))
|
||||
if context_token:
|
||||
metadata["context_token"] = context_token
|
||||
if room_id:
|
||||
metadata["room_id"] = room_id
|
||||
|
||||
return build_inbound_message(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=peer_id,
|
||||
peer_type=peer_type,
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _content(self, message_type: str, payload: dict[str, Any]) -> str:
|
||||
if message_type == "text":
|
||||
return str(payload.get("text") or payload.get("content") or "").strip()
|
||||
file_name = _string_or_none(payload.get("file_name") or payload.get("filename"))
|
||||
return compact_media_summary(message_type, file_name=file_name)
|
||||
|
||||
def _context_token(self, message: OutboundMessage) -> str | None:
|
||||
inbound_metadata = message.metadata.get("inbound_metadata")
|
||||
if isinstance(inbound_metadata, dict):
|
||||
value = _string_or_none(inbound_metadata.get("context_token"))
|
||||
if value:
|
||||
return value
|
||||
return _string_or_none(message.metadata.get("context_token"))
|
||||
|
||||
def _require_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_client()
|
||||
return self._client
|
||||
|
||||
def _build_client(self) -> Any:
|
||||
self._require_secret("token")
|
||||
try:
|
||||
import aiohttp # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover - optional live dependency
|
||||
raise RuntimeError("Install beaver-backend[weixin] to enable WeixinAdapter") from exc
|
||||
raise RuntimeError("Weixin live client is not configured for direct construction")
|
||||
|
||||
def _require_secret(self, key: str) -> str:
|
||||
value = self.secrets.get(key)
|
||||
if not value:
|
||||
raise ValueError(f"{key} is required")
|
||||
return str(value)
|
||||
|
||||
|
||||
def _string_or_none(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
526
app-instance/backend/beaver/interfaces/channels/runtime.py
Normal file
526
app-instance/backend/beaver/interfaces/channels/runtime.py
Normal file
@ -0,0 +1,526 @@
|
||||
"""Channel runtime host for adapter lifecycle and bus-first routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.config.schema import ChannelConfig
|
||||
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelAdapter
|
||||
from beaver.interfaces.channels.manager import ChannelManager
|
||||
from beaver.interfaces.channels.state import ChannelDedupeStore, ChannelEventLog
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _channel_capabilities(kind: str, mode: str) -> list[str]:
|
||||
if kind == "webhook":
|
||||
return ["receive_text", "send_text", "sync_webhook_response"]
|
||||
if kind == "terminal" and mode == "websocket":
|
||||
return ["receive_text", "send_text", "persistent_connection"]
|
||||
if kind in {"feishu", "qqbot", "telegram"}:
|
||||
return ["receive_text", "send_text", "receive_media", "groups"]
|
||||
if kind == "weixin":
|
||||
return ["receive_text", "send_text", "receive_media", "direct_messages"]
|
||||
return []
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ChannelAcceptResult:
|
||||
accepted: bool
|
||||
duplicate: bool = False
|
||||
pending: bool = False
|
||||
rejected: bool = False
|
||||
session_id: str | None = None
|
||||
dedupe_key: str | None = None
|
||||
record: dict[str, Any] | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChannelRuntime:
|
||||
"""Own channel adapters, state, and the inbound/outbound bus bridge."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service: AgentService,
|
||||
workspace: Path,
|
||||
channels: dict[str, ChannelConfig],
|
||||
bus: MessageBus | None = None,
|
||||
) -> None:
|
||||
self.service = service
|
||||
self.workspace = Path(workspace)
|
||||
self.bus = bus or MessageBus()
|
||||
self.manager = ChannelManager(self.bus)
|
||||
self.channel_configs = dict(channels)
|
||||
self.adapters: dict[str, ChannelAdapter] = {}
|
||||
self.states: dict[str, dict[str, Any]] = {}
|
||||
state_dir = self.workspace / "state" / "channels"
|
||||
retention = self._default_dedupe_retention_hours()
|
||||
self.dedupe = ChannelDedupeStore(state_dir / "dedupe.json", retention_hours=retention)
|
||||
self.events = ChannelEventLog(state_dir / "events.jsonl")
|
||||
self._bridge_task: asyncio.Task[None] | None = None
|
||||
self._dispatch_task: asyncio.Task[None] | None = None
|
||||
self._stop_event = asyncio.Event()
|
||||
self._dispatch_stop_event = asyncio.Event()
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
|
||||
async def start(self) -> None:
|
||||
self._stop_event.clear()
|
||||
self._dispatch_stop_event.clear()
|
||||
for channel_id, cfg in self.channel_configs.items():
|
||||
if not cfg.enabled:
|
||||
self.states[channel_id] = {"state": "disabled", "last_error": None}
|
||||
continue
|
||||
try:
|
||||
adapter = self._build_adapter(channel_id, cfg)
|
||||
self.adapters[channel_id] = adapter
|
||||
self.manager.register(adapter)
|
||||
await adapter.start()
|
||||
self.states[channel_id] = {
|
||||
"state": "running",
|
||||
"last_error": None,
|
||||
"started_at": _iso_now(),
|
||||
}
|
||||
self.events.record(channel_id=channel_id, kind="adapter_started")
|
||||
except Exception as exc: # pragma: no cover - defensive startup isolation
|
||||
self.states[channel_id] = {"state": "error", "last_error": str(exc)}
|
||||
self.events.record(
|
||||
channel_id=channel_id,
|
||||
kind="adapter_error",
|
||||
status="error",
|
||||
error=str(exc),
|
||||
)
|
||||
self._bridge_task = asyncio.create_task(self._bridge_inbound_to_agent())
|
||||
self._dispatch_task = asyncio.create_task(
|
||||
self.manager.dispatch_outbound(
|
||||
self._dispatch_stop_event,
|
||||
on_delivered=self._record_outbound_delivered,
|
||||
on_failed=self._record_outbound_failed,
|
||||
)
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stop_event.set()
|
||||
if self._bridge_task is not None:
|
||||
self._bridge_task.cancel()
|
||||
try:
|
||||
await self._bridge_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._dispatch_stop_event.set()
|
||||
if self._dispatch_task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(self._dispatch_task, timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._dispatch_task.cancel()
|
||||
try:
|
||||
await self._dispatch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self.manager.stop()
|
||||
for channel_id in self.adapters:
|
||||
self.events.record(channel_id=channel_id, kind="adapter_stopped")
|
||||
|
||||
async def add_channel(self, channel_id: str, config: ChannelConfig) -> None:
|
||||
async with self._lifecycle_lock:
|
||||
current = self.channel_configs.get(channel_id)
|
||||
if current == config and channel_id in self.adapters:
|
||||
return
|
||||
if not config.enabled:
|
||||
await self._remove_channel_locked(channel_id)
|
||||
self.channel_configs[channel_id] = config
|
||||
self.states[channel_id] = {"state": "disabled", "last_error": None}
|
||||
return
|
||||
|
||||
adapter = self._build_adapter(channel_id, config)
|
||||
await adapter.start()
|
||||
old_adapter = self.adapters.get(channel_id)
|
||||
self.manager.replace_registered(adapter)
|
||||
self.adapters[channel_id] = adapter
|
||||
self.channel_configs[channel_id] = config
|
||||
self.states[channel_id] = {"state": "running", "last_error": None, "started_at": _iso_now()}
|
||||
self.events.record(channel_id=channel_id, kind="adapter_started")
|
||||
if old_adapter is not None and old_adapter is not adapter:
|
||||
await old_adapter.stop()
|
||||
|
||||
async def remove_channel(self, channel_id: str) -> None:
|
||||
async with self._lifecycle_lock:
|
||||
await self._remove_channel_locked(channel_id)
|
||||
|
||||
async def _remove_channel_locked(self, channel_id: str) -> None:
|
||||
adapter = self.adapters.pop(channel_id, None)
|
||||
self.manager.unregister(channel_id)
|
||||
self.channel_configs.pop(channel_id, None)
|
||||
if adapter is not None:
|
||||
await adapter.stop()
|
||||
self.events.record(channel_id=channel_id, kind="adapter_stopped")
|
||||
self.states[channel_id] = {"state": "removed", "last_error": None}
|
||||
|
||||
async def accept_inbound(self, message: InboundMessage) -> ChannelAcceptResult:
|
||||
identity = message.channel_identity
|
||||
if identity is None:
|
||||
self.events.record(
|
||||
channel_id=message.channel,
|
||||
kind="inbound_rejected",
|
||||
status="error",
|
||||
error="channel_identity is required",
|
||||
)
|
||||
return ChannelAcceptResult(
|
||||
accepted=False,
|
||||
rejected=True,
|
||||
error="channel_identity is required",
|
||||
)
|
||||
|
||||
validation_error = identity.validation_error()
|
||||
if validation_error:
|
||||
self.events.record(
|
||||
channel_id=identity.channel_id,
|
||||
kind="inbound_rejected",
|
||||
status="error",
|
||||
error=validation_error,
|
||||
)
|
||||
return ChannelAcceptResult(accepted=False, rejected=True, error=validation_error)
|
||||
|
||||
expected_session_id = identity.session_id()
|
||||
if message.session_id != expected_session_id:
|
||||
self.events.record(
|
||||
channel_id=identity.channel_id,
|
||||
kind="session_id_normalized",
|
||||
session_id=expected_session_id,
|
||||
message_id=identity.message_id,
|
||||
)
|
||||
message.session_id = expected_session_id
|
||||
message.channel = identity.channel_id
|
||||
|
||||
dedupe_key = identity.dedupe_key()
|
||||
if dedupe_key:
|
||||
write = self.dedupe.mark_processing(
|
||||
dedupe_key=dedupe_key,
|
||||
session_id=expected_session_id,
|
||||
message_id=identity.message_id or "",
|
||||
)
|
||||
if not write.created:
|
||||
record = write.record or {}
|
||||
self.events.record(
|
||||
channel_id=identity.channel_id,
|
||||
kind="inbound_duplicate",
|
||||
session_id=expected_session_id,
|
||||
message_id=identity.message_id,
|
||||
status=str(record.get("status") or "processing"),
|
||||
)
|
||||
return ChannelAcceptResult(
|
||||
accepted=False,
|
||||
duplicate=True,
|
||||
pending=record.get("status") == "processing",
|
||||
session_id=expected_session_id,
|
||||
dedupe_key=dedupe_key,
|
||||
record=record,
|
||||
)
|
||||
|
||||
self.events.record(
|
||||
channel_id=identity.channel_id,
|
||||
kind="inbound_accepted",
|
||||
session_id=expected_session_id,
|
||||
message_id=identity.message_id,
|
||||
text=message.content,
|
||||
)
|
||||
await self.bus.publish_inbound(message)
|
||||
return ChannelAcceptResult(
|
||||
accepted=True,
|
||||
session_id=expected_session_id,
|
||||
dedupe_key=dedupe_key,
|
||||
)
|
||||
|
||||
def statuses(self) -> list[dict[str, Any]]:
|
||||
items: list[dict[str, Any]] = []
|
||||
recent = self.events.recent(limit=500)
|
||||
last_by_channel = {event["channel_id"]: event for event in recent if event.get("channel_id")}
|
||||
for channel_id, cfg in self.channel_configs.items():
|
||||
state = self.states.get(channel_id, {"state": "configured", "last_error": None})
|
||||
capabilities = _channel_capabilities(cfg.kind, cfg.mode)
|
||||
webhook_url = None
|
||||
websocket_url = None
|
||||
connected_peers = 0
|
||||
if cfg.kind == "webhook":
|
||||
webhook_url = f"/api/channels/{channel_id}/webhook"
|
||||
elif cfg.kind == "terminal" and cfg.mode == "websocket":
|
||||
websocket_url = f"/api/channels/{channel_id}/ws"
|
||||
adapter = self.adapters.get(channel_id)
|
||||
if adapter is not None and hasattr(adapter, "status_extra"):
|
||||
extra = adapter.status_extra() # type: ignore[attr-defined]
|
||||
connected_peers = int(extra.get("connected_peers") or 0)
|
||||
items.append(
|
||||
{
|
||||
"channel_id": channel_id,
|
||||
"name": channel_id,
|
||||
"kind": cfg.kind,
|
||||
"mode": cfg.mode,
|
||||
"display_name": cfg.display_name or channel_id,
|
||||
"enabled": cfg.enabled,
|
||||
"state": state.get("state", "configured"),
|
||||
"account_id": cfg.account_id,
|
||||
"last_error": state.get("last_error"),
|
||||
"started_at": state.get("started_at"),
|
||||
"last_event_at": last_by_channel.get(channel_id, {}).get("created_at"),
|
||||
"capabilities": capabilities,
|
||||
"webhook_url": webhook_url,
|
||||
"websocket_url": websocket_url,
|
||||
"connected_peers": connected_peers,
|
||||
}
|
||||
)
|
||||
return items
|
||||
|
||||
def recent_events(self, channel_id: str, *, limit: int = 100) -> list[dict[str, Any]]:
|
||||
return self.events.recent(channel_id=channel_id, limit=limit)
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
session_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
run_id: str | None = None,
|
||||
status: str = "ok",
|
||||
error: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self.events.record(
|
||||
channel_id=channel_id,
|
||||
kind=kind,
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
run_id=run_id,
|
||||
status=status,
|
||||
error=error,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _build_adapter(self, channel_id: str, cfg: ChannelConfig) -> ChannelAdapter:
|
||||
if cfg.kind == "webhook" and cfg.mode == "webhook":
|
||||
from beaver.interfaces.channels.generic_webhook import GenericWebhookAdapter
|
||||
|
||||
return GenericWebhookAdapter(
|
||||
channel_id=channel_id,
|
||||
kind=cfg.kind,
|
||||
mode=cfg.mode,
|
||||
account_id=cfg.account_id,
|
||||
display_name=cfg.display_name,
|
||||
inbound_sink=self,
|
||||
response_timeout_seconds=float(cfg.config.get("response_timeout_seconds") or 1800),
|
||||
)
|
||||
|
||||
if cfg.kind == "terminal" and cfg.mode == "websocket":
|
||||
from beaver.interfaces.channels.terminal_websocket import TerminalWebSocketAdapter
|
||||
|
||||
return TerminalWebSocketAdapter(
|
||||
channel_id=channel_id,
|
||||
kind=cfg.kind,
|
||||
mode=cfg.mode,
|
||||
account_id=cfg.account_id,
|
||||
display_name=cfg.display_name,
|
||||
inbound_sink=self,
|
||||
event_recorder=self.record_event,
|
||||
heartbeat_seconds=float(cfg.config.get("heartbeat_seconds") or 30),
|
||||
max_message_chars=int(cfg.config.get("max_message_chars") or 20000),
|
||||
)
|
||||
|
||||
if cfg.kind == "telegram" and cfg.mode in {"polling", "webhook"}:
|
||||
from beaver.interfaces.channels.platforms.telegram import TelegramAdapter
|
||||
|
||||
return TelegramAdapter(
|
||||
channel_id=channel_id,
|
||||
kind=cfg.kind,
|
||||
mode=cfg.mode,
|
||||
account_id=cfg.account_id,
|
||||
display_name=cfg.display_name,
|
||||
inbound_sink=self,
|
||||
secrets=cfg.secrets,
|
||||
config=cfg.config,
|
||||
event_recorder=self.record_event,
|
||||
)
|
||||
|
||||
if cfg.kind == "feishu" and cfg.mode in {"websocket", "webhook"}:
|
||||
from beaver.interfaces.channels.platforms.feishu import FeishuAdapter
|
||||
|
||||
return FeishuAdapter(
|
||||
channel_id=channel_id,
|
||||
kind=cfg.kind,
|
||||
mode=cfg.mode,
|
||||
account_id=cfg.account_id,
|
||||
display_name=cfg.display_name,
|
||||
inbound_sink=self,
|
||||
secrets=cfg.secrets,
|
||||
config=cfg.config,
|
||||
event_recorder=self.record_event,
|
||||
)
|
||||
|
||||
if cfg.kind == "qqbot" and cfg.mode == "websocket":
|
||||
from beaver.interfaces.channels.platforms.qqbot import QQBotAdapter
|
||||
|
||||
return QQBotAdapter(
|
||||
channel_id=channel_id,
|
||||
kind=cfg.kind,
|
||||
mode=cfg.mode,
|
||||
account_id=cfg.account_id,
|
||||
display_name=cfg.display_name,
|
||||
inbound_sink=self,
|
||||
secrets=cfg.secrets,
|
||||
config=cfg.config,
|
||||
event_recorder=self.record_event,
|
||||
)
|
||||
|
||||
if cfg.kind == "weixin" and cfg.mode == "polling":
|
||||
from beaver.interfaces.channels.platforms.weixin import WeixinAdapter
|
||||
|
||||
return WeixinAdapter(
|
||||
channel_id=channel_id,
|
||||
kind=cfg.kind,
|
||||
mode=cfg.mode,
|
||||
account_id=cfg.account_id,
|
||||
display_name=cfg.display_name,
|
||||
inbound_sink=self,
|
||||
secrets=cfg.secrets,
|
||||
config=cfg.config,
|
||||
event_recorder=self.record_event,
|
||||
)
|
||||
|
||||
if cfg.kind == "external_connector" and cfg.mode == "http":
|
||||
import os
|
||||
|
||||
from beaver.interfaces.channels.connections.sidecar_client import ConnectorSidecarClient
|
||||
from beaver.interfaces.channels.external_connector import ExternalConnectorChannel
|
||||
|
||||
base_url = str(cfg.config.get("sidecarBaseUrl") or os.getenv("EXTERNAL_CONNECTOR_BASE_URL") or "").strip()
|
||||
token = os.getenv("EXTERNAL_CONNECTOR_TOKEN", "")
|
||||
platform_kind = str(cfg.config.get("platformKind") or "").strip()
|
||||
connection_id = str(cfg.config.get("connectionId") or "").strip()
|
||||
if not base_url:
|
||||
raise ValueError("external connector sidecarBaseUrl is required")
|
||||
if not platform_kind:
|
||||
raise ValueError("external connector platformKind is required")
|
||||
if not connection_id:
|
||||
raise ValueError("external connector connectionId is required")
|
||||
return ExternalConnectorChannel(
|
||||
channel_id=channel_id,
|
||||
platform_kind=platform_kind,
|
||||
connection_id=connection_id,
|
||||
account_id=cfg.account_id,
|
||||
display_name=cfg.display_name,
|
||||
sidecar_client=ConnectorSidecarClient(base_url=base_url, token=token),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported channel kind/mode: {cfg.kind}/{cfg.mode}")
|
||||
|
||||
async def _bridge_inbound_to_agent(self) -> None:
|
||||
current_inbound: InboundMessage | None = None
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
current_inbound = await asyncio.wait_for(self.bus.consume_inbound(), timeout=0.25)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
inbound = current_inbound
|
||||
identity = inbound.channel_identity
|
||||
try:
|
||||
self.events.record(
|
||||
channel_id=inbound.channel,
|
||||
kind="direct_run_started",
|
||||
session_id=inbound.session_id,
|
||||
message_id=identity.message_id if identity else inbound.message_id,
|
||||
)
|
||||
outbound = await self.service.handle_inbound_message(inbound)
|
||||
except asyncio.CancelledError:
|
||||
outbound = AgentService.build_outbound_error(
|
||||
inbound,
|
||||
detail="Channel runtime stopped before completing the inbound message",
|
||||
finish_reason="cancelled",
|
||||
)
|
||||
self._mark_dedupe_result(inbound, outbound)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
current_inbound = None
|
||||
raise
|
||||
except Exception as exc:
|
||||
self.events.record(
|
||||
channel_id=inbound.channel,
|
||||
kind="direct_run_failed",
|
||||
session_id=inbound.session_id,
|
||||
message_id=identity.message_id if identity else inbound.message_id,
|
||||
status="error",
|
||||
error=str(exc),
|
||||
)
|
||||
outbound = AgentService.build_outbound_error(
|
||||
inbound,
|
||||
detail=str(exc),
|
||||
finish_reason="error",
|
||||
)
|
||||
else:
|
||||
self.events.record(
|
||||
channel_id=outbound.channel,
|
||||
kind="direct_run_finished",
|
||||
session_id=outbound.session_id,
|
||||
message_id=identity.message_id if identity else inbound.message_id,
|
||||
run_id=outbound.run_id,
|
||||
)
|
||||
self._mark_dedupe_result(inbound, outbound)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
current_inbound = None
|
||||
|
||||
def _mark_dedupe_result(self, inbound: InboundMessage, outbound: OutboundMessage) -> None:
|
||||
identity = inbound.channel_identity
|
||||
dedupe_key = identity.dedupe_key() if identity else None
|
||||
if not dedupe_key:
|
||||
return
|
||||
cfg = self.channel_configs.get(identity.channel_id)
|
||||
max_reply_chars = int((cfg.config if cfg else {}).get("max_cached_reply_chars") or 20000)
|
||||
max_error_chars = int((cfg.config if cfg else {}).get("max_cached_error_chars") or 4000)
|
||||
if outbound.finish_reason == "error":
|
||||
self.dedupe.mark_error(
|
||||
dedupe_key=dedupe_key,
|
||||
error=outbound.content,
|
||||
max_error_chars=max_error_chars,
|
||||
)
|
||||
else:
|
||||
self.dedupe.mark_done(
|
||||
dedupe_key=dedupe_key,
|
||||
run_id=outbound.run_id,
|
||||
reply=outbound.content,
|
||||
max_reply_chars=max_reply_chars,
|
||||
)
|
||||
|
||||
async def _record_outbound_delivered(self, message: OutboundMessage) -> None:
|
||||
kind = "outbound_unclaimed" if message.metadata.get("delivery_status") == "unclaimed" else "outbound_delivered"
|
||||
self.events.record(
|
||||
channel_id=message.channel,
|
||||
kind=kind,
|
||||
session_id=message.session_id,
|
||||
message_id=message.channel_identity.message_id if message.channel_identity else message.message_id,
|
||||
run_id=message.run_id,
|
||||
)
|
||||
|
||||
async def _record_outbound_failed(self, message: OutboundMessage, exc: Exception | None) -> None:
|
||||
self.events.record(
|
||||
channel_id=message.channel,
|
||||
kind="outbound_delivery_failed",
|
||||
session_id=message.session_id,
|
||||
message_id=message.channel_identity.message_id if message.channel_identity else message.message_id,
|
||||
run_id=message.run_id,
|
||||
status="error",
|
||||
error=str(exc) if exc else "channel not registered",
|
||||
)
|
||||
|
||||
def _default_dedupe_retention_hours(self) -> int:
|
||||
for cfg in self.channel_configs.values():
|
||||
value = cfg.config.get("dedupe_retention_hours")
|
||||
if value is not None:
|
||||
return int(value)
|
||||
return 48
|
||||
198
app-instance/backend/beaver/interfaces/channels/state.py
Normal file
198
app-instance/backend/beaver/interfaces/channels/state.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""Persistent channel runtime state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _iso_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class DedupeWriteResult:
|
||||
created: bool
|
||||
record: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChannelDedupeStore:
|
||||
def __init__(self, path: Path, *, retention_hours: int = 48) -> None:
|
||||
self.path = path
|
||||
self.retention_ms = max(1, int(retention_hours)) * 60 * 60 * 1000
|
||||
self._lock = Lock()
|
||||
|
||||
def get(self, dedupe_key: str) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
self._prune_unlocked(data, _now_ms())
|
||||
record = data["records"].get(dedupe_key)
|
||||
self._save(data)
|
||||
return record
|
||||
|
||||
def mark_processing(self, *, dedupe_key: str, session_id: str, message_id: str) -> DedupeWriteResult:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
now_ms = _now_ms()
|
||||
self._prune_unlocked(data, now_ms)
|
||||
existing = data["records"].get(dedupe_key)
|
||||
if existing is not None:
|
||||
self._save(data)
|
||||
return DedupeWriteResult(created=False, record=existing)
|
||||
|
||||
record = {
|
||||
"dedupe_key": dedupe_key,
|
||||
"status": "processing",
|
||||
"session_id": session_id,
|
||||
"message_id": message_id,
|
||||
"run_id": None,
|
||||
"reply": None,
|
||||
"error": None,
|
||||
"created_at_ms": now_ms,
|
||||
"updated_at_ms": now_ms,
|
||||
}
|
||||
data["records"][dedupe_key] = record
|
||||
self._save(data)
|
||||
return DedupeWriteResult(created=True, record=record)
|
||||
|
||||
def mark_done(
|
||||
self,
|
||||
*,
|
||||
dedupe_key: str,
|
||||
run_id: str | None,
|
||||
reply: str,
|
||||
max_reply_chars: int,
|
||||
) -> None:
|
||||
self._mark_result(
|
||||
dedupe_key=dedupe_key,
|
||||
status="done",
|
||||
run_id=run_id,
|
||||
reply=reply[: max(0, int(max_reply_chars))],
|
||||
error=None,
|
||||
)
|
||||
|
||||
def mark_error(self, *, dedupe_key: str, error: str, max_error_chars: int) -> None:
|
||||
self._mark_result(
|
||||
dedupe_key=dedupe_key,
|
||||
status="error",
|
||||
run_id=None,
|
||||
reply=None,
|
||||
error=error[: max(0, int(max_error_chars))],
|
||||
)
|
||||
|
||||
def _mark_result(
|
||||
self,
|
||||
*,
|
||||
dedupe_key: str,
|
||||
status: str,
|
||||
run_id: str | None,
|
||||
reply: str | None,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
with self._lock:
|
||||
data = self._load()
|
||||
record = data["records"].get(dedupe_key)
|
||||
if record is None:
|
||||
record = {"dedupe_key": dedupe_key, "created_at_ms": _now_ms()}
|
||||
data["records"][dedupe_key] = record
|
||||
record.update(
|
||||
{
|
||||
"status": status,
|
||||
"run_id": run_id,
|
||||
"reply": reply,
|
||||
"error": error,
|
||||
"updated_at_ms": _now_ms(),
|
||||
}
|
||||
)
|
||||
self._save(data)
|
||||
|
||||
def _load(self) -> dict[str, Any]:
|
||||
if not self.path.exists():
|
||||
return {"records": {}}
|
||||
try:
|
||||
data = json.loads(self.path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {"records": {}}
|
||||
if not isinstance(data, dict) or not isinstance(data.get("records"), dict):
|
||||
return {"records": {}}
|
||||
return data
|
||||
|
||||
def _save(self, data: dict[str, Any]) -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = self.path.with_name(f"{self.path.name}.tmp")
|
||||
tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
tmp_path.replace(self.path)
|
||||
|
||||
def _prune_unlocked(self, data: dict[str, Any], now_ms: int) -> None:
|
||||
records = data.get("records", {})
|
||||
expired_before = now_ms - self.retention_ms
|
||||
for key, record in list(records.items()):
|
||||
updated_at_ms = int(record.get("updated_at_ms") or record.get("created_at_ms") or 0)
|
||||
if updated_at_ms < expired_before:
|
||||
records.pop(key, None)
|
||||
|
||||
|
||||
class ChannelEventLog:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = path
|
||||
self._lock = Lock()
|
||||
|
||||
def record(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
session_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
run_id: str | None = None,
|
||||
status: str = "ok",
|
||||
error: str | None = None,
|
||||
text: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
entry = {
|
||||
"event_id": uuid4().hex,
|
||||
"channel_id": channel_id,
|
||||
"kind": kind,
|
||||
"session_id": session_id,
|
||||
"message_id": message_id,
|
||||
"run_id": run_id,
|
||||
"status": status,
|
||||
"error": error,
|
||||
"text_preview": (text or "")[:120] if text else None,
|
||||
"text_length": len(text) if text else 0,
|
||||
"metadata": metadata or {},
|
||||
"created_at": _iso_now(),
|
||||
}
|
||||
with self._lock:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
return entry
|
||||
|
||||
def recent(self, *, channel_id: str | None = None, limit: int = 100) -> list[dict[str, Any]]:
|
||||
if not self.path.exists():
|
||||
return []
|
||||
lines = self.path.read_text(encoding="utf-8").splitlines()
|
||||
items: list[dict[str, Any]] = []
|
||||
for line in reversed(lines):
|
||||
try:
|
||||
item = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if channel_id and item.get("channel_id") != channel_id:
|
||||
continue
|
||||
items.append(item)
|
||||
if len(items) >= max(1, int(limit)):
|
||||
break
|
||||
return list(reversed(items))
|
||||
@ -0,0 +1,301 @@
|
||||
"""Text-only terminal WebSocket channel adapter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||
from beaver.interfaces.channels.base import ChannelInboundSink
|
||||
|
||||
try:
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
except ModuleNotFoundError: # pragma: no cover - import-only fallback
|
||||
class WebSocketDisconnect(Exception):
|
||||
"""Fallback disconnect exception for skeleton import environments."""
|
||||
|
||||
class WebSocket: # type: ignore[override]
|
||||
"""Fallback websocket annotation shim."""
|
||||
|
||||
|
||||
def _clean(value: Any) -> str:
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TerminalConnection:
|
||||
websocket: WebSocket
|
||||
peer_id: str
|
||||
session_id: str
|
||||
thread_id: str | None = None
|
||||
user_id: str | None = None
|
||||
device_name: str = ""
|
||||
capabilities: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class TerminalWebSocketAdapter:
|
||||
"""Accept text terminal websocket frames and deliver final assistant replies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channel_id: str,
|
||||
kind: str,
|
||||
mode: str,
|
||||
account_id: str,
|
||||
display_name: str = "",
|
||||
inbound_sink: ChannelInboundSink,
|
||||
event_recorder: Callable[..., None] | None = None,
|
||||
heartbeat_seconds: float = 30,
|
||||
max_message_chars: int = 20000,
|
||||
) -> None:
|
||||
self.channel_id = channel_id
|
||||
self.kind = kind
|
||||
self.mode = mode
|
||||
self.account_id = account_id
|
||||
self.display_name = display_name or channel_id
|
||||
self.inbound_sink = inbound_sink
|
||||
self.event_recorder = event_recorder
|
||||
self.heartbeat_seconds = max(1.0, float(heartbeat_seconds))
|
||||
self.max_message_chars = max(1, int(max_message_chars))
|
||||
self.started = False
|
||||
self._connections_by_session: dict[str, TerminalConnection] = {}
|
||||
self._session_by_peer: dict[str, str] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
self.started = True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.started = False
|
||||
for connection in list(self._connections_by_session.values()):
|
||||
with suppress(Exception):
|
||||
await connection.websocket.close(code=1001)
|
||||
self._connections_by_session.clear()
|
||||
self._session_by_peer.clear()
|
||||
|
||||
def status_extra(self) -> dict[str, Any]:
|
||||
return {"connected_peers": len(self._connections_by_session)}
|
||||
|
||||
async def handle_websocket(self, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
connection: TerminalConnection | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
payload = await websocket.receive_json()
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
except ValueError:
|
||||
await websocket.send_json({"type": "error", "error": "Invalid websocket JSON payload"})
|
||||
continue
|
||||
if not isinstance(payload, dict):
|
||||
await websocket.send_json({"type": "error", "error": "Websocket payload must be a JSON object"})
|
||||
continue
|
||||
|
||||
frame_type = _clean(payload.get("type")).lower()
|
||||
if frame_type == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
continue
|
||||
if frame_type == "connect":
|
||||
connection = await self._handle_connect(websocket, payload, current=connection)
|
||||
continue
|
||||
if frame_type == "message":
|
||||
if connection is None:
|
||||
await websocket.send_json({"type": "error", "error": "connect is required before message"})
|
||||
continue
|
||||
await self._handle_message(websocket, connection, payload)
|
||||
continue
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"error": f"Unsupported websocket frame type: {frame_type or '<empty>'}",
|
||||
}
|
||||
)
|
||||
finally:
|
||||
if connection is not None:
|
||||
self._remove_connection(connection)
|
||||
self._record(
|
||||
kind="terminal_disconnected",
|
||||
session_id=connection.session_id,
|
||||
metadata={"peer_id": connection.peer_id, "device_name": connection.device_name},
|
||||
)
|
||||
|
||||
async def _handle_connect(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
current: TerminalConnection | None,
|
||||
) -> TerminalConnection | None:
|
||||
peer_id = _clean(payload.get("peer_id"))
|
||||
if not peer_id:
|
||||
await websocket.send_json({"type": "error", "error": "peer_id is required"})
|
||||
return current
|
||||
|
||||
thread_id = _clean(payload.get("thread_id")) or None
|
||||
user_id = _clean(payload.get("user_id")) or None
|
||||
device_name = _clean(payload.get("device_name"))
|
||||
capabilities = [str(item) for item in payload.get("capabilities") or [] if item is not None]
|
||||
identity = ChannelIdentity(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=peer_id,
|
||||
thread_id=thread_id,
|
||||
peer_type="terminal",
|
||||
user_id=user_id,
|
||||
)
|
||||
session_id = identity.session_id()
|
||||
connection = TerminalConnection(
|
||||
websocket=websocket,
|
||||
peer_id=peer_id,
|
||||
session_id=session_id,
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
device_name=device_name,
|
||||
capabilities=capabilities,
|
||||
)
|
||||
|
||||
if current is not None and current.session_id != session_id:
|
||||
self._remove_connection(current)
|
||||
old = self._connections_by_session.get(session_id)
|
||||
if old is not None and old.websocket is not websocket:
|
||||
with suppress(Exception):
|
||||
await old.websocket.close(code=1000)
|
||||
self._connections_by_session[session_id] = connection
|
||||
self._session_by_peer[peer_id] = session_id
|
||||
self._record(
|
||||
kind="terminal_connected",
|
||||
session_id=session_id,
|
||||
metadata={"peer_id": peer_id, "device_name": device_name, "capabilities": capabilities},
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "connected",
|
||||
"channel_id": self.channel_id,
|
||||
"session_id": session_id,
|
||||
}
|
||||
)
|
||||
return connection
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
connection: TerminalConnection,
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
message_id = _clean(payload.get("message_id"))
|
||||
text = _clean(payload.get("text"))
|
||||
if not message_id:
|
||||
await websocket.send_json({"type": "error", "error": "message_id is required"})
|
||||
return
|
||||
if not text:
|
||||
await websocket.send_json({"type": "error", "error": "text is required"})
|
||||
return
|
||||
if len(text) > self.max_message_chars:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"error": f"text exceeds max_message_chars ({self.max_message_chars})",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
thread_id = _clean(payload.get("thread_id")) or connection.thread_id
|
||||
user_id = _clean(payload.get("user_id")) or connection.user_id
|
||||
identity = ChannelIdentity(
|
||||
channel_id=self.channel_id,
|
||||
kind=self.kind,
|
||||
account_id=self.account_id,
|
||||
peer_id=connection.peer_id,
|
||||
thread_id=thread_id,
|
||||
peer_type="terminal",
|
||||
user_id=user_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
inbound = InboundMessage(
|
||||
channel=self.channel_id,
|
||||
content=text,
|
||||
content_type="text",
|
||||
user_id=user_id,
|
||||
channel_identity=identity,
|
||||
metadata={
|
||||
"terminal": {
|
||||
"peer_id": connection.peer_id,
|
||||
"device_name": connection.device_name,
|
||||
"capabilities": connection.capabilities,
|
||||
}
|
||||
},
|
||||
)
|
||||
accept = await self.inbound_sink.accept_inbound(inbound)
|
||||
ack: dict[str, Any] = {
|
||||
"type": "ack",
|
||||
"message_id": message_id,
|
||||
"session_id": accept.session_id or identity.session_id(),
|
||||
"accepted": accept.accepted,
|
||||
}
|
||||
if accept.duplicate:
|
||||
ack["duplicate"] = True
|
||||
ack["pending"] = accept.pending
|
||||
record = accept.record or {}
|
||||
if record.get("reply"):
|
||||
ack["reply"] = record["reply"]
|
||||
if accept.error or record.get("error"):
|
||||
ack["error"] = accept.error or record.get("error")
|
||||
await websocket.send_json(ack)
|
||||
|
||||
async def send(self, message: OutboundMessage) -> None:
|
||||
session_id = message.session_id
|
||||
if not session_id and message.channel_identity is not None:
|
||||
session_id = message.channel_identity.session_id()
|
||||
connection = self._connections_by_session.get(session_id or "")
|
||||
if connection is None:
|
||||
message.metadata["delivery_status"] = "unclaimed"
|
||||
return
|
||||
|
||||
payload = {
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"message_id": message.channel_identity.message_id if message.channel_identity else message.message_id,
|
||||
"run_id": message.run_id,
|
||||
"text": message.content,
|
||||
"finish_reason": message.finish_reason,
|
||||
}
|
||||
try:
|
||||
await connection.websocket.send_json(payload)
|
||||
except Exception:
|
||||
message.metadata["delivery_status"] = "unclaimed"
|
||||
self._remove_connection(connection)
|
||||
|
||||
def _remove_connection(self, connection: TerminalConnection) -> None:
|
||||
current = self._connections_by_session.get(connection.session_id)
|
||||
if current is connection:
|
||||
self._connections_by_session.pop(connection.session_id, None)
|
||||
if self._session_by_peer.get(connection.peer_id) == connection.session_id:
|
||||
self._session_by_peer.pop(connection.peer_id, None)
|
||||
|
||||
def _record(
|
||||
self,
|
||||
*,
|
||||
kind: str,
|
||||
session_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
status: str = "ok",
|
||||
error: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if self.event_recorder is None:
|
||||
return
|
||||
self.event_recorder(
|
||||
channel_id=self.channel_id,
|
||||
kind=kind,
|
||||
session_id=session_id,
|
||||
message_id=message_id,
|
||||
status=status,
|
||||
error=error,
|
||||
metadata=metadata,
|
||||
)
|
||||
@ -27,12 +27,8 @@ from beaver.tools.builtins import (
|
||||
CronTool,
|
||||
DelegateTool,
|
||||
ExecuteCodeTool,
|
||||
ListDirectoryTool,
|
||||
MemoryTool,
|
||||
PatchFileTool,
|
||||
ProcessTool,
|
||||
ReadFileTool,
|
||||
SearchFilesTool,
|
||||
SendMessageTool,
|
||||
SkillManageTool,
|
||||
SkillViewTool,
|
||||
@ -40,6 +36,12 @@ from beaver.tools.builtins import (
|
||||
SpawnTool,
|
||||
TerminalTool,
|
||||
TodoTool,
|
||||
UserFilesCopyToWorkspaceTool,
|
||||
UserFilesListTool,
|
||||
UserFilesMkdirTool,
|
||||
UserFilesPublishOutputTool,
|
||||
UserFilesReadTool,
|
||||
UserFilesWriteTool,
|
||||
WebFetchTool,
|
||||
WebSearchTool,
|
||||
WriteFileTool,
|
||||
@ -47,7 +49,7 @@ from beaver.tools.builtins import (
|
||||
|
||||
|
||||
LOCAL_TOOL_CATEGORIES = {
|
||||
"filesystem": "Beaver Local Filesystem Tools",
|
||||
"filesystem": "Beaver Personal Agent Filesystem Tools",
|
||||
"runtime": "Beaver Local Runtime Tools",
|
||||
"memory": "Beaver Local Memory Tools",
|
||||
"skills": "Beaver Local Skills Tools",
|
||||
@ -84,11 +86,12 @@ def _category_tools(category: str, workspace: Path) -> tuple[list[BaseTool], Too
|
||||
|
||||
if category == "filesystem":
|
||||
tools: list[BaseTool] = [
|
||||
ObjectBackedTool(ListDirectoryTool()),
|
||||
ObjectBackedTool(ReadFileTool()),
|
||||
ObjectBackedTool(SearchFilesTool()),
|
||||
ObjectBackedTool(WriteFileTool()),
|
||||
ObjectBackedTool(PatchFileTool()),
|
||||
ObjectBackedTool(UserFilesListTool()),
|
||||
ObjectBackedTool(UserFilesReadTool()),
|
||||
ObjectBackedTool(UserFilesWriteTool()),
|
||||
ObjectBackedTool(UserFilesMkdirTool()),
|
||||
ObjectBackedTool(UserFilesCopyToWorkspaceTool()),
|
||||
ObjectBackedTool(UserFilesPublishOutputTool()),
|
||||
]
|
||||
elif category == "runtime":
|
||||
tools = [
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
152
app-instance/backend/beaver/interfaces/web/keycloak_auth.py
Normal file
152
app-instance/backend/beaver/interfaces/web/keycloak_auth.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""Keycloak OIDC token verification for the Beaver web app."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
|
||||
try:
|
||||
from fastapi import HTTPException
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
class HTTPException(Exception): # type: ignore[override]
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
super().__init__(detail)
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
|
||||
|
||||
DEFAULT_KEYCLOAK_ISSUER = "https://keycloak.bwgdi.com/realms/beaver"
|
||||
DEFAULT_KEYCLOAK_CLIENT_ID = "beaver-agnet"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class KeycloakAuthConfig:
|
||||
issuer: str
|
||||
client_id: str
|
||||
token_url: str
|
||||
jwks_url: str
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "KeycloakAuthConfig":
|
||||
issuer = _clean_base_url(os.getenv("BEAVER_KEYCLOAK_ISSUER") or DEFAULT_KEYCLOAK_ISSUER)
|
||||
client_id = (os.getenv("BEAVER_KEYCLOAK_CLIENT_ID") or DEFAULT_KEYCLOAK_CLIENT_ID).strip()
|
||||
token_url = (
|
||||
os.getenv("BEAVER_KEYCLOAK_TOKEN_URL", "").strip()
|
||||
or f"{issuer}/protocol/openid-connect/token"
|
||||
)
|
||||
jwks_url = (
|
||||
os.getenv("BEAVER_KEYCLOAK_JWKS_URL", "").strip()
|
||||
or f"{issuer}/protocol/openid-connect/certs"
|
||||
)
|
||||
return cls(issuer=issuer, client_id=client_id, token_url=token_url, jwks_url=jwks_url)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class KeycloakIdentity:
|
||||
user_id: str
|
||||
username: str
|
||||
email: str = ""
|
||||
name: str = ""
|
||||
realm_roles: tuple[str, ...] = ()
|
||||
client_roles: tuple[str, ...] = ()
|
||||
|
||||
|
||||
def extract_bearer_token(authorization: str | None) -> str:
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
||||
prefix = "bearer "
|
||||
if not authorization.lower().startswith(prefix):
|
||||
raise HTTPException(status_code=401, detail="Invalid Authorization header")
|
||||
token = authorization[len(prefix):].strip()
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
return token
|
||||
|
||||
|
||||
class KeycloakTokenVerifier:
|
||||
def __init__(self, *, config: KeycloakAuthConfig) -> None:
|
||||
self.config = config
|
||||
self._jwks_client = PyJWKClient(config.jwks_url)
|
||||
|
||||
def verify(self, token: str, *, expected_nonce: str | None = None) -> KeycloakIdentity:
|
||||
try:
|
||||
signing_key = self._jwks_client.get_signing_key_from_jwt(token).key
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key,
|
||||
algorithms=["RS256"],
|
||||
issuer=self.config.issuer,
|
||||
options={
|
||||
"require": ["exp", "iat", "iss"],
|
||||
"verify_aud": False,
|
||||
},
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - normalize JWT/JWKS failures for HTTP callers
|
||||
raise HTTPException(status_code=401, detail=f"Invalid token: {exc}") from exc
|
||||
return self.validate_claims(claims, expected_nonce=expected_nonce)
|
||||
|
||||
def validate_claims(self, claims: dict[str, Any], *, expected_nonce: str | None = None) -> KeycloakIdentity:
|
||||
now = int(time.time())
|
||||
issuer = str(claims.get("iss") or "")
|
||||
if issuer != self.config.issuer:
|
||||
raise HTTPException(status_code=401, detail="Invalid token issuer")
|
||||
exp = _int_claim(claims, "exp")
|
||||
iat = _int_claim(claims, "iat")
|
||||
if exp <= now:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
if iat > now + 120:
|
||||
raise HTTPException(status_code=401, detail="Token issued in the future")
|
||||
if not _matches_client(claims.get("aud"), self.config.client_id) and claims.get("azp") != self.config.client_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid token audience")
|
||||
if expected_nonce is not None and claims.get("nonce") != expected_nonce:
|
||||
raise HTTPException(status_code=401, detail="Invalid token nonce")
|
||||
|
||||
user_id = str(claims.get("sub") or "").strip()
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Token subject is required")
|
||||
username = (
|
||||
str(claims.get("preferred_username") or "").strip()
|
||||
or str(claims.get("email") or "").strip()
|
||||
or user_id
|
||||
)
|
||||
return KeycloakIdentity(
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
email=str(claims.get("email") or "").strip(),
|
||||
name=str(claims.get("name") or "").strip(),
|
||||
realm_roles=_roles_from(claims.get("realm_access")),
|
||||
client_roles=_roles_from((claims.get("resource_access") or {}).get(self.config.client_id) if isinstance(claims.get("resource_access"), dict) else None),
|
||||
)
|
||||
|
||||
|
||||
def _clean_base_url(value: str) -> str:
|
||||
return value.strip().rstrip("/")
|
||||
|
||||
|
||||
def _int_claim(claims: dict[str, Any], key: str) -> int:
|
||||
try:
|
||||
return int(claims[key])
|
||||
except (KeyError, TypeError, ValueError) as exc:
|
||||
raise HTTPException(status_code=401, detail=f"Token {key} claim is required") from exc
|
||||
|
||||
|
||||
def _matches_client(audience: Any, client_id: str) -> bool:
|
||||
if isinstance(audience, str):
|
||||
return audience == client_id
|
||||
if isinstance(audience, list):
|
||||
return client_id in {str(item) for item in audience}
|
||||
return False
|
||||
|
||||
|
||||
def _roles_from(value: Any) -> tuple[str, ...]:
|
||||
if not isinstance(value, dict):
|
||||
return ()
|
||||
roles = value.get("roles")
|
||||
if not isinstance(roles, list):
|
||||
return ()
|
||||
return tuple(str(role) for role in roles if str(role).strip())
|
||||
@ -10,6 +10,16 @@ from .chat import (
|
||||
WebErrorResponse,
|
||||
WebAgentConfigRequest,
|
||||
WebAgentConfigResponse,
|
||||
WebChannelConfigRequest,
|
||||
WebChannelConfigResponse,
|
||||
WebChannelConnectionCreateRequest,
|
||||
WebChannelConnectionResponse,
|
||||
WebChannelConnectionUpdateRequest,
|
||||
WebChannelValidationResponse,
|
||||
WebConnectorBridgeEventRequest,
|
||||
WebConnectorBridgeEventResponse,
|
||||
WebConnectorSessionCreateRequest,
|
||||
WebConnectorSessionResponse,
|
||||
WebProviderConfigRequest,
|
||||
WebProviderConfigResponse,
|
||||
WebProviderTarget,
|
||||
@ -26,6 +36,16 @@ __all__ = [
|
||||
"WebErrorResponse",
|
||||
"WebAgentConfigRequest",
|
||||
"WebAgentConfigResponse",
|
||||
"WebChannelConfigRequest",
|
||||
"WebChannelConfigResponse",
|
||||
"WebChannelConnectionCreateRequest",
|
||||
"WebChannelConnectionResponse",
|
||||
"WebChannelConnectionUpdateRequest",
|
||||
"WebChannelValidationResponse",
|
||||
"WebConnectorBridgeEventRequest",
|
||||
"WebConnectorBridgeEventResponse",
|
||||
"WebConnectorSessionCreateRequest",
|
||||
"WebConnectorSessionResponse",
|
||||
"WebProviderConfigRequest",
|
||||
"WebProviderConfigResponse",
|
||||
"WebProviderTarget",
|
||||
|
||||
@ -55,6 +55,7 @@ class WebChatRequest(BaseModel):
|
||||
user_id: str | None = None
|
||||
title: str | None = None
|
||||
execution_context: str | None = None
|
||||
prompt_locale: str | None = None
|
||||
model: str | None = None
|
||||
provider_name: str | None = None
|
||||
embedding_model: str | None = None
|
||||
@ -139,6 +140,113 @@ class WebProviderConfigResponse(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class WebChannelConfigRequest(BaseModel):
|
||||
"""Channel config update from the settings page."""
|
||||
|
||||
enabled: bool = False
|
||||
kind: str
|
||||
mode: str
|
||||
account_id: str | None = None
|
||||
display_name: str | None = None
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
secrets: dict[str, str | None] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WebChannelConfigResponse(BaseModel):
|
||||
"""Channel config update result."""
|
||||
|
||||
ok: bool
|
||||
channel_id: str
|
||||
restart_required: bool
|
||||
channel: dict[str, Any]
|
||||
|
||||
|
||||
class WebChannelConnectionCreateRequest(BaseModel):
|
||||
"""Create a channel connection from the setup UI."""
|
||||
|
||||
kind: str
|
||||
mode: str
|
||||
display_name: str | None = Field(default=None, alias="displayName")
|
||||
owner_user_id: str | None = Field(default=None, alias="ownerUserId")
|
||||
auth_type: str = Field(default="token", alias="authType")
|
||||
account_id: str | None = Field(default=None, alias="accountId")
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
secrets: dict[str, str | None] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WebChannelConnectionResponse(BaseModel):
|
||||
"""Channel connection response with redacted credentials."""
|
||||
|
||||
connection: dict[str, Any]
|
||||
credentials: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WebChannelConnectionUpdateRequest(BaseModel):
|
||||
"""Update editable channel connection setup fields."""
|
||||
|
||||
display_name: str | None = Field(default=None, alias="displayName")
|
||||
account_id: str | None = Field(default=None, alias="accountId")
|
||||
config: dict[str, Any] | None = None
|
||||
secrets: dict[str, str | None] | None = None
|
||||
|
||||
|
||||
class WebChannelValidationResponse(BaseModel):
|
||||
"""Connector validation response."""
|
||||
|
||||
ok: bool
|
||||
status: str
|
||||
account_id: str | None = None
|
||||
display_name: str | None = None
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
connection: dict[str, Any]
|
||||
|
||||
|
||||
class WebConnectorBridgeEventRequest(BaseModel):
|
||||
"""Inbound connector bridge event from the external sidecar."""
|
||||
|
||||
event_id: str = Field(alias="eventId")
|
||||
timestamp: str
|
||||
delivery_attempt: int = Field(default=1, alias="deliveryAttempt")
|
||||
connection_id: str = Field(alias="connectionId")
|
||||
channel_id: str = Field(alias="channelId")
|
||||
kind: str
|
||||
account_id: str = Field(alias="accountId")
|
||||
peer_id: str = Field(alias="peerId")
|
||||
peer_type: str = Field(default="unknown", alias="peerType")
|
||||
user_id: str | None = Field(default=None, alias="userId")
|
||||
thread_id: str | None = Field(default=None, alias="threadId")
|
||||
message_id: str = Field(alias="messageId")
|
||||
message_type: str = Field(default="text", alias="messageType")
|
||||
content: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WebConnectorBridgeEventResponse(BaseModel):
|
||||
"""Connector bridge event accept/dedupe response."""
|
||||
|
||||
accepted: bool
|
||||
duplicate: bool = False
|
||||
pending: bool = False
|
||||
retry_after_seconds: int | None = Field(default=None, alias="retryAfterSeconds")
|
||||
|
||||
|
||||
class WebConnectorSessionCreateRequest(BaseModel):
|
||||
"""Start a connector-managed onboarding session."""
|
||||
|
||||
kind: str
|
||||
display_name: str | None = Field(default=None, alias="displayName")
|
||||
owner_user_id: str | None = Field(default=None, alias="ownerUserId")
|
||||
options: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WebConnectorSessionResponse(BaseModel):
|
||||
"""Connector session view plus optional connection view."""
|
||||
|
||||
session: dict[str, Any]
|
||||
connection: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class WebAgentConfigRequest(BaseModel):
|
||||
"""Agent runtime defaults update from the settings page."""
|
||||
|
||||
|
||||
@ -227,6 +227,15 @@ class SkillDraftEvalReport:
|
||||
cases: list[dict[str, Any]] = field(default_factory=list)
|
||||
status: str = "completed"
|
||||
created_at: str = ""
|
||||
eval_version: str = "heuristic-v1"
|
||||
mode: str = "heuristic"
|
||||
execution_coverage: float = 0.0
|
||||
surrogate_coverage: float = 0.0
|
||||
blocked_coverage: float = 0.0
|
||||
confidence: str = "low"
|
||||
case_reports: list[dict[str, Any]] = field(default_factory=list)
|
||||
tool_mode_summary: dict[str, Any] = field(default_factory=dict)
|
||||
preservation_report: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
@ -244,6 +253,17 @@ class SkillDraftEvalReport:
|
||||
"cases": [dict(item) for item in self.cases],
|
||||
"status": self.status,
|
||||
"created_at": self.created_at,
|
||||
"eval_version": self.eval_version,
|
||||
"mode": self.mode,
|
||||
"execution_coverage": self.execution_coverage,
|
||||
"surrogate_coverage": self.surrogate_coverage,
|
||||
"blocked_coverage": self.blocked_coverage,
|
||||
"confidence": self.confidence,
|
||||
"case_reports": [dict(item) for item in self.case_reports],
|
||||
"tool_mode_summary": dict(self.tool_mode_summary),
|
||||
"preservation_report": (
|
||||
dict(self.preservation_report) if self.preservation_report is not None else None
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -263,6 +283,23 @@ class SkillDraftEvalReport:
|
||||
cases=[dict(item) for item in payload.get("cases") or [] if isinstance(item, dict)],
|
||||
status=str(payload.get("status") or "completed"),
|
||||
created_at=str(payload.get("created_at") or ""),
|
||||
eval_version=str(payload.get("eval_version") or "heuristic-v1"),
|
||||
mode=str(payload.get("mode") or "heuristic"),
|
||||
execution_coverage=_bounded_float(payload.get("execution_coverage"), default=0.0),
|
||||
surrogate_coverage=_bounded_float(payload.get("surrogate_coverage"), default=0.0),
|
||||
blocked_coverage=_bounded_float(payload.get("blocked_coverage"), default=0.0),
|
||||
confidence=str(payload.get("confidence") or "low"),
|
||||
case_reports=[
|
||||
dict(item)
|
||||
for item in payload.get("case_reports") or []
|
||||
if isinstance(item, dict)
|
||||
],
|
||||
tool_mode_summary=dict(payload.get("tool_mode_summary") or {}),
|
||||
preservation_report=(
|
||||
dict(payload["preservation_report"])
|
||||
if isinstance(payload.get("preservation_report"), dict)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -272,6 +309,15 @@ def _optional_str(value: Any) -> str | None:
|
||||
return str(value)
|
||||
|
||||
|
||||
def _bounded_float(value: Any, *, default: float = 0.0) -> float:
|
||||
if value in (None, ""):
|
||||
return default
|
||||
try:
|
||||
return max(0.0, min(1.0, float(value)))
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def _summarize_evidence(payload: dict[str, Any]) -> str:
|
||||
evidence = payload.get("evidence")
|
||||
if isinstance(evidence, dict):
|
||||
|
||||
5
app-instance/backend/beaver/prompts/__init__.py
Normal file
5
app-instance/backend/beaver/prompts/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Prompt templates used by Beaver runtime components."""
|
||||
|
||||
from .main_agent import get_main_agent_prompt
|
||||
|
||||
__all__ = ["get_main_agent_prompt"]
|
||||
55
app-instance/backend/beaver/prompts/main_agent.py
Normal file
55
app-instance/backend/beaver/prompts/main_agent.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Locale-aware main agent prompt loading."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_MAIN_AGENT_PROMPT_LOCALE = "zh-Hans"
|
||||
|
||||
_PROMPT_FILES = {
|
||||
"zh-Hans": "zh-Hans.md",
|
||||
"zh-Hant": "zh-Hant.md",
|
||||
"en": "en.md",
|
||||
}
|
||||
|
||||
_LOCALE_ALIASES = {
|
||||
"zh": "zh-Hans",
|
||||
"zh-cn": "zh-Hans",
|
||||
"zh-hans": "zh-Hans",
|
||||
"zh-sg": "zh-Hans",
|
||||
"zh-hant": "zh-Hant",
|
||||
"zh-tw": "zh-Hant",
|
||||
"zh-hk": "zh-Hant",
|
||||
"zh-mo": "zh-Hant",
|
||||
"en": "en",
|
||||
"en-us": "en",
|
||||
"en-gb": "en",
|
||||
}
|
||||
|
||||
|
||||
def get_main_agent_prompt(locale: str | None = None) -> str:
|
||||
"""Return the main-agent identity prompt for a prompt locale."""
|
||||
|
||||
prompt_locale = normalize_main_agent_prompt_locale(locale)
|
||||
return _load_main_agent_prompt(prompt_locale)
|
||||
|
||||
|
||||
def normalize_main_agent_prompt_locale(locale: str | None = None) -> str:
|
||||
cleaned = (locale or DEFAULT_MAIN_AGENT_PROMPT_LOCALE).strip()
|
||||
if not cleaned:
|
||||
return DEFAULT_MAIN_AGENT_PROMPT_LOCALE
|
||||
normalized = _LOCALE_ALIASES.get(cleaned.lower())
|
||||
if normalized:
|
||||
return normalized
|
||||
return cleaned if cleaned in _PROMPT_FILES else DEFAULT_MAIN_AGENT_PROMPT_LOCALE
|
||||
|
||||
|
||||
@lru_cache(maxsize=len(_PROMPT_FILES))
|
||||
def _load_main_agent_prompt(locale: str) -> str:
|
||||
filename = _PROMPT_FILES.get(locale, _PROMPT_FILES[DEFAULT_MAIN_AGENT_PROMPT_LOCALE])
|
||||
path = Path(__file__).with_name("main_agent") / filename
|
||||
if not path.exists():
|
||||
fallback_path = Path(__file__).with_name("main_agent") / _PROMPT_FILES[DEFAULT_MAIN_AGENT_PROMPT_LOCALE]
|
||||
return fallback_path.read_text(encoding="utf-8").strip()
|
||||
return path.read_text(encoding="utf-8").strip()
|
||||
7
app-instance/backend/beaver/prompts/main_agent/en.md
Normal file
7
app-instance/backend/beaver/prompts/main_agent/en.md
Normal file
@ -0,0 +1,7 @@
|
||||
You are Beaver, an AI assistant developed by Boway Information Systems Co., Ltd.
|
||||
|
||||
When communicating with users, keep this identity consistent. If users ask who you are, say that you are Beaver, an AI assistant developed by Boway Information Systems Co., Ltd.
|
||||
|
||||
# Language
|
||||
|
||||
Use English for user-facing replies, task titles, summaries, plans, and final reports while this prompt is active. If the user explicitly asks for another language, follow that request.
|
||||
@ -0,0 +1,7 @@
|
||||
你是海狸 (Beaver),博维资讯系统有限公司研发的 AI 助手。
|
||||
|
||||
与用户沟通时,保持这个身份一致。用户问你是谁时,说明你是海狸 (Beaver),博维资讯系统有限公司研发的 AI 助手。
|
||||
|
||||
# 语言
|
||||
|
||||
使用简体中文进行面向用户的回复、任务标题、摘要、计划和最终报告。若用户明确要求其他语言,则按用户要求执行。
|
||||
@ -0,0 +1,7 @@
|
||||
你是海狸 (Beaver),博維資訊系統有限公司研發的 AI 助手。
|
||||
|
||||
與使用者溝通時,保持這個身份一致。使用者問你是誰時,說明你是海狸 (Beaver),博維資訊系統有限公司研發的 AI 助手。
|
||||
|
||||
# 語言
|
||||
|
||||
使用繁體中文進行面向使用者的回覆、任務標題、摘要、計劃和最終報告。若使用者明確要求其他語言,則按使用者要求執行。
|
||||
@ -22,6 +22,7 @@ from beaver.engine import AgentLoop, AgentProfile, AgentRunResult, EngineLoader
|
||||
from beaver.engine.providers import make_provider_bundle
|
||||
from beaver.foundation.events import InboundMessage, OutboundMessage
|
||||
from beaver.foundation.models import CronJob, CronRunRecord
|
||||
from beaver.prompts.main_agent import normalize_main_agent_prompt_locale
|
||||
from beaver.tasks import (
|
||||
EvidenceBuilder,
|
||||
MainAgentRouter,
|
||||
@ -604,6 +605,8 @@ class AgentService:
|
||||
if active_task is not None and decision.short_title and not active_task.metadata.get("short_title"):
|
||||
active_task.metadata["short_title"] = decision.short_title
|
||||
task_service.store.upsert_task(active_task)
|
||||
if active_task is not None and (decision.action == "simple_chat" or decision.starts_new_task):
|
||||
await self._accept_active_task_for_new_topic(active_task)
|
||||
if active_task is not None and decision.closes_task:
|
||||
task_service.close_task(active_task.task_id, reason=decision.reason)
|
||||
return await runner(message, **kwargs)
|
||||
@ -620,6 +623,7 @@ class AgentService:
|
||||
session_id=session_id,
|
||||
description=message,
|
||||
metadata={
|
||||
"prompt_locale": normalize_main_agent_prompt_locale(kwargs.get("prompt_locale")),
|
||||
"router_reason": decision.reason,
|
||||
**({"short_title": decision.short_title} if decision.short_title else {}),
|
||||
},
|
||||
@ -636,6 +640,20 @@ class AgentService:
|
||||
)
|
||||
return await self._run_task_mode(message, runner=runner, kwargs=kwargs, task=task)
|
||||
|
||||
async def _accept_active_task_for_new_topic(self, task: TaskRecord) -> None:
|
||||
"""Accept a completed active Task before routing an unrelated new topic."""
|
||||
|
||||
if task.status != "awaiting_acceptance":
|
||||
return
|
||||
run_id = next((item for item in reversed(task.run_ids) if item), None)
|
||||
if not run_id:
|
||||
return
|
||||
await self.submit_acceptance(
|
||||
session_id=task.session_id,
|
||||
run_id=run_id,
|
||||
acceptance_type="accept",
|
||||
)
|
||||
|
||||
def _record_revision_acceptance_for_task(
|
||||
self,
|
||||
loaded: Any,
|
||||
@ -733,6 +751,8 @@ class AgentService:
|
||||
session_manager = self._require_loaded(loaded, "session_manager")
|
||||
|
||||
base_execution_context = kwargs.get("execution_context")
|
||||
prompt_locale = kwargs.get("prompt_locale") or task.metadata.get("prompt_locale")
|
||||
output_language_instruction = self._output_language_instruction(prompt_locale)
|
||||
provider_bundle = kwargs.get("provider_bundle") or self._make_provider_bundle_for_task(loaded, kwargs)
|
||||
kwargs = dict(kwargs)
|
||||
team_provider_bundle_factory = kwargs.pop("team_provider_bundle_factory", None)
|
||||
@ -827,8 +847,11 @@ class AgentService:
|
||||
"allow_candidate_generation": False,
|
||||
}
|
||||
)
|
||||
if team_execution_context:
|
||||
attempt_kwargs["execution_context"] = self._join_context(base_execution_context, team_execution_context)
|
||||
attempt_kwargs["execution_context"] = self._join_context(
|
||||
base_execution_context,
|
||||
output_language_instruction,
|
||||
team_execution_context,
|
||||
)
|
||||
if plan.is_team and team_execution_context:
|
||||
attempt_kwargs["include_tools"] = False
|
||||
attempt_kwargs["max_tool_iterations"] = 0
|
||||
@ -963,6 +986,24 @@ class AgentService:
|
||||
"short_title": decision.short_title,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _output_language_instruction(prompt_locale: str | None) -> str:
|
||||
locale = normalize_main_agent_prompt_locale(prompt_locale)
|
||||
if locale == "en":
|
||||
return (
|
||||
"Output language: English. Use English for user-facing task titles, summaries, plans, "
|
||||
"and final answers unless the user explicitly requests another language."
|
||||
)
|
||||
if locale == "zh-Hant":
|
||||
return (
|
||||
"輸出語言:繁體中文。除非使用者明確要求其他語言,所有面向使用者的任務標題、摘要、"
|
||||
"計劃與最終回答都使用繁體中文。"
|
||||
)
|
||||
return (
|
||||
"输出语言:简体中文。除非用户明确要求其他语言,所有面向用户的任务标题、摘要、"
|
||||
"计划与最终回答都使用简体中文。"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _skill_names_for_run(loaded: Any, run_id: str) -> list[str]:
|
||||
store = getattr(loaded, "run_memory_store", None)
|
||||
@ -1237,17 +1278,19 @@ class AgentService:
|
||||
async def handle_inbound_message(self, inbound: InboundMessage) -> OutboundMessage:
|
||||
"""把 bus inbound 映射成标准 runtime 调用,并返回结构化 outbound。"""
|
||||
|
||||
channel_identity = inbound.channel_identity
|
||||
try:
|
||||
result = await self.submit_direct(
|
||||
inbound.content,
|
||||
session_id=inbound.session_id,
|
||||
source=f"gateway:{inbound.channel}",
|
||||
user_id=inbound.user_id,
|
||||
user_id=inbound.user_id or (channel_identity.user_id if channel_identity else None),
|
||||
title=inbound.title,
|
||||
execution_context=inbound.execution_context,
|
||||
model=inbound.model,
|
||||
provider_name=inbound.provider_name,
|
||||
embedding_model=inbound.embedding_model,
|
||||
channel_identity=channel_identity,
|
||||
)
|
||||
except Exception as exc:
|
||||
return self.build_outbound_error(
|
||||
@ -1283,6 +1326,8 @@ class AgentService:
|
||||
finish_reason=result.finish_reason,
|
||||
provider_name=result.provider_name,
|
||||
model=result.model,
|
||||
content_type=inbound.content_type,
|
||||
channel_identity=inbound.channel_identity,
|
||||
usage=dict(result.usage),
|
||||
metadata={
|
||||
"inbound_metadata": dict(inbound.metadata),
|
||||
@ -1308,6 +1353,8 @@ class AgentService:
|
||||
session_id=inbound.session_id,
|
||||
content=detail,
|
||||
finish_reason=finish_reason,
|
||||
content_type=inbound.content_type,
|
||||
channel_identity=inbound.channel_identity,
|
||||
metadata={"error": detail, "inbound_metadata": dict(inbound.metadata)},
|
||||
)
|
||||
|
||||
|
||||
201
app-instance/backend/beaver/services/user_file_resolver.py
Normal file
201
app-instance/backend/beaver/services/user_file_resolver.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""Resolve the user-visible file system for web and agent callers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from beaver.foundation.config.schema import BeaverConfig
|
||||
|
||||
from .user_files import (
|
||||
LocalUserFileStorage,
|
||||
MinIOStorageConfig,
|
||||
MinIOUserFileStorage,
|
||||
USER_FILE_ROOTS,
|
||||
UserFileError,
|
||||
UserFileService,
|
||||
)
|
||||
|
||||
|
||||
class UserFileConfigurationError(UserFileError):
|
||||
"""Raised when user file storage is not configured for this backend."""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class FileAuthContext:
|
||||
"""Authenticated identity used by the personal file system boundary."""
|
||||
|
||||
username: str
|
||||
backend_id: str
|
||||
storage_namespace: str
|
||||
user_id: str | None = None
|
||||
scopes: tuple[str, ...] = field(default_factory=tuple)
|
||||
auth_source: str = "beaver-web-token"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFileStorageStatus:
|
||||
configured: bool
|
||||
storage_mode: str
|
||||
roots: list[str]
|
||||
workspace_visible: bool = False
|
||||
detail: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
payload: dict[str, Any] = {
|
||||
"configured": self.configured,
|
||||
"storage_mode": self.storage_mode,
|
||||
"roots": self.roots,
|
||||
"workspace_visible": self.workspace_visible,
|
||||
}
|
||||
if self.detail:
|
||||
payload["detail"] = self.detail
|
||||
return payload
|
||||
|
||||
|
||||
class UserFileStorageResolver:
|
||||
"""Build `UserFileService` from the current Beaver identity and config."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: BeaverConfig,
|
||||
workspace: Path,
|
||||
auth_context: FileAuthContext,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.workspace = Path(workspace)
|
||||
self.auth_context = auth_context
|
||||
|
||||
async def service(self) -> UserFileService:
|
||||
mode = _storage_mode(self.config)
|
||||
if mode == "local":
|
||||
return UserFileService(LocalUserFileStorage(self.workspace / "user_files"))
|
||||
settings = await self._load_minio_settings()
|
||||
return UserFileService(
|
||||
MinIOUserFileStorage(
|
||||
MinIOStorageConfig(
|
||||
endpoint=str(settings.get("endpoint") or ""),
|
||||
access_key=str(settings.get("access_key") or ""),
|
||||
secret_key=str(settings.get("secret_key") or ""),
|
||||
bucket=str(settings.get("bucket") or ""),
|
||||
secure=bool(settings.get("secure", False)),
|
||||
region=_clean_optional(settings.get("region")),
|
||||
namespace=str(settings.get("namespace") or self.auth_context.storage_namespace),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def status(self) -> UserFileStorageStatus:
|
||||
mode = _storage_mode(self.config)
|
||||
if mode == "local":
|
||||
return UserFileStorageStatus(
|
||||
configured=True,
|
||||
storage_mode="local",
|
||||
roots=list(USER_FILE_ROOTS),
|
||||
workspace_visible=False,
|
||||
)
|
||||
try:
|
||||
await self._load_minio_settings()
|
||||
except UserFileConfigurationError as exc:
|
||||
return UserFileStorageStatus(
|
||||
configured=False,
|
||||
storage_mode="object",
|
||||
roots=list(USER_FILE_ROOTS),
|
||||
workspace_visible=False,
|
||||
detail=str(exc),
|
||||
)
|
||||
return UserFileStorageStatus(
|
||||
configured=True,
|
||||
storage_mode="object",
|
||||
roots=list(USER_FILE_ROOTS),
|
||||
workspace_visible=False,
|
||||
)
|
||||
|
||||
async def _load_minio_settings(self) -> dict[str, Any]:
|
||||
backend_id = self.auth_context.backend_id.strip()
|
||||
if not backend_id:
|
||||
raise UserFileConfigurationError("User file storage backend identity is not configured")
|
||||
base_url = self.config.authz.base_url.strip()
|
||||
if not (self.config.authz.enabled and base_url):
|
||||
raise UserFileConfigurationError("AuthZ is required for deployed user file storage")
|
||||
token = (
|
||||
os.getenv("BEAVER_AUTHZ_INTERNAL_TOKEN", "").strip()
|
||||
or os.getenv("AUTHZ_INTERNAL_TOKEN", "").strip()
|
||||
)
|
||||
if not token:
|
||||
raise UserFileConfigurationError("AuthZ internal token is not configured for user file storage")
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=self.config.authz.request_timeout_seconds,
|
||||
follow_redirects=True,
|
||||
trust_env=False,
|
||||
) as client:
|
||||
response = await client.get(
|
||||
f"{base_url.rstrip('/')}/internal/backends/{backend_id}/settings/minio",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
raise UserFileConfigurationError(f"Unable to load user file storage settings: {exc}") from exc
|
||||
if response.status_code == 404:
|
||||
raise UserFileConfigurationError("MinIO user file storage is not configured")
|
||||
if response.is_error:
|
||||
raise UserFileConfigurationError(
|
||||
f"Unable to load user file storage settings: HTTP {response.status_code}"
|
||||
)
|
||||
payload = response.json()
|
||||
if not isinstance(payload, dict):
|
||||
raise UserFileConfigurationError("Invalid MinIO settings response")
|
||||
if not all(str(payload.get(key) or "").strip() for key in ("endpoint", "access_key", "secret_key", "bucket")):
|
||||
raise UserFileConfigurationError("MinIO user file storage settings are incomplete")
|
||||
payload.setdefault("namespace", self.auth_context.storage_namespace)
|
||||
return payload
|
||||
|
||||
|
||||
def build_file_auth_context(
|
||||
*,
|
||||
username: str,
|
||||
config: BeaverConfig,
|
||||
user_id: str | None = None,
|
||||
scopes: tuple[str, ...] = (),
|
||||
auth_source: str = "beaver-web-token",
|
||||
) -> FileAuthContext:
|
||||
backend_id = (
|
||||
config.backend_identity.backend_id.strip()
|
||||
or os.getenv("BEAVER_BACKEND_IDENTITY__BACKEND_ID", "").strip()
|
||||
or username.strip()
|
||||
)
|
||||
namespace = default_user_file_namespace(backend_id)
|
||||
return FileAuthContext(
|
||||
username=username.strip(),
|
||||
backend_id=backend_id,
|
||||
storage_namespace=namespace,
|
||||
user_id=user_id,
|
||||
scopes=scopes,
|
||||
auth_source=auth_source,
|
||||
)
|
||||
|
||||
|
||||
def default_user_file_namespace(backend_id: str) -> str:
|
||||
cleaned = backend_id.strip().strip("/")
|
||||
return f"users/{cleaned}" if cleaned else "users/unconfigured"
|
||||
|
||||
|
||||
def _storage_mode(config: BeaverConfig) -> str:
|
||||
raw = os.getenv("BEAVER_USER_FILES_STORAGE_MODE", "").strip().lower()
|
||||
if raw in {"local", "dev-local", "development"}:
|
||||
return "local"
|
||||
if raw in {"minio", "object", "object-storage"}:
|
||||
return "minio"
|
||||
if config.authz.enabled and config.authz.base_url.strip() and config.backend_identity.backend_id.strip():
|
||||
return "minio"
|
||||
return "local"
|
||||
|
||||
|
||||
def _clean_optional(value: Any) -> str | None:
|
||||
text = str(value or "").strip()
|
||||
return text or None
|
||||
630
app-instance/backend/beaver/services/user_files.py
Normal file
630
app-instance/backend/beaver/services/user_files.py
Normal file
@ -0,0 +1,630 @@
|
||||
"""User-visible file system service.
|
||||
|
||||
This module owns the personal file-system boundary exposed to users and
|
||||
agents. Storage backends can change, but callers see only virtual paths under
|
||||
fixed roots.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from io import BytesIO
|
||||
import mimetypes
|
||||
from pathlib import Path, PurePosixPath
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
USER_FILE_ROOTS = ("uploads", "outputs", "shared", "tasks")
|
||||
MAX_PREVIEW_BYTES = 1024 * 1024
|
||||
AGENT_UPLOADS_ERROR = "uploads/ is user-provided input storage; agents may read it but must not write it"
|
||||
AGENT_DELETE_ERROR = "agents cannot delete user-visible files; use the Files page or user-side APIs"
|
||||
|
||||
|
||||
class UserFileError(ValueError):
|
||||
"""Base error for user file operations."""
|
||||
|
||||
|
||||
class UserFilePathError(UserFileError):
|
||||
"""Raised when a user file path violates the virtual path policy."""
|
||||
|
||||
|
||||
class UserFileNotFoundError(UserFileError):
|
||||
"""Raised when a user file path does not exist."""
|
||||
|
||||
|
||||
class UserFileSizeError(UserFileError):
|
||||
"""Raised when a user file upload exceeds configured limits."""
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AgentUserFilePolicy:
|
||||
task_id: str | None = None
|
||||
fallback_scope: str = "interactive"
|
||||
|
||||
@property
|
||||
def task_namespace(self) -> str:
|
||||
if self.task_id:
|
||||
return f"tasks/{self.task_id}"
|
||||
scope = _safe_scope(self.fallback_scope)
|
||||
return f"tasks/interactive/{scope}"
|
||||
|
||||
def validate_read(self, path: str) -> str:
|
||||
return normalize_user_path(path, allow_root=False)
|
||||
|
||||
def validate_write(self, path: str) -> str:
|
||||
normalized = normalize_user_path(path, allow_root=False)
|
||||
root = normalized.split("/", 1)[0]
|
||||
if root == "uploads":
|
||||
raise UserFilePathError(AGENT_UPLOADS_ERROR)
|
||||
if root == "tasks":
|
||||
self._validate_task_namespace(normalized)
|
||||
return normalized
|
||||
|
||||
def validate_mkdir(self, path: str) -> str:
|
||||
return self.validate_write(path)
|
||||
|
||||
def validate_delete(self, path: str) -> str:
|
||||
normalize_user_path(path, allow_root=False)
|
||||
raise UserFilePathError(AGENT_DELETE_ERROR)
|
||||
|
||||
def _validate_task_namespace(self, normalized: str) -> None:
|
||||
namespace = self.task_namespace
|
||||
if normalized == "tasks" or not normalized.startswith(f"{namespace}/"):
|
||||
raise UserFilePathError(f"Agent task files must be written under {namespace}/")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFileEntry:
|
||||
name: str
|
||||
path: str
|
||||
type: str
|
||||
size: int | None = None
|
||||
content_type: str | None = None
|
||||
modified: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"path": self.path,
|
||||
"type": self.type,
|
||||
"size": self.size,
|
||||
"content_type": self.content_type,
|
||||
"modified": self.modified,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFileContent:
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
content_type: str
|
||||
modified: str | None
|
||||
content: bytes
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFilePreview:
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
content_type: str
|
||||
modified: str | None
|
||||
is_binary: bool
|
||||
is_truncated: bool
|
||||
content: str | None
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"path": self.path,
|
||||
"size": self.size,
|
||||
"content_type": self.content_type,
|
||||
"modified": self.modified,
|
||||
"is_binary": self.is_binary,
|
||||
"is_truncated": self.is_truncated,
|
||||
"content": self.content,
|
||||
}
|
||||
|
||||
|
||||
class UserFileStorage(Protocol):
|
||||
async def list_dir(self, path: str) -> list[UserFileEntry]:
|
||||
...
|
||||
|
||||
async def read_file(self, path: str, *, max_bytes: int | None = None) -> UserFileContent:
|
||||
...
|
||||
|
||||
async def write_file(self, path: str, content: bytes, *, content_type: str) -> UserFileEntry:
|
||||
...
|
||||
|
||||
async def write_file_stream(
|
||||
self,
|
||||
path: str,
|
||||
stream: object,
|
||||
*,
|
||||
content_type: str,
|
||||
max_bytes: int | None = None,
|
||||
part_size: int = 10 * 1024 * 1024,
|
||||
) -> UserFileEntry:
|
||||
...
|
||||
|
||||
async def delete_path(self, path: str) -> bool:
|
||||
...
|
||||
|
||||
async def mkdir(self, path: str) -> UserFileEntry:
|
||||
...
|
||||
|
||||
|
||||
class UserFileService:
|
||||
def __init__(self, storage: UserFileStorage) -> None:
|
||||
self.storage = storage
|
||||
|
||||
async def browse(self, path: str = "") -> dict[str, object]:
|
||||
normalized = normalize_user_path(path, allow_root=True)
|
||||
if normalized == "":
|
||||
return {
|
||||
"path": "",
|
||||
"items": [
|
||||
UserFileEntry(name=root, path=root, type="directory").to_dict()
|
||||
for root in USER_FILE_ROOTS
|
||||
],
|
||||
}
|
||||
entries = await self.storage.list_dir(normalized)
|
||||
return {"path": normalized, "items": [entry.to_dict() for entry in entries]}
|
||||
|
||||
async def upload(self, directory: str, filename: str, content: bytes, *, content_type: str) -> dict[str, object]:
|
||||
if not is_safe_filename(filename):
|
||||
raise UserFilePathError("Invalid filename")
|
||||
target = normalize_user_path(_join_user_path(directory, filename), allow_root=False)
|
||||
return (await self.storage.write_file(target, content, content_type=content_type)).to_dict()
|
||||
|
||||
async def upload_stream(
|
||||
self,
|
||||
directory: str,
|
||||
filename: str,
|
||||
stream: object,
|
||||
*,
|
||||
content_type: str,
|
||||
max_bytes: int | None = None,
|
||||
part_size: int = 10 * 1024 * 1024,
|
||||
) -> dict[str, object]:
|
||||
if not is_safe_filename(filename):
|
||||
raise UserFilePathError("Invalid filename")
|
||||
target = normalize_user_path(_join_user_path(directory, filename), allow_root=False)
|
||||
return (
|
||||
await self.storage.write_file_stream(
|
||||
target,
|
||||
stream,
|
||||
content_type=content_type,
|
||||
max_bytes=max_bytes,
|
||||
part_size=part_size,
|
||||
)
|
||||
).to_dict()
|
||||
|
||||
async def write_file(self, path: str, content: bytes | str, *, content_type: str = "text/plain") -> dict[str, object]:
|
||||
normalized = normalize_user_path(path, allow_root=False)
|
||||
raw = content.encode("utf-8") if isinstance(content, str) else bytes(content)
|
||||
return (await self.storage.write_file(normalized, raw, content_type=content_type)).to_dict()
|
||||
|
||||
async def download(self, path: str) -> UserFileContent:
|
||||
return await self.storage.read_file(normalize_user_path(path, allow_root=False))
|
||||
|
||||
async def preview(self, path: str, *, max_bytes: int = MAX_PREVIEW_BYTES) -> dict[str, object]:
|
||||
content = await self.storage.read_file(normalize_user_path(path, allow_root=False), max_bytes=max_bytes)
|
||||
is_binary = _is_probably_binary(content.content, content.content_type)
|
||||
text = None if is_binary else content.content.decode("utf-8", errors="replace")
|
||||
return UserFilePreview(
|
||||
name=content.name,
|
||||
path=content.path,
|
||||
size=content.size,
|
||||
content_type=content.content_type,
|
||||
modified=content.modified,
|
||||
is_binary=is_binary,
|
||||
is_truncated=content.size > len(content.content),
|
||||
content=text,
|
||||
).to_dict()
|
||||
|
||||
async def delete(self, path: str) -> bool:
|
||||
normalized = normalize_user_path(path, allow_root=False)
|
||||
if normalized in USER_FILE_ROOTS:
|
||||
raise UserFilePathError("Cannot delete virtual root folders")
|
||||
return await self.storage.delete_path(normalized)
|
||||
|
||||
async def mkdir(self, path: str) -> dict[str, object]:
|
||||
normalized = normalize_user_path(path, allow_root=False)
|
||||
if normalized in USER_FILE_ROOTS:
|
||||
raise UserFilePathError("Virtual root folders already exist")
|
||||
return (await self.storage.mkdir(normalized)).to_dict()
|
||||
|
||||
|
||||
class LocalUserFileStorage:
|
||||
"""Filesystem-backed storage adapter for tests and local development."""
|
||||
|
||||
def __init__(self, root: Path) -> None:
|
||||
self.root = Path(root).expanduser().resolve()
|
||||
self.root.mkdir(parents=True, exist_ok=True)
|
||||
for name in USER_FILE_ROOTS:
|
||||
(self.root / name).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def list_dir(self, path: str) -> list[UserFileEntry]:
|
||||
target = self._path(path)
|
||||
if not target.exists():
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
if not target.is_dir():
|
||||
raise UserFilePathError("Path is not a directory")
|
||||
entries: list[UserFileEntry] = []
|
||||
for child in sorted(target.iterdir(), key=lambda item: (not item.is_dir(), item.name.lower())):
|
||||
if child.name.startswith("."):
|
||||
continue
|
||||
entries.append(self._entry(child))
|
||||
return entries
|
||||
|
||||
async def read_file(self, path: str, *, max_bytes: int | None = None) -> UserFileContent:
|
||||
target = self._path(path)
|
||||
if not target.is_file():
|
||||
raise UserFileNotFoundError("File not found")
|
||||
raw = target.read_bytes()
|
||||
selected = raw[:max_bytes] if max_bytes is not None else raw
|
||||
stat = target.stat()
|
||||
content_type, _ = mimetypes.guess_type(target.name)
|
||||
return UserFileContent(
|
||||
name=target.name,
|
||||
path=self._relative(target),
|
||||
size=stat.st_size,
|
||||
content_type=content_type or "application/octet-stream",
|
||||
modified=_iso_from_timestamp(stat.st_mtime),
|
||||
content=selected,
|
||||
)
|
||||
|
||||
async def write_file(self, path: str, content: bytes, *, content_type: str) -> UserFileEntry:
|
||||
target = self._path(path)
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
target.write_bytes(content)
|
||||
return self._entry(target, content_type=content_type)
|
||||
|
||||
async def write_file_stream(
|
||||
self,
|
||||
path: str,
|
||||
stream: object,
|
||||
*,
|
||||
content_type: str,
|
||||
max_bytes: int | None = None,
|
||||
part_size: int = 10 * 1024 * 1024,
|
||||
) -> UserFileEntry:
|
||||
target = self._path(path)
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp_name = tempfile.mkstemp(prefix=f".{target.name}.", suffix=".tmp", dir=target.parent)
|
||||
tmp_path = Path(tmp_name)
|
||||
total = 0
|
||||
try:
|
||||
with open(fd, "wb", closefd=True) as output:
|
||||
while True:
|
||||
chunk = stream.read(part_size) # type: ignore[attr-defined]
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if max_bytes is not None and total > max_bytes:
|
||||
raise UserFileSizeError(_size_error(max_bytes))
|
||||
output.write(chunk)
|
||||
tmp_path.replace(target)
|
||||
except Exception:
|
||||
with suppress(FileNotFoundError):
|
||||
tmp_path.unlink()
|
||||
raise
|
||||
return self._entry(target, content_type=content_type)
|
||||
|
||||
async def delete_path(self, path: str) -> bool:
|
||||
target = self._path(path)
|
||||
if not target.exists():
|
||||
return False
|
||||
if target.is_dir():
|
||||
shutil.rmtree(target)
|
||||
else:
|
||||
target.unlink()
|
||||
return True
|
||||
|
||||
async def mkdir(self, path: str) -> UserFileEntry:
|
||||
target = self._path(path)
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
return self._entry(target)
|
||||
|
||||
def _path(self, path: str) -> Path:
|
||||
normalized = normalize_user_path(path, allow_root=False)
|
||||
target = (self.root / normalized).resolve()
|
||||
try:
|
||||
target.relative_to(self.root)
|
||||
except ValueError as exc:
|
||||
raise UserFilePathError("Path escapes user file root") from exc
|
||||
return target
|
||||
|
||||
def _relative(self, path: Path) -> str:
|
||||
return path.relative_to(self.root).as_posix()
|
||||
|
||||
def _entry(self, path: Path, *, content_type: str | None = None) -> UserFileEntry:
|
||||
stat = path.stat()
|
||||
guessed_type, _ = mimetypes.guess_type(path.name)
|
||||
return UserFileEntry(
|
||||
name=path.name,
|
||||
path=self._relative(path),
|
||||
type="directory" if path.is_dir() else "file",
|
||||
size=None if path.is_dir() else stat.st_size,
|
||||
content_type=None if path.is_dir() else (content_type or guessed_type or "application/octet-stream"),
|
||||
modified=_iso_from_timestamp(stat.st_mtime),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MinIOStorageConfig:
|
||||
endpoint: str
|
||||
access_key: str
|
||||
secret_key: str
|
||||
bucket: str
|
||||
secure: bool = False
|
||||
region: str | None = None
|
||||
namespace: str = ""
|
||||
|
||||
|
||||
class MinIOUserFileStorage:
|
||||
"""MinIO-backed user file storage adapter."""
|
||||
|
||||
def __init__(self, config: MinIOStorageConfig) -> None:
|
||||
if not config.endpoint or not config.access_key or not config.secret_key or not config.bucket:
|
||||
raise ValueError("MinIO storage requires endpoint, access key, secret key, and bucket")
|
||||
from minio import Minio
|
||||
|
||||
self.config = config
|
||||
self.client = Minio(
|
||||
endpoint=config.endpoint,
|
||||
access_key=config.access_key,
|
||||
secret_key=config.secret_key,
|
||||
secure=config.secure,
|
||||
region=config.region,
|
||||
)
|
||||
|
||||
async def list_dir(self, path: str) -> list[UserFileEntry]:
|
||||
prefix = self._object_prefix(path)
|
||||
objects = self.client.list_objects(self.config.bucket, prefix=prefix, recursive=False)
|
||||
entries: list[UserFileEntry] = []
|
||||
for obj in objects:
|
||||
object_name = str(obj.object_name or "")
|
||||
user_path = self._user_path(object_name)
|
||||
if not user_path or user_path == path or user_path.endswith("/.keep"):
|
||||
continue
|
||||
trimmed = user_path.rstrip("/")
|
||||
name = PurePosixPath(trimmed).name
|
||||
is_dir = bool(getattr(obj, "is_dir", False)) or object_name.endswith("/")
|
||||
entries.append(
|
||||
UserFileEntry(
|
||||
name=name,
|
||||
path=trimmed,
|
||||
type="directory" if is_dir else "file",
|
||||
size=None if is_dir else getattr(obj, "size", None),
|
||||
content_type=None if is_dir else "application/octet-stream",
|
||||
modified=obj.last_modified.isoformat() if getattr(obj, "last_modified", None) else None,
|
||||
)
|
||||
)
|
||||
return sorted(entries, key=lambda item: (item.type != "directory", item.name.lower()))
|
||||
|
||||
async def read_file(self, path: str, *, max_bytes: int | None = None) -> UserFileContent:
|
||||
object_name = self._object_name(path)
|
||||
try:
|
||||
stat = self.client.stat_object(self.config.bucket, object_name)
|
||||
if max_bytes is None:
|
||||
response = self.client.get_object(self.config.bucket, object_name)
|
||||
else:
|
||||
response = self.client.get_object(self.config.bucket, object_name, length=max_bytes)
|
||||
raw = response.read()
|
||||
response.close()
|
||||
response.release_conn()
|
||||
except Exception as exc:
|
||||
raise UserFileNotFoundError("File not found") from exc
|
||||
return UserFileContent(
|
||||
name=PurePosixPath(path).name,
|
||||
path=path,
|
||||
size=int(stat.size or len(raw)),
|
||||
content_type=stat.content_type or "application/octet-stream",
|
||||
modified=stat.last_modified.isoformat() if stat.last_modified else None,
|
||||
content=raw,
|
||||
)
|
||||
|
||||
async def write_file(self, path: str, content: bytes, *, content_type: str) -> UserFileEntry:
|
||||
object_name = self._object_name(path)
|
||||
result = self.client.put_object(
|
||||
self.config.bucket,
|
||||
object_name,
|
||||
BytesIO(content),
|
||||
length=len(content),
|
||||
content_type=content_type,
|
||||
)
|
||||
return UserFileEntry(
|
||||
name=PurePosixPath(path).name,
|
||||
path=path,
|
||||
type="file",
|
||||
size=len(content),
|
||||
content_type=content_type,
|
||||
modified=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
async def write_file_stream(
|
||||
self,
|
||||
path: str,
|
||||
stream: object,
|
||||
*,
|
||||
content_type: str,
|
||||
max_bytes: int | None = None,
|
||||
part_size: int = 10 * 1024 * 1024,
|
||||
) -> UserFileEntry:
|
||||
object_name = self._object_name(path)
|
||||
reader = _LimitedReadStream(stream, max_bytes=max_bytes)
|
||||
try:
|
||||
self.client.put_object(
|
||||
self.config.bucket,
|
||||
object_name,
|
||||
reader,
|
||||
length=-1,
|
||||
part_size=max(5 * 1024 * 1024, part_size),
|
||||
content_type=content_type,
|
||||
)
|
||||
except UserFileSizeError:
|
||||
try:
|
||||
self.client.remove_object(self.config.bucket, object_name)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
return UserFileEntry(
|
||||
name=PurePosixPath(path).name,
|
||||
path=path,
|
||||
type="file",
|
||||
size=reader.bytes_read,
|
||||
content_type=content_type,
|
||||
modified=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
async def delete_path(self, path: str) -> bool:
|
||||
object_name = self._object_name(path)
|
||||
removed = False
|
||||
try:
|
||||
self.client.remove_object(self.config.bucket, object_name)
|
||||
removed = True
|
||||
except Exception:
|
||||
pass
|
||||
prefix = f"{object_name.rstrip('/')}/"
|
||||
for obj in self.client.list_objects(self.config.bucket, prefix=prefix, recursive=True):
|
||||
self.client.remove_object(self.config.bucket, str(obj.object_name))
|
||||
removed = True
|
||||
return removed
|
||||
|
||||
async def mkdir(self, path: str) -> UserFileEntry:
|
||||
object_name = f"{self._object_name(path).rstrip('/')}/.keep"
|
||||
self.client.put_object(
|
||||
self.config.bucket,
|
||||
object_name,
|
||||
BytesIO(b""),
|
||||
length=0,
|
||||
content_type="application/x-directory",
|
||||
)
|
||||
return UserFileEntry(
|
||||
name=PurePosixPath(path).name,
|
||||
path=path,
|
||||
type="directory",
|
||||
size=None,
|
||||
modified=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
def _namespace(self) -> str:
|
||||
return self.config.namespace.strip("/")
|
||||
|
||||
def _object_name(self, path: str) -> str:
|
||||
normalized = normalize_user_path(path, allow_root=False)
|
||||
namespace = self._namespace()
|
||||
object_name = f"{namespace}/{normalized}" if namespace else normalized
|
||||
if object_name.startswith("/") or "/../" in f"/{object_name}/":
|
||||
raise UserFilePathError("Object path escapes namespace")
|
||||
return object_name
|
||||
|
||||
def _object_prefix(self, path: str) -> str:
|
||||
return f"{self._object_name(path).rstrip('/')}/"
|
||||
|
||||
def _user_path(self, object_name: str) -> str:
|
||||
namespace = self._namespace()
|
||||
if namespace:
|
||||
prefix = f"{namespace}/"
|
||||
if not object_name.startswith(prefix):
|
||||
raise UserFilePathError("Object path escapes namespace")
|
||||
return object_name[len(prefix) :]
|
||||
return object_name
|
||||
|
||||
|
||||
def normalize_user_path(path: str | None, *, allow_root: bool) -> str:
|
||||
original = (path or "").replace("\\", "/").strip()
|
||||
if original.startswith("/"):
|
||||
raise UserFilePathError("Absolute paths are not allowed")
|
||||
raw = original.strip("/")
|
||||
if raw == "":
|
||||
if allow_root:
|
||||
return ""
|
||||
raise UserFilePathError("Path is required")
|
||||
posix = PurePosixPath(raw)
|
||||
if posix.is_absolute():
|
||||
raise UserFilePathError("Absolute paths are not allowed")
|
||||
parts = [part for part in posix.parts if part not in ("", ".")]
|
||||
if any(part == ".." for part in parts):
|
||||
raise UserFilePathError("Parent-directory traversal is not allowed")
|
||||
if any(part.startswith(".") for part in parts):
|
||||
raise UserFilePathError("Hidden implementation paths are not allowed")
|
||||
if not parts or parts[0] not in USER_FILE_ROOTS:
|
||||
raise UserFilePathError("Path must be under uploads, outputs, shared, or tasks")
|
||||
return "/".join(parts)
|
||||
|
||||
|
||||
def is_safe_filename(filename: str) -> bool:
|
||||
return bool(filename) and "/" not in filename and "\\" not in filename and not filename.startswith(".")
|
||||
|
||||
|
||||
def _join_user_path(directory: str, filename: str) -> str:
|
||||
normalized_dir = normalize_user_path(directory, allow_root=False)
|
||||
return f"{normalized_dir.rstrip('/')}/{filename}"
|
||||
|
||||
|
||||
def _is_probably_binary(raw: bytes, content_type: str) -> bool:
|
||||
if content_type.startswith("text/") or content_type in {
|
||||
"application/json",
|
||||
"application/javascript",
|
||||
"application/xml",
|
||||
"application/x-yaml",
|
||||
}:
|
||||
return False
|
||||
if not raw:
|
||||
return False
|
||||
if b"\x00" in raw[:4096]:
|
||||
return True
|
||||
try:
|
||||
raw[:4096].decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _iso_from_timestamp(value: float) -> str:
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _safe_scope(value: str | None) -> str:
|
||||
raw = (value or "interactive").strip()
|
||||
allowed = [char if char.isalnum() or char in ("-", "_") else "-" for char in raw]
|
||||
cleaned = "".join(allowed).strip("-_")
|
||||
return cleaned or "interactive"
|
||||
|
||||
|
||||
class _LimitedReadStream:
|
||||
def __init__(self, stream: object, *, max_bytes: int | None = None) -> None:
|
||||
self.stream = stream
|
||||
self.max_bytes = max_bytes
|
||||
self.bytes_read = 0
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
chunk = self.stream.read(size) # type: ignore[attr-defined]
|
||||
if not chunk:
|
||||
return b""
|
||||
self.bytes_read += len(chunk)
|
||||
if self.max_bytes is not None and self.bytes_read > self.max_bytes:
|
||||
raise UserFileSizeError(_size_error(self.max_bytes))
|
||||
return chunk
|
||||
|
||||
|
||||
def _size_error(max_bytes: int) -> str:
|
||||
return f"File too large (max {_human_size(max_bytes)})"
|
||||
|
||||
|
||||
def _human_size(size: int) -> str:
|
||||
units = ("B", "KB", "MB", "GB", "TB")
|
||||
value = float(size)
|
||||
for unit in units:
|
||||
if value < 1024 or unit == units[-1]:
|
||||
return f"{value:.0f}{unit}" if unit == "B" else f"{value:.1f}{unit}"
|
||||
value /= 1024
|
||||
return f"{size}B"
|
||||
@ -32,16 +32,23 @@ When there is an active task, do not force every new user message into that task
|
||||
|
||||
- Choose `revise_task` when the user asks to change, correct, refine, expand, reformat, or redo the latest active task result.
|
||||
- Choose `continue_task` for neutral follow-up questions or additional next steps that still belong to the active task.
|
||||
- Choose `new_task` when the user asks for clearly unrelated work.
|
||||
- Choose `simple_chat` for unrelated lightweight conversation. This starts a new topic and the previous task will be accepted automatically.
|
||||
- Choose `new_task` when the user asks for clearly unrelated work that needs Task capabilities. This starts a new topic and the previous task will be accepted automatically.
|
||||
- Choose `close_task` when the user says the task is satisfactory or finished, such as "可以了", "就这样", or "that's good".
|
||||
- Choose `abandon_task` when the user says to stop, cancel, or no longer do the active task.
|
||||
|
||||
Do not classify unrelated lightweight conversation as `revise_task` merely because
|
||||
the active task is awaiting acceptance. A revision must ask to change or correct
|
||||
the active task result.
|
||||
|
||||
Examples with an active weather task:
|
||||
|
||||
- "再详细一点" -> `revise_task`
|
||||
- "加上明后天穿衣建议" -> `revise_task`
|
||||
- "顺便查一下深圳" -> `continue_task`
|
||||
- "帮我写一个采购合同" -> `new_task`
|
||||
- "吃饭没" -> `simple_chat`
|
||||
- "我在冰岛" -> `simple_chat`
|
||||
- "可以了" -> `close_task`
|
||||
- "不用了" -> `abandon_task`
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Skill learning loop helpers."""
|
||||
|
||||
from .case_selection import select_replay_cases
|
||||
from .evidence import EvidencePacket, EvidenceSelector
|
||||
from .eval import SkillDraftEvaluator
|
||||
from .missing_skill import (
|
||||
@ -9,11 +10,15 @@ from .missing_skill import (
|
||||
MissingSkillSynthesizer,
|
||||
)
|
||||
from .pipeline import SkillLearningPipelineService
|
||||
from .preservation import check_preservation
|
||||
from .replay import ReplayArmRequest, ReplayRunner, ReplayToolExecutor, ReplayToolPolicy, classify_tool_mode
|
||||
from .service import RunReceiptContext, SkillLearningService
|
||||
from .surrogate import SurrogateToolEvaluator
|
||||
from .synthesizer import SkillDraftSynthesizer
|
||||
from .worker import SkillLearningWorker, SkillLearningWorkerConfig, SkillLearningWorkerResult
|
||||
|
||||
__all__ = [
|
||||
"select_replay_cases",
|
||||
"EvidencePacket",
|
||||
"EvidenceSelector",
|
||||
"SkillDraftEvaluator",
|
||||
@ -23,6 +28,13 @@ __all__ = [
|
||||
"MissingSkillSynthesizer",
|
||||
"RunReceiptContext",
|
||||
"SkillLearningPipelineService",
|
||||
"check_preservation",
|
||||
"ReplayToolExecutor",
|
||||
"ReplayToolPolicy",
|
||||
"ReplayArmRequest",
|
||||
"ReplayRunner",
|
||||
"classify_tool_mode",
|
||||
"SurrogateToolEvaluator",
|
||||
"SkillDraftSynthesizer",
|
||||
"SkillLearningService",
|
||||
"SkillLearningWorker",
|
||||
|
||||
109
app-instance/backend/beaver/skills/learning/case_selection.py
Normal file
109
app-instance/backend/beaver/skills/learning/case_selection.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Historical replay case selection for skill draft evaluation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from beaver.memory.runs import RunRecord
|
||||
from beaver.memory.skills import SkillLearningCandidate
|
||||
|
||||
MAX_REPLAY_CASES = 10
|
||||
|
||||
|
||||
def select_replay_cases(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[dict[str, Any]]:
|
||||
accepted = [record for record in runs if _is_accepted(record)]
|
||||
if candidate.kind == "revise_skill":
|
||||
selected = _select_revise(candidate, accepted)
|
||||
elif candidate.kind == "merge_skills":
|
||||
selected = _select_merge(candidate, accepted)
|
||||
else:
|
||||
selected = _select_new(candidate, accepted)
|
||||
return [_case_payload(candidate, record) for record in selected[:MAX_REPLAY_CASES]]
|
||||
|
||||
|
||||
def _select_revise(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||
target = candidate.related_skill_names[0] if candidate.related_skill_names else ""
|
||||
version = str(candidate.evidence.get("skill_version") or "")
|
||||
matches = [
|
||||
record
|
||||
for record in runs
|
||||
if any(
|
||||
receipt.skill_name == target and (not version or receipt.skill_version == version)
|
||||
for receipt in record.activated_skills
|
||||
)
|
||||
]
|
||||
return _recent_diverse(matches)
|
||||
|
||||
|
||||
def _select_merge(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||
targets = set(candidate.related_skill_names)
|
||||
matches = [
|
||||
record
|
||||
for record in runs
|
||||
if targets and targets.issubset({receipt.skill_name for receipt in record.activated_skills})
|
||||
]
|
||||
return _recent_diverse(matches)
|
||||
|
||||
|
||||
def _select_new(candidate: SkillLearningCandidate, runs: list[RunRecord]) -> list[RunRecord]:
|
||||
source_ids = set(candidate.source_run_ids)
|
||||
if source_ids:
|
||||
matches = [record for record in runs if record.run_id in source_ids]
|
||||
else:
|
||||
theme = str(candidate.evidence.get("theme") or "").lower().strip()
|
||||
matches = [record for record in runs if theme and theme in record.task_text.lower()]
|
||||
return _recent_diverse(matches)
|
||||
|
||||
|
||||
def _case_payload(candidate: SkillLearningCandidate, record: RunRecord) -> dict[str, Any]:
|
||||
baseline_skill_names = []
|
||||
if candidate.kind == "revise_skill":
|
||||
baseline_skill_names = list(candidate.related_skill_names[:1])
|
||||
elif candidate.kind == "merge_skills":
|
||||
baseline_skill_names = list(candidate.related_skill_names)
|
||||
return {
|
||||
"run_id": record.run_id,
|
||||
"task_id": record.task_id,
|
||||
"session_id": record.session_id,
|
||||
"task_text": record.task_text,
|
||||
"baseline_skill_names": baseline_skill_names,
|
||||
"candidate_skill_name": candidate.draft_skill_name,
|
||||
"accepted_score": _score(record),
|
||||
}
|
||||
|
||||
|
||||
def _recent_diverse(runs: list[RunRecord]) -> list[RunRecord]:
|
||||
sorted_runs = sorted(runs, key=lambda item: (item.started_at, item.run_id), reverse=True)
|
||||
result: list[RunRecord] = []
|
||||
seen_tasks: set[str] = set()
|
||||
for record in sorted_runs:
|
||||
task_key = record.task_id or record.task_text
|
||||
if task_key in seen_tasks and len(sorted_runs) > MAX_REPLAY_CASES:
|
||||
continue
|
||||
seen_tasks.add(task_key)
|
||||
result.append(record)
|
||||
if len(result) >= MAX_REPLAY_CASES:
|
||||
break
|
||||
if len(result) < min(len(sorted_runs), MAX_REPLAY_CASES):
|
||||
seen_run_ids = {record.run_id for record in result}
|
||||
result.extend(record for record in sorted_runs if record.run_id not in seen_run_ids)
|
||||
return result[:MAX_REPLAY_CASES]
|
||||
|
||||
|
||||
def _is_accepted(record: RunRecord) -> bool:
|
||||
feedback = record.feedback or {}
|
||||
acceptance = feedback.get("acceptance_type")
|
||||
if acceptance is None and feedback.get("feedback_type") == "satisfied":
|
||||
acceptance = "accept"
|
||||
return bool(record.success) and acceptance == "accept"
|
||||
|
||||
|
||||
def _score(record: RunRecord) -> float:
|
||||
validation = record.validation_result or {}
|
||||
value = validation.get("score") if isinstance(validation, dict) else None
|
||||
if value is not None:
|
||||
try:
|
||||
return max(0.0, min(1.0, float(value)))
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return 0.8 if record.success else 0.4
|
||||
@ -4,17 +4,28 @@ from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from beaver.engine.context import SkillContext
|
||||
from beaver.engine.providers import ProviderBundle
|
||||
from beaver.memory.runs import RunMemoryStore
|
||||
from beaver.memory.skills import SkillDraftEvalReport, SkillLearningCandidate
|
||||
from beaver.skills.learning.case_selection import select_replay_cases
|
||||
from beaver.skills.learning.preservation import check_preservation
|
||||
from beaver.skills.learning.replay import ReplayArmRequest, ReplayRunner
|
||||
from beaver.skills.learning.surrogate import SurrogateToolEvaluator
|
||||
from beaver.skills.specs import SkillDraft
|
||||
|
||||
|
||||
class SkillDraftEvaluator:
|
||||
"""Builds a bounded eval report without writing user-visible sessions."""
|
||||
|
||||
def __init__(self, run_store: RunMemoryStore) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
run_store: RunMemoryStore,
|
||||
*,
|
||||
surrogate_evaluator: SurrogateToolEvaluator | None = None,
|
||||
) -> None:
|
||||
self.run_store = run_store
|
||||
self.surrogate_evaluator = surrogate_evaluator or SurrogateToolEvaluator()
|
||||
|
||||
async def evaluate(
|
||||
self,
|
||||
@ -22,11 +33,30 @@ class SkillDraftEvaluator:
|
||||
candidate: SkillLearningCandidate,
|
||||
draft: SkillDraft,
|
||||
provider_bundle: ProviderBundle | None,
|
||||
replay_runner: ReplayRunner | None = None,
|
||||
) -> SkillDraftEvalReport:
|
||||
if provider_bundle is None or provider_bundle.main_provider is None:
|
||||
return self._skipped(candidate, draft)
|
||||
|
||||
runs_by_id = {record.run_id: record for record in self.run_store.list_runs()}
|
||||
runs = self.run_store.list_runs()
|
||||
replay_cases = select_replay_cases(candidate, runs)
|
||||
if replay_runner is not None and replay_cases:
|
||||
return await self._evaluate_replay(
|
||||
candidate=candidate,
|
||||
draft=draft,
|
||||
replay_cases=replay_cases,
|
||||
provider_bundle=provider_bundle,
|
||||
replay_runner=replay_runner,
|
||||
)
|
||||
return self._evaluate_heuristic(candidate, draft, runs)
|
||||
|
||||
def _evaluate_heuristic(
|
||||
self,
|
||||
candidate: SkillLearningCandidate,
|
||||
draft: SkillDraft,
|
||||
runs: list,
|
||||
) -> SkillDraftEvalReport:
|
||||
runs_by_id = {record.run_id: record for record in runs}
|
||||
cases: list[dict] = []
|
||||
for run_id in candidate.source_run_ids[:8]:
|
||||
record = runs_by_id.get(run_id)
|
||||
@ -78,6 +108,78 @@ class SkillDraftEvaluator:
|
||||
created_at=_utc_now(),
|
||||
)
|
||||
|
||||
async def _evaluate_replay(
|
||||
self,
|
||||
*,
|
||||
candidate: SkillLearningCandidate,
|
||||
draft: SkillDraft,
|
||||
replay_cases: list[dict],
|
||||
provider_bundle: ProviderBundle,
|
||||
replay_runner: ReplayRunner,
|
||||
) -> SkillDraftEvalReport:
|
||||
case_reports: list[dict] = []
|
||||
legacy_cases: list[dict] = []
|
||||
for case in replay_cases:
|
||||
baseline = await replay_runner.run_arm(
|
||||
ReplayArmRequest(
|
||||
case_id=f"{case['run_id']}:baseline",
|
||||
arm="baseline",
|
||||
task_text=str(case["task_text"]),
|
||||
pinned_skill_names=list(case.get("baseline_skill_names") or []),
|
||||
pinned_skill_contexts=[],
|
||||
provider_bundle=provider_bundle,
|
||||
model_settings={"max_tool_iterations": 4, "temperature": 0.0},
|
||||
)
|
||||
)
|
||||
candidate_arm = await replay_runner.run_arm(
|
||||
ReplayArmRequest(
|
||||
case_id=f"{case['run_id']}:candidate",
|
||||
arm="candidate",
|
||||
task_text=str(case["task_text"]),
|
||||
pinned_skill_names=[],
|
||||
pinned_skill_contexts=[_draft_skill_context(draft)],
|
||||
provider_bundle=provider_bundle,
|
||||
model_settings={"max_tool_iterations": 4, "temperature": 0.0},
|
||||
)
|
||||
)
|
||||
surrogate = await self.surrogate_evaluator.evaluate(
|
||||
task_text=str(case["task_text"]),
|
||||
baseline=baseline,
|
||||
candidate=candidate_arm,
|
||||
)
|
||||
baseline_score = surrogate["baseline_score"]
|
||||
candidate_score = surrogate["candidate_score"]
|
||||
case_report = {
|
||||
"run_id": case["run_id"],
|
||||
"task_id": case.get("task_id"),
|
||||
"session_id": case.get("session_id"),
|
||||
"baseline": baseline,
|
||||
"candidate": candidate_arm,
|
||||
"baseline_score": baseline_score,
|
||||
"candidate_score": candidate_score,
|
||||
"delta": round(candidate_score - baseline_score, 4),
|
||||
"execution_coverage": _arm_mode_coverage(baseline, candidate_arm, "executed"),
|
||||
"surrogate_coverage": _arm_mode_coverage(baseline, candidate_arm, "surrogate"),
|
||||
"blocked_tool_count": _arm_mode_count(baseline, candidate_arm, "blocked"),
|
||||
"confidence": surrogate["confidence"],
|
||||
"tool_calls": [*baseline.get("tool_calls", []), *candidate_arm.get("tool_calls", [])],
|
||||
"artifacts": [*baseline.get("artifacts", []), *candidate_arm.get("artifacts", [])],
|
||||
"side_effects": [*baseline.get("side_effects", []), *candidate_arm.get("side_effects", [])],
|
||||
"validator_notes": list(surrogate.get("notes") or []),
|
||||
}
|
||||
case_reports.append(case_report)
|
||||
legacy_cases.append(
|
||||
{
|
||||
"run_id": case["run_id"],
|
||||
"session_id": case.get("session_id") or "",
|
||||
"baseline_score": baseline_score,
|
||||
"candidate_score": candidate_score,
|
||||
"delta": round(candidate_score - baseline_score, 4),
|
||||
}
|
||||
)
|
||||
preservation_report = _preservation_report(candidate, draft)
|
||||
return _report_from_case_reports(candidate, draft, case_reports, legacy_cases, preservation_report)
|
||||
|
||||
def _skipped(self, candidate: SkillLearningCandidate, draft: SkillDraft) -> SkillDraftEvalReport:
|
||||
return SkillDraftEvalReport(
|
||||
report_id=uuid4().hex,
|
||||
@ -115,6 +217,108 @@ def _candidate_score(baseline: float, draft: SkillDraft) -> float:
|
||||
return min(1.0, max(0.75, baseline + 0.05))
|
||||
|
||||
|
||||
def _draft_skill_context(draft: SkillDraft) -> SkillContext:
|
||||
tool_hints = draft.proposed_frontmatter.get("tools")
|
||||
return SkillContext(
|
||||
name=f"draft:{draft.skill_name}",
|
||||
content=draft.proposed_content,
|
||||
version=draft.draft_id,
|
||||
content_hash="draft",
|
||||
activation_reason="skill_replay_eval_candidate",
|
||||
tool_hints=[str(item) for item in tool_hints if str(item).strip()] if isinstance(tool_hints, list) else [],
|
||||
)
|
||||
|
||||
|
||||
def _preservation_report(candidate: SkillLearningCandidate, draft: SkillDraft) -> dict | None:
|
||||
if candidate.kind not in {"revise_skill", "merge_skills"}:
|
||||
return None
|
||||
base_content = str(candidate.evidence.get("base_content") or "") if isinstance(candidate.evidence, dict) else ""
|
||||
if not base_content.strip():
|
||||
return None
|
||||
return check_preservation(base_content=base_content, draft_content=draft.proposed_content)
|
||||
|
||||
|
||||
def _report_from_case_reports(
|
||||
candidate: SkillLearningCandidate,
|
||||
draft: SkillDraft,
|
||||
case_reports: list[dict],
|
||||
legacy_cases: list[dict],
|
||||
preservation_report: dict | None,
|
||||
) -> SkillDraftEvalReport:
|
||||
baseline_avg = sum(item["baseline_score"] for item in legacy_cases) / len(legacy_cases)
|
||||
candidate_avg = sum(item["candidate_score"] for item in legacy_cases) / len(legacy_cases)
|
||||
regressions = [item for item in legacy_cases if item["candidate_score"] < item["baseline_score"]]
|
||||
improved = [item for item in legacy_cases if item["candidate_score"] > item["baseline_score"]]
|
||||
unchanged = len(legacy_cases) - len(regressions) - len(improved)
|
||||
execution, surrogate, blocked = _coverage(case_reports)
|
||||
confidence = _confidence(execution, surrogate, blocked, [item.get("confidence") for item in case_reports])
|
||||
score_delta = candidate_avg - baseline_avg
|
||||
passed = candidate_avg >= 0.75 and not (regressions and score_delta <= 0) and blocked < 1.0
|
||||
return SkillDraftEvalReport(
|
||||
report_id=uuid4().hex,
|
||||
skill_name=draft.skill_name,
|
||||
draft_id=draft.draft_id,
|
||||
candidate_id=candidate.candidate_id,
|
||||
passed=passed,
|
||||
baseline_score_avg=round(baseline_avg, 4),
|
||||
candidate_score_avg=round(candidate_avg, 4),
|
||||
score_delta=round(score_delta, 4),
|
||||
regression_count=len(regressions),
|
||||
improved_count=len(improved),
|
||||
unchanged_count=unchanged,
|
||||
cases=legacy_cases,
|
||||
status="completed",
|
||||
created_at=_utc_now(),
|
||||
eval_version="replay-v1",
|
||||
mode="replay",
|
||||
execution_coverage=execution,
|
||||
surrogate_coverage=surrogate,
|
||||
blocked_coverage=blocked,
|
||||
confidence=confidence,
|
||||
case_reports=case_reports,
|
||||
tool_mode_summary={"executed": execution, "surrogate": surrogate, "blocked": blocked},
|
||||
preservation_report=preservation_report,
|
||||
)
|
||||
|
||||
|
||||
def _coverage(case_reports: list[dict]) -> tuple[float, float, float]:
|
||||
counts = {"executed": 0, "surrogate": 0, "blocked": 0}
|
||||
for report in case_reports:
|
||||
for call in report.get("tool_calls") or []:
|
||||
if isinstance(call, dict) and call.get("mode") in counts:
|
||||
counts[str(call["mode"])] += 1
|
||||
total = sum(counts.values())
|
||||
if total == 0:
|
||||
return 1.0, 0.0, 0.0
|
||||
return (
|
||||
round(counts["executed"] / total, 4),
|
||||
round(counts["surrogate"] / total, 4),
|
||||
round(counts["blocked"] / total, 4),
|
||||
)
|
||||
|
||||
|
||||
def _confidence(execution: float, surrogate: float, blocked: float, case_confidences: list[object]) -> str:
|
||||
if blocked > 0.0:
|
||||
return "low"
|
||||
if execution >= 0.75 and surrogate <= 0.25:
|
||||
return "high"
|
||||
if execution >= 0.25 or "medium" in case_confidences:
|
||||
return "medium"
|
||||
return "low"
|
||||
|
||||
|
||||
def _arm_mode_coverage(baseline: dict, candidate: dict, mode: str) -> float:
|
||||
calls = [*baseline.get("tool_calls", []), *candidate.get("tool_calls", [])]
|
||||
if not calls:
|
||||
return 1.0 if mode == "executed" else 0.0
|
||||
return round(sum(1 for call in calls if isinstance(call, dict) and call.get("mode") == mode) / len(calls), 4)
|
||||
|
||||
|
||||
def _arm_mode_count(baseline: dict, candidate: dict, mode: str) -> int:
|
||||
calls = [*baseline.get("tool_calls", []), *candidate.get("tool_calls", [])]
|
||||
return sum(1 for call in calls if isinstance(call, dict) and call.get("mode") == mode)
|
||||
|
||||
|
||||
def _utc_now() -> str:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from beaver.engine.providers import ProviderBundle
|
||||
from beaver.memory.skills import SkillDraftEvalReport, SkillDraftSafetyReport, SkillLearningCandidate, SkillLearningStore
|
||||
from beaver.skills.drafts import DraftService
|
||||
from beaver.skills.learning.eval import SkillDraftEvaluator
|
||||
from beaver.skills.learning.replay import ReplayRunner
|
||||
from beaver.skills.learning.service import SkillLearningService
|
||||
from beaver.skills.learning.safety import SkillDraftSafetyChecker
|
||||
from beaver.skills.publisher import SkillPublisher
|
||||
@ -285,11 +286,17 @@ class SkillLearningPipelineService:
|
||||
draft_id: str,
|
||||
*,
|
||||
provider_bundle: ProviderBundle | None,
|
||||
replay_runner: ReplayRunner | None = None,
|
||||
) -> SkillDraftEvalReport:
|
||||
draft = self.get_draft(skill_name, draft_id)
|
||||
candidate = self.get_candidate(candidate_id)
|
||||
evaluator = self.evaluator or SkillDraftEvaluator(self.learning_service.run_store)
|
||||
report = await evaluator.evaluate(candidate=candidate, draft=draft, provider_bundle=provider_bundle)
|
||||
report = await evaluator.evaluate(
|
||||
candidate=candidate,
|
||||
draft=draft,
|
||||
provider_bundle=provider_bundle,
|
||||
replay_runner=replay_runner,
|
||||
)
|
||||
self.learning_store.write_eval_report(report)
|
||||
if report.status == "skipped_provider_unavailable":
|
||||
status = "draft_ready"
|
||||
@ -330,6 +337,14 @@ class SkillLearningPipelineService:
|
||||
eval_report = self.get_eval_report(draft.skill_name, draft.draft_id)
|
||||
if eval_report is not None and eval_report.status != "skipped_provider_unavailable" and not eval_report.passed:
|
||||
raise ValueError("Draft eval report did not pass")
|
||||
if eval_report is not None and eval_report.mode == "replay":
|
||||
if eval_report.confidence == "low":
|
||||
raise ValueError("Draft replay eval has low confidence and requires revision before publish")
|
||||
if eval_report.blocked_coverage >= 1.0:
|
||||
raise ValueError("Draft replay eval blocked all important tool calls")
|
||||
preservation = eval_report.preservation_report or {}
|
||||
if preservation.get("passed") is False:
|
||||
raise ValueError("Draft preservation check did not pass")
|
||||
|
||||
def _mark_candidate_by_draft(
|
||||
self,
|
||||
|
||||
53
app-instance/backend/beaver/skills/learning/preservation.py
Normal file
53
app-instance/backend/beaver/skills/learning/preservation.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Preservation checks for skill revision drafts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
def check_preservation(*, base_content: str, draft_content: str) -> dict[str, Any]:
|
||||
base_sections = _sections(base_content)
|
||||
draft_sections = _sections(draft_content)
|
||||
preserved: list[str] = []
|
||||
changed: list[str] = []
|
||||
dropped: list[str] = []
|
||||
|
||||
for heading, body in base_sections.items():
|
||||
draft_body = draft_sections.get(heading)
|
||||
if draft_body is None:
|
||||
dropped.append(heading)
|
||||
continue
|
||||
preserved.append(heading)
|
||||
if _normalize(body) != _normalize(draft_body):
|
||||
changed.append(heading)
|
||||
|
||||
risk_level = "high" if dropped else "low"
|
||||
return {
|
||||
"passed": not dropped,
|
||||
"risk_level": risk_level,
|
||||
"preserved_sections": preserved,
|
||||
"changed_sections": changed,
|
||||
"dropped_sections": dropped,
|
||||
}
|
||||
|
||||
|
||||
def _sections(content: str) -> dict[str, str]:
|
||||
current = "body"
|
||||
sections: dict[str, list[str]] = {current: []}
|
||||
for line in (content or "").splitlines():
|
||||
match = re.match(r"^#{1,6}\s+(.+?)\s*$", line)
|
||||
if match:
|
||||
current = match.group(1).strip()
|
||||
sections.setdefault(current, [])
|
||||
continue
|
||||
sections.setdefault(current, []).append(line)
|
||||
return {
|
||||
heading: "\n".join(lines).strip()
|
||||
for heading, lines in sections.items()
|
||||
if "\n".join(lines).strip()
|
||||
}
|
||||
|
||||
|
||||
def _normalize(value: str) -> str:
|
||||
return re.sub(r"\s+", " ", value or "").strip().lower()
|
||||
203
app-instance/backend/beaver/skills/learning/replay.py
Normal file
203
app-instance/backend/beaver/skills/learning/replay.py
Normal file
@ -0,0 +1,203 @@
|
||||
"""Replay execution helpers for skill draft evaluation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from beaver.tools.base import ToolContext, ToolResult, ToolSpec
|
||||
from beaver.tools.registry.tool_registry import ToolRegistry
|
||||
from beaver.tools.runtime.executor import ToolExecutor
|
||||
|
||||
ToolExecutionMode = Literal["executed", "surrogate", "blocked"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ReplayToolPolicy:
|
||||
safe_toolsets: set[str] = field(default_factory=lambda: {"filesystem", "user_files", "core", "web", "search"})
|
||||
surrogate_transports: set[str] = field(default_factory=lambda: {"mcp", "connector"})
|
||||
destructive_terms: tuple[str, ...] = (
|
||||
"delete",
|
||||
"remove",
|
||||
"destroy",
|
||||
"revoke",
|
||||
"permission",
|
||||
"credential",
|
||||
"payment",
|
||||
"pay",
|
||||
)
|
||||
external_write_terms: tuple[str, ...] = (
|
||||
"send",
|
||||
"post",
|
||||
"publish",
|
||||
"create",
|
||||
"update",
|
||||
"invite",
|
||||
"reply",
|
||||
"forward",
|
||||
)
|
||||
|
||||
|
||||
class ReplayToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
inner: ToolExecutor,
|
||||
*,
|
||||
registry: ToolRegistry,
|
||||
policy: ReplayToolPolicy | None = None,
|
||||
) -> None:
|
||||
self.inner = inner
|
||||
self.registry = registry
|
||||
self.policy = policy or ReplayToolPolicy()
|
||||
self.traces: list[dict[str, Any]] = []
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any] | None,
|
||||
*,
|
||||
context: ToolContext | None = None,
|
||||
) -> ToolResult:
|
||||
tool = self.registry.get(tool_name)
|
||||
spec = tool.spec if tool is not None else ToolSpec(
|
||||
name=tool_name,
|
||||
description="unregistered tool",
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
toolset="unknown",
|
||||
)
|
||||
mode = classify_tool_mode(spec, self.policy)
|
||||
trace = {
|
||||
"trace_id": uuid4().hex,
|
||||
"tool_name": tool_name,
|
||||
"mode": mode,
|
||||
"arguments": dict(arguments or {}),
|
||||
"schema": dict(spec.input_schema),
|
||||
"toolset": spec.toolset,
|
||||
"metadata": dict(spec.metadata),
|
||||
"classification_reason": _classification_reason(spec, mode),
|
||||
}
|
||||
if mode == "executed":
|
||||
result = await self.inner.execute(tool_name, arguments or {}, context=context)
|
||||
trace["result"] = {
|
||||
"success": result.success,
|
||||
"error": result.error,
|
||||
"content": result.content[:2000],
|
||||
}
|
||||
self.traces.append(trace)
|
||||
return result
|
||||
if mode == "surrogate":
|
||||
trace["result"] = {
|
||||
"success": True,
|
||||
"error": "replay_surrogate",
|
||||
"content": "Tool call recorded for surrogate evaluation.",
|
||||
}
|
||||
self.traces.append(trace)
|
||||
return ToolResult(
|
||||
success=True,
|
||||
content="Tool call recorded for surrogate evaluation.",
|
||||
tool_name=tool_name,
|
||||
error="replay_surrogate",
|
||||
raw_output=trace,
|
||||
)
|
||||
trace["result"] = {
|
||||
"success": False,
|
||||
"error": "replay_blocked",
|
||||
"content": "Tool call blocked by replay policy.",
|
||||
}
|
||||
self.traces.append(trace)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
content="Tool call blocked by replay policy.",
|
||||
tool_name=tool_name,
|
||||
error="replay_blocked",
|
||||
raw_output=trace,
|
||||
)
|
||||
|
||||
async def execute_tool_call(self, tool_call: Any, *, context: ToolContext | None = None) -> ToolResult:
|
||||
tool_name, arguments = ToolExecutor._normalize_tool_call(tool_call)
|
||||
return await self.execute(tool_name, arguments, context=context)
|
||||
|
||||
|
||||
def classify_tool_mode(spec: ToolSpec, policy: ReplayToolPolicy | None = None) -> ToolExecutionMode:
|
||||
policy = policy or ReplayToolPolicy()
|
||||
name = spec.name.lower()
|
||||
toolset = spec.toolset.lower()
|
||||
metadata = {str(key).lower(): str(value).lower() for key, value in spec.metadata.items()}
|
||||
if any(term in name for term in policy.destructive_terms):
|
||||
return "blocked"
|
||||
if toolset in policy.safe_toolsets:
|
||||
return "executed"
|
||||
if metadata.get("transport") in policy.surrogate_transports or toolset in {"mcp", "connector", "external"}:
|
||||
if any(term in name for term in policy.external_write_terms):
|
||||
return "surrogate"
|
||||
return "executed"
|
||||
return "surrogate"
|
||||
|
||||
|
||||
def _classification_reason(spec: ToolSpec, mode: ToolExecutionMode) -> str:
|
||||
return f"{spec.name} classified as {mode} from toolset={spec.toolset} metadata={spec.metadata}"
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ReplayArmRequest:
|
||||
case_id: str
|
||||
arm: str
|
||||
task_text: str
|
||||
pinned_skill_names: list[str] = field(default_factory=list)
|
||||
pinned_skill_contexts: list[Any] = field(default_factory=list)
|
||||
provider_bundle: Any | None = None
|
||||
model_settings: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ReplayRunner:
|
||||
def __init__(self, *, agent_loop: Any, policy: ReplayToolPolicy | None = None) -> None:
|
||||
self.agent_loop = agent_loop
|
||||
self.policy = policy or ReplayToolPolicy()
|
||||
|
||||
async def run_arm(self, request: ReplayArmRequest) -> dict[str, Any]:
|
||||
loaded = self.agent_loop.boot()
|
||||
replay_executor = ReplayToolExecutor(
|
||||
loaded.tool_executor,
|
||||
registry=loaded.tool_registry,
|
||||
policy=self.policy,
|
||||
)
|
||||
result = await self.agent_loop.process_direct(
|
||||
request.task_text,
|
||||
provider_bundle=request.provider_bundle,
|
||||
include_skill_assembly=False,
|
||||
include_tools=True,
|
||||
pinned_skill_names=request.pinned_skill_names,
|
||||
pinned_skill_contexts=request.pinned_skill_contexts,
|
||||
max_tool_iterations=int(request.model_settings.get("max_tool_iterations") or 4),
|
||||
temperature=float(request.model_settings.get("temperature") or 0.0),
|
||||
source="skill_replay_eval",
|
||||
tool_executor_override=replay_executor,
|
||||
)
|
||||
return {
|
||||
"case_id": request.case_id,
|
||||
"arm": request.arm,
|
||||
"session_id": result.session_id,
|
||||
"run_id": result.run_id,
|
||||
"task_text": request.task_text,
|
||||
"finish_reason": result.finish_reason,
|
||||
"final_answer": result.output_text,
|
||||
"tool_calls": list(replay_executor.traces),
|
||||
"artifacts": [],
|
||||
"side_effects": _side_effects_from_traces(replay_executor.traces),
|
||||
}
|
||||
|
||||
|
||||
def _side_effects_from_traces(traces: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
effects: list[dict[str, Any]] = []
|
||||
for trace in traces:
|
||||
if trace.get("mode") in {"surrogate", "blocked"}:
|
||||
effects.append(
|
||||
{
|
||||
"tool_name": trace.get("tool_name"),
|
||||
"mode": trace.get("mode"),
|
||||
"arguments": trace.get("arguments"),
|
||||
"classification_reason": trace.get("classification_reason"),
|
||||
}
|
||||
)
|
||||
return effects
|
||||
@ -205,7 +205,13 @@ class SkillLearningService:
|
||||
)
|
||||
if candidate.kind == "merge_skills":
|
||||
target_name = self._suggest_skill_name(candidate, packet)
|
||||
payload = await self.synthesizer.synthesize_merge(candidate, packet, provider, model)
|
||||
payload = await self.synthesizer.synthesize_merge(
|
||||
candidate,
|
||||
packet,
|
||||
provider,
|
||||
model,
|
||||
base_skill=self._merged_base_skill_snapshot(candidate.related_skill_names),
|
||||
)
|
||||
return self.draft_service.create_merge_draft(
|
||||
skill_name=target_name,
|
||||
base_version=None,
|
||||
@ -217,7 +223,13 @@ class SkillLearningService:
|
||||
)
|
||||
target_skill = candidate.related_skill_names[0]
|
||||
base_version = candidate.evidence.get("skill_version")
|
||||
payload = await self.synthesizer.synthesize_revision(candidate, packet, provider, model)
|
||||
payload = await self.synthesizer.synthesize_revision(
|
||||
candidate,
|
||||
packet,
|
||||
provider,
|
||||
model,
|
||||
base_skill=self._base_skill_snapshot(target_skill, base_version),
|
||||
)
|
||||
return self.draft_service.create_revision_draft(
|
||||
skill_name=target_skill,
|
||||
base_version=base_version,
|
||||
@ -228,6 +240,46 @@ class SkillLearningService:
|
||||
evidence_refs=[{"run_id": item} for item in candidate.source_run_ids],
|
||||
)
|
||||
|
||||
def _base_skill_snapshot(self, skill_name: str, version: str | None) -> dict[str, Any] | None:
|
||||
loaded = self.draft_service.store.read_published_skill(skill_name, version)
|
||||
if loaded is None:
|
||||
return None
|
||||
return {
|
||||
"skill_name": loaded.version.skill_name,
|
||||
"version": loaded.version.version,
|
||||
"frontmatter": dict(loaded.version.frontmatter),
|
||||
"content": loaded.content,
|
||||
"summary": loaded.version.summary,
|
||||
"tool_hints": list(loaded.version.tool_hints),
|
||||
}
|
||||
|
||||
def _merged_base_skill_snapshot(self, skill_names: list[str]) -> dict[str, Any] | None:
|
||||
snapshots = [
|
||||
snapshot
|
||||
for name in skill_names
|
||||
if (snapshot := self._base_skill_snapshot(name, None)) is not None
|
||||
]
|
||||
if not snapshots:
|
||||
return None
|
||||
return {
|
||||
"skill_name": "merge:" + ",".join(str(item["skill_name"]) for item in snapshots),
|
||||
"version": "mixed",
|
||||
"frontmatter": {"merged_skills": [item["frontmatter"] for item in snapshots]},
|
||||
"content": "\n\n".join(
|
||||
f"<!-- base skill: {item['skill_name']} {item['version']} -->\n{item['content']}"
|
||||
for item in snapshots
|
||||
),
|
||||
"summary": "\n".join(str(item["summary"]) for item in snapshots if item.get("summary")),
|
||||
"tool_hints": list(
|
||||
dict.fromkeys(
|
||||
tool
|
||||
for item in snapshots
|
||||
for tool in item.get("tool_hints", [])
|
||||
if str(tool).strip()
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
def rescore_skill_versions(self) -> list[SkillPerformanceSnapshot]:
|
||||
snapshots: list[SkillPerformanceSnapshot] = []
|
||||
grouped: dict[tuple[str, str], list[SkillEffectRecord]] = {}
|
||||
|
||||
53
app-instance/backend/beaver/skills/learning/surrogate.py
Normal file
53
app-instance/backend/beaver/skills/learning/surrogate.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Surrogate evaluation for replay tool calls that cannot execute safely."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SurrogateToolEvaluator:
|
||||
async def evaluate(self, *, task_text: str, baseline: dict[str, Any], candidate: dict[str, Any]) -> dict[str, Any]:
|
||||
baseline_score = _score_arm(task_text, baseline)
|
||||
candidate_score = _score_arm(task_text, candidate)
|
||||
surrogate_count = _mode_count(baseline, "surrogate") + _mode_count(candidate, "surrogate")
|
||||
blocked_count = _mode_count(baseline, "blocked") + _mode_count(candidate, "blocked")
|
||||
confidence = "low" if blocked_count else ("medium" if surrogate_count <= 2 else "low")
|
||||
return {
|
||||
"baseline_score": baseline_score,
|
||||
"candidate_score": candidate_score,
|
||||
"delta": round(candidate_score - baseline_score, 4),
|
||||
"surrogate_tool_count": surrogate_count,
|
||||
"blocked_tool_count": blocked_count,
|
||||
"confidence": confidence,
|
||||
"notes": [
|
||||
"Surrogate score is based on intended tool calls, schemas, arguments, and task relevance.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _score_arm(task_text: str, arm: dict[str, Any]) -> float:
|
||||
calls = [item for item in arm.get("tool_calls") or [] if isinstance(item, dict)]
|
||||
if not calls:
|
||||
return 0.5
|
||||
scores = [_score_call(task_text, call) for call in calls]
|
||||
return round(sum(scores) / len(scores), 4)
|
||||
|
||||
|
||||
def _score_call(task_text: str, call: dict[str, Any]) -> float:
|
||||
if call.get("mode") == "blocked":
|
||||
return 0.2
|
||||
if call.get("mode") == "executed":
|
||||
result = call.get("result") if isinstance(call.get("result"), dict) else {}
|
||||
return 0.85 if result.get("success") is not False else 0.35
|
||||
arguments = dict(call.get("arguments") or {})
|
||||
if not arguments:
|
||||
return 0.45
|
||||
non_empty = sum(1 for value in arguments.values() if str(value).strip())
|
||||
completeness = non_empty / max(1, len(arguments))
|
||||
argument_text = " ".join(str(value).lower() for value in arguments.values())
|
||||
relevance = 0.15 if any(token and token in argument_text for token in task_text.lower().split()[:16]) else 0.0
|
||||
return round(min(0.9, 0.5 + 0.3 * completeness + relevance), 4)
|
||||
|
||||
|
||||
def _mode_count(arm: dict[str, Any], mode: str) -> int:
|
||||
return sum(1 for item in arm.get("tool_calls") or [] if isinstance(item, dict) and item.get("mode") == mode)
|
||||
@ -17,8 +17,9 @@ class SkillDraftSynthesizer:
|
||||
evidence_packet: EvidencePacket,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
base_skill: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return await self._synthesize(candidate, evidence_packet, provider, model, "revise")
|
||||
return await self._synthesize(candidate, evidence_packet, provider, model, "revise", base_skill=base_skill)
|
||||
|
||||
async def synthesize_new_skill(
|
||||
self,
|
||||
@ -27,7 +28,7 @@ class SkillDraftSynthesizer:
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
) -> dict[str, Any]:
|
||||
return await self._synthesize(candidate, evidence_packet, provider, model, "new")
|
||||
return await self._synthesize(candidate, evidence_packet, provider, model, "new", base_skill=None)
|
||||
|
||||
async def synthesize_merge(
|
||||
self,
|
||||
@ -35,8 +36,9 @@ class SkillDraftSynthesizer:
|
||||
evidence_packet: EvidencePacket,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
base_skill: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return await self._synthesize(candidate, evidence_packet, provider, model, "merge")
|
||||
return await self._synthesize(candidate, evidence_packet, provider, model, "merge", base_skill=base_skill)
|
||||
|
||||
async def _synthesize(
|
||||
self,
|
||||
@ -45,15 +47,18 @@ class SkillDraftSynthesizer:
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
action: str,
|
||||
*,
|
||||
base_skill: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
prompt = self._build_prompt(candidate, evidence_packet, action)
|
||||
prompt = self._build_prompt(candidate, evidence_packet, action, base_skill=base_skill)
|
||||
response = await provider.chat(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You synthesize Beaver skill drafts from execution evidence. "
|
||||
"Return only JSON with keys: frontmatter, content, change_reason."
|
||||
"Return only JSON with keys: frontmatter, content, change_reason, "
|
||||
"preserved_sections, changed_sections, dropped_sections."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
@ -69,11 +74,30 @@ class SkillDraftSynthesizer:
|
||||
return self._fallback_payload(candidate, evidence_packet, action)
|
||||
|
||||
@staticmethod
|
||||
def _build_prompt(candidate: SkillLearningCandidate, evidence_packet: EvidencePacket, action: str) -> str:
|
||||
def _build_prompt(
|
||||
candidate: SkillLearningCandidate,
|
||||
evidence_packet: EvidencePacket,
|
||||
action: str,
|
||||
base_skill: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
tool_names = _coerce_string_list(evidence_packet.metadata.get("tool_names"))
|
||||
tool_section = ", ".join(tool_names) if tool_names else "none observed"
|
||||
selected_tool_names = _coerce_string_list(evidence_packet.metadata.get("selected_tool_names"))
|
||||
selected_tool_section = ", ".join(selected_tool_names) if selected_tool_names else "none recorded"
|
||||
base_section = ""
|
||||
if base_skill:
|
||||
base_section = (
|
||||
"\n\nBase skill snapshot:\n"
|
||||
f"- skill_name: {base_skill.get('skill_name')}\n"
|
||||
f"- version: {base_skill.get('version')}\n"
|
||||
f"- frontmatter: {json.dumps(base_skill.get('frontmatter') or {}, ensure_ascii=False, sort_keys=True)}\n"
|
||||
f"- tool_hints: {base_skill.get('tool_hints') or []}\n"
|
||||
f"- summary: {base_skill.get('summary') or ''}\n"
|
||||
"Base skill content:\n"
|
||||
f"{base_skill.get('content') or ''}\n"
|
||||
"Preserve existing instructions unless the evidence requires a change. "
|
||||
"If any section is changed or dropped, explain it in changed_sections or dropped_sections."
|
||||
)
|
||||
return (
|
||||
f"Action: {action}\n"
|
||||
f"Candidate kind: {candidate.kind}\n"
|
||||
@ -83,11 +107,13 @@ class SkillDraftSynthesizer:
|
||||
f"Run-selected tool names: {selected_tool_section}\n"
|
||||
f"Task summaries:\n- " + "\n- ".join(evidence_packet.task_summaries)
|
||||
+ "\n\nSession excerpts:\n" + "\n\n".join(evidence_packet.session_excerpts)
|
||||
+ base_section
|
||||
+ "\n\nReturn JSON only. The frontmatter object must include:"
|
||||
+ "\n- description: a concise skill description"
|
||||
+ "\n- tools: an explicit JSON array of exact tool names this skill needs. "
|
||||
+ "Prefer called tool names when the workflow depends on them; use run-selected tool names only when clearly required. "
|
||||
+ "Use [] only when no tool is required."
|
||||
+ "\nThe JSON may include preserved_sections, changed_sections, and dropped_sections arrays."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -111,6 +137,9 @@ class SkillDraftSynthesizer:
|
||||
"frontmatter": frontmatter,
|
||||
"content": content_value.strip(),
|
||||
"change_reason": str(payload.get("change_reason") or ""),
|
||||
"preserved_sections": _coerce_string_list(payload.get("preserved_sections")),
|
||||
"changed_sections": _coerce_string_list(payload.get("changed_sections")),
|
||||
"dropped_sections": _coerce_string_list(payload.get("dropped_sections")),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@ -124,6 +153,9 @@ class SkillDraftSynthesizer:
|
||||
"frontmatter": frontmatter,
|
||||
"content": str(payload.get("content") or "").strip(),
|
||||
"change_reason": str(payload.get("change_reason") or ""),
|
||||
"preserved_sections": _coerce_string_list(payload.get("preserved_sections")),
|
||||
"changed_sections": _coerce_string_list(payload.get("changed_sections")),
|
||||
"dropped_sections": _coerce_string_list(payload.get("dropped_sections")),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@ -138,6 +170,9 @@ class SkillDraftSynthesizer:
|
||||
},
|
||||
"content": f"# {title}\n\n## Evidence\n\n{content}\n",
|
||||
"change_reason": candidate.reason or f"Fallback {action} synthesis.",
|
||||
"preserved_sections": [],
|
||||
"changed_sections": [],
|
||||
"dropped_sections": [],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -161,6 +161,9 @@ class MainAgentRouter:
|
||||
"Critical policy:\n"
|
||||
"- If there is an active Task, choose continue_task or revise_task unless the user's topic is completely unrelated "
|
||||
"to that Task or the user explicitly closes/abandons it.\n"
|
||||
"- With an active Task, choose simple_chat for unrelated lightweight conversation and new_task for unrelated work "
|
||||
"that needs Task capabilities. Either decision starts a new topic.\n"
|
||||
"- An unrelated lightweight conversation must not be classified as revise_task merely because the active Task is awaiting acceptance.\n"
|
||||
"- Choose revise_task when the active Task is awaiting feedback or needs revision and the user asks for changes "
|
||||
"such as '改一下', '加上', '删除', '换成', '再详细点', '格式改成', '不要', or equivalent wording.\n"
|
||||
"- Choose continue_task for neutral follow-up questions or additional next steps that do not imply dissatisfaction with the previous result.\n"
|
||||
|
||||
@ -180,8 +180,10 @@ class ObjectBackedTool(BaseTool):
|
||||
|
||||
if "current_session_id" not in arguments and hasattr(self.backend, "current_session_id"):
|
||||
arguments["current_session_id"] = context.session_id
|
||||
if "workspace" not in arguments and hasattr(self.backend, "workspace"):
|
||||
if "workspace" not in arguments and (hasattr(self.backend, "workspace") or self._backend_accepts_argument("workspace")):
|
||||
arguments["workspace"] = context.workspace
|
||||
if "services" not in arguments and self._backend_accepts_argument("services"):
|
||||
arguments["services"] = context.services
|
||||
if "metadata" not in arguments and self._backend_accepts_argument("metadata"):
|
||||
arguments["metadata"] = context.metadata
|
||||
|
||||
|
||||
@ -9,6 +9,15 @@ from .skill_view import SkillViewTool, skill_view
|
||||
from .session_search import SessionSearchTool, session_search
|
||||
from .terminal import ExecuteCodeTool, ProcessTool, TerminalTool
|
||||
from .utility import ClarifyTool, DelegateTool, SendMessageTool, SpawnTool, TodoTool
|
||||
from .user_files import (
|
||||
UserFilesCopyToWorkspaceTool,
|
||||
UserFilesDeleteTool,
|
||||
UserFilesListTool,
|
||||
UserFilesMkdirTool,
|
||||
UserFilesPublishOutputTool,
|
||||
UserFilesReadTool,
|
||||
UserFilesWriteTool,
|
||||
)
|
||||
from .web import WebFetchTool, WebSearchTool
|
||||
|
||||
__all__ = [
|
||||
@ -30,6 +39,13 @@ __all__ = [
|
||||
"SessionSearchTool",
|
||||
"TerminalTool",
|
||||
"TodoTool",
|
||||
"UserFilesCopyToWorkspaceTool",
|
||||
"UserFilesDeleteTool",
|
||||
"UserFilesListTool",
|
||||
"UserFilesMkdirTool",
|
||||
"UserFilesPublishOutputTool",
|
||||
"UserFilesReadTool",
|
||||
"UserFilesWriteTool",
|
||||
"ClarifyTool",
|
||||
"WebFetchTool",
|
||||
"WebSearchTool",
|
||||
|
||||
@ -14,7 +14,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ MAX_READ_CHARS = 120_000
|
||||
MAX_SEARCH_RESULTS = 200
|
||||
MAX_SEARCH_FILE_BYTES = 2_000_000
|
||||
MAX_SEARCH_FILES = 5_000
|
||||
USER_FILE_VIRTUAL_ROOTS = {"uploads", "outputs", "shared", "tasks"}
|
||||
SKIP_DIR_NAMES = {
|
||||
".git",
|
||||
".hg",
|
||||
@ -161,9 +162,28 @@ def _workspace_root(workspace: str | None) -> Path:
|
||||
return root
|
||||
|
||||
|
||||
def _virtual_user_file_error(user_path: str | None) -> str | None:
|
||||
raw = str(user_path or ".").replace("\\", "/").strip()
|
||||
if not raw or raw in {".", "./"}:
|
||||
return None
|
||||
try:
|
||||
parts = [part for part in PurePosixPath(raw.strip("/")).parts if part not in ("", ".")]
|
||||
except TypeError:
|
||||
return None
|
||||
if parts and parts[0] in USER_FILE_VIRTUAL_ROOTS:
|
||||
return (
|
||||
f"{user_path} is a personal agent file system path, not a workspace path. "
|
||||
"Use user_files_read or user_files_copy_to_workspace for reads; use "
|
||||
"user_files_write for shared/tasks files or user_files_publish_output for outputs."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_existing_path(workspace: str | None, user_path: str | None) -> tuple[Path, Path]:
|
||||
"""Resolve a user path and ensure the real target stays inside workspace."""
|
||||
|
||||
if error := _virtual_user_file_error(user_path):
|
||||
raise WorkspacePathError(error)
|
||||
root = _workspace_root(workspace)
|
||||
raw_path = Path(user_path or ".").expanduser()
|
||||
candidate = raw_path if raw_path.is_absolute() else root / raw_path
|
||||
@ -178,6 +198,8 @@ def _resolve_existing_path(workspace: str | None, user_path: str | None) -> tupl
|
||||
|
||||
|
||||
def _resolve_writable_path(workspace: str | None, user_path: str | None) -> tuple[Path, Path]:
|
||||
if error := _virtual_user_file_error(user_path):
|
||||
raise WorkspacePathError(error)
|
||||
root = _workspace_root(workspace)
|
||||
if not user_path or not str(user_path).strip():
|
||||
raise WorkspacePathError("path is required")
|
||||
|
||||
389
app-instance/backend/beaver/tools/builtins/user_files.py
Normal file
389
app-instance/backend/beaver/tools/builtins/user_files.py
Normal file
@ -0,0 +1,389 @@
|
||||
"""Agent-facing tools for the user-visible file system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import json
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.config.loader import load_config
|
||||
from beaver.services.user_file_resolver import UserFileStorageResolver, build_file_auth_context
|
||||
from beaver.services.user_files import AgentUserFilePolicy, UserFileError, UserFilePathError, UserFileService
|
||||
|
||||
|
||||
MAX_WORKSPACE_STAGE_BYTES = 50 * 1024 * 1024
|
||||
|
||||
|
||||
def _json_result(success: bool, **payload: Any) -> str:
|
||||
return json.dumps({"success": success, **payload}, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
async def _service(workspace: str | None, services: dict[str, Any] | None = None) -> UserFileService:
|
||||
if not workspace:
|
||||
raise UserFileError("workspace is not configured for user file tools")
|
||||
config = (services or {}).get("beaver_config")
|
||||
if config is None:
|
||||
config = load_config(workspace=workspace)
|
||||
backend_id = config.backend_identity.backend_id.strip() or config.backend_identity.client_id.strip() or "agent"
|
||||
auth_context = build_file_auth_context(
|
||||
username=backend_id,
|
||||
config=config,
|
||||
user_id=(services or {}).get("user_id"),
|
||||
auth_source="beaver-agent-runtime",
|
||||
)
|
||||
return await UserFileStorageResolver(
|
||||
config=config,
|
||||
workspace=Path(workspace),
|
||||
auth_context=auth_context,
|
||||
).service()
|
||||
|
||||
|
||||
def _agent_policy(services: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None) -> AgentUserFilePolicy:
|
||||
payload = services or {}
|
||||
meta = metadata or {}
|
||||
task_id = str(payload.get("task_id") or meta.get("task_id") or "").strip() or None
|
||||
fallback = str(payload.get("run_id") or meta.get("run_id") or meta.get("session_id") or "interactive")
|
||||
return AgentUserFilePolicy(task_id=task_id, fallback_scope=fallback)
|
||||
|
||||
|
||||
def _workspace_root(workspace: str | None) -> Path:
|
||||
if not workspace:
|
||||
raise UserFilePathError("workspace is not configured for user file tools")
|
||||
root = Path(workspace).expanduser().resolve()
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
return root
|
||||
|
||||
|
||||
def _resolve_workspace_source(workspace: str | None, source_path: str) -> tuple[Path, Path]:
|
||||
root = _workspace_root(workspace)
|
||||
if not source_path or not str(source_path).strip():
|
||||
raise UserFilePathError("source_path is required")
|
||||
raw = Path(str(source_path)).expanduser()
|
||||
candidate = raw if raw.is_absolute() else root / raw
|
||||
resolved = candidate.resolve(strict=True)
|
||||
try:
|
||||
resolved.relative_to(root)
|
||||
except ValueError as exc:
|
||||
raise UserFilePathError("source_path escapes workspace") from exc
|
||||
if not resolved.is_file():
|
||||
raise UserFilePathError("source_path must be a file")
|
||||
return root, resolved
|
||||
|
||||
|
||||
def _resolve_workspace_destination(workspace: str | None, target_path: str) -> tuple[Path, Path]:
|
||||
root = _workspace_root(workspace)
|
||||
if not target_path or not str(target_path).strip():
|
||||
raise UserFilePathError("workspace_path is required")
|
||||
raw = Path(str(target_path)).expanduser()
|
||||
if raw.is_absolute():
|
||||
raise UserFilePathError("workspace_path must be relative")
|
||||
candidate = (root / raw).resolve()
|
||||
try:
|
||||
candidate.relative_to(root)
|
||||
except ValueError as exc:
|
||||
raise UserFilePathError("workspace_path escapes workspace") from exc
|
||||
return root, candidate
|
||||
|
||||
|
||||
def _relative_path(root: Path, path: Path) -> str:
|
||||
return path.relative_to(root).as_posix()
|
||||
|
||||
|
||||
USER_FILES_LIST_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
"description": "User file path under uploads, outputs, shared, or tasks. Empty path lists the virtual roots.",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
USER_FILES_READ_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "User file path to read."},
|
||||
"max_bytes": {
|
||||
"type": "integer",
|
||||
"default": 120000,
|
||||
"minimum": 1,
|
||||
"maximum": 1000000,
|
||||
"description": "Maximum bytes to return in model context.",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
USER_FILES_WRITE_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "User file path to create or replace."},
|
||||
"content": {"type": "string", "description": "Text content to write."},
|
||||
"content_type": {"type": "string", "default": "text/plain"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
USER_FILES_DELETE_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "User file or directory path to delete."}},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
USER_FILES_MKDIR_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "User file directory path to create."}},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
USER_FILES_COPY_TO_WORKSPACE_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Readable user file path under uploads, outputs, shared, or an authorized tasks namespace.",
|
||||
},
|
||||
"workspace_path": {
|
||||
"type": "string",
|
||||
"description": "Optional relative workspace destination. Defaults to user-files/tasks/{task_id}/<filename> or user-files/runs/<scope>/<filename>.",
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
USER_FILES_PUBLISH_OUTPUT_PARAMETERS: dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_path": {
|
||||
"type": "string",
|
||||
"description": "Workspace file path to publish. Absolute paths are allowed only if they stay inside the workspace.",
|
||||
},
|
||||
"target_path": {
|
||||
"type": "string",
|
||||
"description": "Output path under outputs/, such as outputs/report.md.",
|
||||
},
|
||||
"content_type": {
|
||||
"type": "string",
|
||||
"description": "Optional content type. If omitted, Beaver guesses from the target filename.",
|
||||
},
|
||||
},
|
||||
"required": ["source_path", "target_path"],
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFilesListTool:
|
||||
name: str = "user_files_list"
|
||||
description: str = (
|
||||
"List files and folders in the personal agent file system. Use the virtual roots only: "
|
||||
"uploads for files the user provides to the agent, outputs for agent-generated results, "
|
||||
"shared for reusable user/agent reference material, and tasks for files bound to a specific task. "
|
||||
"An empty path lists the four roots; this tool never exposes MinIO buckets, credentials, or internal workspace paths."
|
||||
)
|
||||
toolset: str = "user_files"
|
||||
always_available: bool = True
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(USER_FILES_LIST_PARAMETERS))
|
||||
|
||||
async def execute(self, *, path: str = "", workspace: str | None = None, services: dict[str, Any] | None = None) -> str:
|
||||
try:
|
||||
return _json_result(True, **await (await _service(workspace, services)).browse(path))
|
||||
except UserFileError as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFilesReadTool:
|
||||
name: str = "user_files_read"
|
||||
description: str = (
|
||||
"Read a bounded text preview from the personal agent file system. Use this to inspect user-provided "
|
||||
"files in uploads, long-lived shared material in shared, task files in tasks, or generated outputs in outputs. "
|
||||
"The path must stay under uploads, outputs, shared, or tasks; internal workspace and MinIO implementation paths are hidden."
|
||||
)
|
||||
toolset: str = "user_files"
|
||||
always_available: bool = True
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(USER_FILES_READ_PARAMETERS))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
path: str,
|
||||
max_bytes: int = 120000,
|
||||
workspace: str | None = None,
|
||||
services: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
path = _agent_policy(services, metadata).validate_read(path)
|
||||
limit = max(1, min(int(max_bytes), 1_000_000))
|
||||
return _json_result(True, **await (await _service(workspace, services)).preview(path, max_bytes=limit))
|
||||
except UserFileError as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFilesWriteTool:
|
||||
name: str = "user_files_write"
|
||||
description: str = (
|
||||
"Create or replace a text file in the personal agent file system. Store agent-generated deliverables "
|
||||
"under outputs, reusable long-lived context under shared, and task-bound files under the current "
|
||||
"tasks/{task_id}/ namespace. Never write to uploads; uploaded files are immutable agent inputs. "
|
||||
"For modifications to uploaded files, copy them to the workspace, edit there, then publish to outputs."
|
||||
)
|
||||
toolset: str = "user_files"
|
||||
always_available: bool = False
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(USER_FILES_WRITE_PARAMETERS))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
path: str,
|
||||
content: str,
|
||||
content_type: str = "text/plain",
|
||||
workspace: str | None = None,
|
||||
services: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
path = _agent_policy(services, metadata).validate_write(path)
|
||||
return _json_result(True, **await (await _service(workspace, services)).write_file(path, content, content_type=content_type))
|
||||
except UserFileError as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFilesDeleteTool:
|
||||
name: str = "user_files_delete"
|
||||
description: str = (
|
||||
"Agent deletion is disabled for the personal agent file system. User-visible file deletion is owned by "
|
||||
"the Files page or user-side APIs; agents should use task/workspace cleanup instead."
|
||||
)
|
||||
toolset: str = "user_files"
|
||||
always_available: bool = False
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(USER_FILES_DELETE_PARAMETERS))
|
||||
|
||||
async def execute(self, *, path: str, workspace: str | None = None, services: dict[str, Any] | None = None) -> str:
|
||||
try:
|
||||
_agent_policy(services).validate_delete(path)
|
||||
return _json_result(False, path=path, deleted=False)
|
||||
except UserFileError as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFilesMkdirTool:
|
||||
name: str = "user_files_mkdir"
|
||||
description: str = (
|
||||
"Create a subfolder in the personal agent file system under uploads, outputs, shared, or tasks. "
|
||||
"Use folders to organize agent outputs, reusable shared material, or current task-specific files. "
|
||||
"Do not create folders under uploads because uploads is user-owned input storage."
|
||||
)
|
||||
toolset: str = "user_files"
|
||||
always_available: bool = False
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(USER_FILES_MKDIR_PARAMETERS))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
path: str,
|
||||
workspace: str | None = None,
|
||||
services: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
path = _agent_policy(services, metadata).validate_mkdir(path)
|
||||
return _json_result(True, **await (await _service(workspace, services)).mkdir(path))
|
||||
except UserFileError as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFilesCopyToWorkspaceTool:
|
||||
name: str = "user_files_copy_to_workspace"
|
||||
description: str = (
|
||||
"Copy a readable file from the personal agent file system into the internal workspace before editing, "
|
||||
"running, or validating it. Use this for user-uploaded files under uploads: the original upload remains "
|
||||
"unchanged, and the returned workspace_path can be used with workspace tools like read_file or patch_file."
|
||||
)
|
||||
toolset: str = "user_files"
|
||||
always_available: bool = False
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(USER_FILES_COPY_TO_WORKSPACE_PARAMETERS))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
path: str,
|
||||
workspace_path: str | None = None,
|
||||
workspace: str | None = None,
|
||||
services: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
policy = _agent_policy(services, metadata)
|
||||
path = policy.validate_read(path)
|
||||
content = await (await _service(workspace, services)).download(path)
|
||||
if content.size > MAX_WORKSPACE_STAGE_BYTES:
|
||||
raise UserFilePathError(f"File is too large to copy to workspace (max {MAX_WORKSPACE_STAGE_BYTES} bytes)")
|
||||
default_path = f"user-files/{policy.task_namespace}/{Path(path).name}"
|
||||
root, destination = _resolve_workspace_destination(workspace, workspace_path or default_path)
|
||||
destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
destination.write_bytes(content.content)
|
||||
return _json_result(
|
||||
True,
|
||||
path=path,
|
||||
workspace_path=_relative_path(root, destination),
|
||||
bytes=len(content.content),
|
||||
content_type=content.content_type,
|
||||
)
|
||||
except UserFileError as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
except OSError as exc:
|
||||
return _json_result(False, error=str(exc), path=path)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class UserFilesPublishOutputTool:
|
||||
name: str = "user_files_publish_output"
|
||||
description: str = (
|
||||
"Publish a validated workspace file to the personal agent file system under outputs/. Use this after "
|
||||
"staging and editing files in the workspace. Publishing never writes to uploads, and it hides MinIO "
|
||||
"bucket, namespace, and credential details from the agent."
|
||||
)
|
||||
toolset: str = "user_files"
|
||||
always_available: bool = False
|
||||
parameters: dict[str, Any] = field(default_factory=lambda: dict(USER_FILES_PUBLISH_OUTPUT_PARAMETERS))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
source_path: str,
|
||||
target_path: str,
|
||||
content_type: str | None = None,
|
||||
workspace: str | None = None,
|
||||
services: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
try:
|
||||
root, source = _resolve_workspace_source(workspace, source_path)
|
||||
normalized_target = target_path.strip().strip("/")
|
||||
if not normalized_target.startswith("outputs/"):
|
||||
raise UserFilePathError("Published output target must be under outputs/")
|
||||
guessed_type, _ = mimetypes.guess_type(normalized_target)
|
||||
raw = source.read_bytes()
|
||||
entry = await (await _service(workspace, services)).write_file(
|
||||
normalized_target,
|
||||
raw,
|
||||
content_type=content_type or guessed_type or "application/octet-stream",
|
||||
)
|
||||
return _json_result(
|
||||
True,
|
||||
source_path=_relative_path(root, source),
|
||||
target_path=normalized_target,
|
||||
bytes=len(raw),
|
||||
**entry,
|
||||
)
|
||||
except UserFileError as exc:
|
||||
return _json_result(False, error=str(exc), source_path=source_path, target_path=target_path)
|
||||
except OSError as exc:
|
||||
return _json_result(False, error=str(exc), source_path=source_path, target_path=target_path)
|
||||
@ -886,11 +886,9 @@ tool_hints:
|
||||
<div class="panel">
|
||||
<h3>Beaver 平台部署层</h3>
|
||||
<ul>
|
||||
<li><span class="code">auth-portal</span>:用户入口、登录注册。</li>
|
||||
<li><span class="code">authz-service</span>:授权服务和后端身份。</li>
|
||||
<li><span class="code">deploy-control</span>:创建和管理单用户实例。</li>
|
||||
<li><span class="code">router-proxy</span>:按 Host 转发到实例。</li>
|
||||
<li>每个用户一个 <span class="code">app-instance</span>,Docker 容器隔离 + per-instance workspace。</li>
|
||||
<li><span class="code">Keycloak OIDC</span>:用户入口、SSO 登录与登出。</li>
|
||||
<li><span class="code">app-instance</span>:单体应用容器,内置前端、后端和 Nginx。</li>
|
||||
<li>外部编排系统负责部署和访问入口,当前应用不维护实例注册表。</li>
|
||||
</ul>
|
||||
</div>
|
||||
<div class="panel">
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
# User File System MinIO/AuthZ Setup
|
||||
|
||||
The user file system is exposed through Beaver APIs and `user_files_*` tools. MinIO remains an implementation detail.
|
||||
|
||||
The ordinary Files page should only call Beaver's `/api/user-files/*` routes and render the virtual roots `uploads/`, `outputs/`, `shared/`, and `tasks/`. It should not show bucket names, endpoint fields, access keys, secret keys, object prefixes, or MinIO administration actions.
|
||||
|
||||
## AuthZ Settings
|
||||
|
||||
Each backend identity can store MinIO settings in AuthZ:
|
||||
|
||||
```bash
|
||||
curl -X POST "$AUTHZ_URL/backends/$BACKEND_ID/settings/minio" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $AUTHZ_ADMIN_TOKEN" \
|
||||
-d '{
|
||||
"endpoint": "minio.example.internal:9000",
|
||||
"access_key": "user-access-key",
|
||||
"secret_key": "user-secret-key",
|
||||
"bucket": "beaver-user-files",
|
||||
"namespace": "users/{backend_id}",
|
||||
"secure": false,
|
||||
"region": null
|
||||
}'
|
||||
```
|
||||
|
||||
Public reads return masked settings. Internal reads require `AUTHZ_INTERNAL_TOKEN` and return the secret key for protected MCP services.
|
||||
|
||||
Deployed personal files use a shared bucket with a backend-scoped namespace. For backend `alice`, Beaver maps:
|
||||
|
||||
- `uploads/report.pdf` to `users/alice/uploads/report.pdf`
|
||||
- `outputs/summary.md` to `users/alice/outputs/summary.md`
|
||||
- `tasks/task-123/result.json` to `users/alice/tasks/task-123/result.json`
|
||||
|
||||
The MinIO policy for Alice's access key must be limited to `beaver-user-files/users/alice/*`. The frontend must still only show Beaver virtual paths, not the shared bucket or namespace.
|
||||
|
||||
Check the public, masked view:
|
||||
|
||||
```bash
|
||||
curl "$AUTHZ_URL/backends/$BACKEND_ID/settings/minio" \
|
||||
-H "Authorization: Bearer $AUTHZ_ADMIN_TOKEN"
|
||||
```
|
||||
|
||||
Check the internal protected view used by MCP services:
|
||||
|
||||
```bash
|
||||
curl "$AUTHZ_URL/internal/backends/$BACKEND_ID/settings/minio" \
|
||||
-H "Authorization: Bearer $AUTHZ_INTERNAL_TOKEN"
|
||||
```
|
||||
|
||||
## Protected MinIO MCP
|
||||
|
||||
Run the MinIO MCP service in protected mode:
|
||||
|
||||
```bash
|
||||
bw-minio-mcp serve \
|
||||
--host 0.0.0.0 \
|
||||
--port 8001 \
|
||||
--authz-url "$AUTHZ_URL" \
|
||||
--authz-token "$AUTHZ_INTERNAL_TOKEN" \
|
||||
--resource-server-url "$MINIO_MCP_PUBLIC_URL/mcp" \
|
||||
--state-root /var/lib/bw-minio-mcp
|
||||
```
|
||||
|
||||
In protected mode, the MCP service does not use static MinIO credentials at startup. Each authenticated tool call resolves the backend identity from the bearer token, loads that backend's MinIO settings from AuthZ, and constructs a per-call provider.
|
||||
|
||||
Outside protected mode, `bw-minio-mcp serve` requires explicit `--endpoint`, `--access-key`, and `--secret-key` values. It intentionally has no embedded production fallback credentials.
|
||||
|
||||
## Beaver Runtime
|
||||
|
||||
Beaver should register the MinIO MCP endpoint with backend-token auth when raw object tools are needed:
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"mcpServers": {
|
||||
"minio_mcp": {
|
||||
"url": "https://minio-mcp.example.internal/mcp",
|
||||
"auth": "oauth_backend_token",
|
||||
"authAudience": "mcp:minio_mcp"
|
||||
}
|
||||
}
|
||||
},
|
||||
"authz": {
|
||||
"baseUrl": "https://authz.example.internal",
|
||||
"backendId": "backend-user-id"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Product-level file interactions should still go through Beaver's user file system:
|
||||
|
||||
- Frontend: `/api/user-files/status`, `/api/user-files/browse`, `/api/user-files/upload`, `/api/user-files/preview`, `/api/user-files/download`, `/api/user-files/delete`, and `/api/user-files/mkdir`.
|
||||
- Agent tools: `user_files_list`, `user_files_read`, `user_files_write`, `user_files_delete`, and `user_files_mkdir`.
|
||||
- Storage boundary: only `uploads/`, `outputs/`, `shared/`, and `tasks/` are valid user paths.
|
||||
|
||||
The local workspace browser APIs and generic filesystem tools are retained for runtime/development compatibility, but they are not the user-visible file boundary.
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
- The Files page root renders exactly `uploads`, `outputs`, `shared`, and `tasks`.
|
||||
- The Files page source does not call `/api/workspace/browse`.
|
||||
- `/api/user-files/status` does not return local workspace paths or MinIO bucket details.
|
||||
- AuthZ public settings responses mask `secret_key`.
|
||||
- Protected `BW_MinIO_Mcp` returns a clear configuration error if a backend has no MinIO settings instead of falling back to another user's credentials.
|
||||
@ -0,0 +1,12 @@
|
||||
# User File System Tooling Boundary
|
||||
|
||||
The `personal-user-filesystem` change adds `user_files_*` tools for files that users can upload, inspect, and receive from agents. These tools enforce the same virtual roots as the web API:
|
||||
|
||||
- `uploads/`
|
||||
- `outputs/`
|
||||
- `shared/`
|
||||
- `tasks/`
|
||||
|
||||
The existing local workspace filesystem tools remain registered for internal runtime and development workflows. They are workspace-scoped, but they are not the user-visible file boundary. Agents should use `user_files_*` tools when reading user-provided files or writing user-facing outputs.
|
||||
|
||||
Follow-up for stronger isolation: add a runtime policy switch that disables or narrows local workspace filesystem tools for ordinary personal-agent tasks, while keeping `user_files_*` available.
|
||||
@ -11,7 +11,9 @@ dependencies = [
|
||||
"httpx>=0.28.0,<1.0.0",
|
||||
"json-repair>=0.39.0,<1.0.0",
|
||||
"litellm>=1.79.0,<2.0.0",
|
||||
"minio>=7.2.0,<8.0.0",
|
||||
"openai>=1.79.0,<2.0.0",
|
||||
"PyJWT[crypto]>=2.9.0,<3.0.0",
|
||||
"pydantic>=2.12.0,<3.0.0",
|
||||
"python-multipart>=0.0.20,<1.0.0",
|
||||
"typer>=0.20.0,<1.0.0",
|
||||
@ -21,6 +23,24 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=9.0.0,<10.0.0",
|
||||
"pytest-asyncio>=1.0.0,<2.0.0",
|
||||
]
|
||||
telegram = [
|
||||
"python-telegram-bot>=22.0,<23.0",
|
||||
]
|
||||
feishu = [
|
||||
"lark-oapi>=1.4.22,<2.0.0",
|
||||
]
|
||||
qqbot = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
weixin = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
channels = [
|
||||
"python-telegram-bot>=22.0,<23.0",
|
||||
"lark-oapi>=1.4.22,<2.0.0",
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from beaver.engine.loader import EngineLoader
|
||||
from beaver.engine.loop import AgentLoop
|
||||
from beaver.engine.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from beaver.engine.providers.factory import ProviderBundle
|
||||
from beaver.skills.learning.replay import ReplayToolExecutor, ReplayToolPolicy
|
||||
|
||||
|
||||
class ToolCallingProvider(LLMProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.calls = 0
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float = 0.7,
|
||||
thinking_enabled: bool | None = None,
|
||||
) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call-1",
|
||||
name="read_file",
|
||||
arguments={"path": "README.md"},
|
||||
)
|
||||
],
|
||||
)
|
||||
return LLMResponse(content="done")
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return "stub"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_uses_replay_tool_executor(tmp_path: Path) -> None:
|
||||
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||
loaded = loop.boot()
|
||||
provider = ToolCallingProvider()
|
||||
runtime = SimpleNamespace(model="stub", provider_name="stub")
|
||||
replay_executor = ReplayToolExecutor(
|
||||
loaded.tool_executor,
|
||||
registry=loaded.tool_registry,
|
||||
policy=ReplayToolPolicy(),
|
||||
)
|
||||
|
||||
result = await loop.process_direct(
|
||||
"Read the README.",
|
||||
provider_bundle=ProviderBundle(main_runtime=runtime, main_provider=provider), # type: ignore[arg-type]
|
||||
include_skill_assembly=False,
|
||||
pinned_skill_names=[],
|
||||
tool_executor_override=replay_executor,
|
||||
max_tool_iterations=2,
|
||||
source="skill_replay_eval",
|
||||
)
|
||||
|
||||
assert result.output_text == "done"
|
||||
assert replay_executor.traces
|
||||
assert replay_executor.traces[0]["tool_name"] == "read_file"
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from beaver.coordinator.models import AgentDescriptor, ExecutionGraph, ExecutionNode
|
||||
from beaver.coordinator.registry import AgentRegistry, RegisteredAgent, TargetResolver
|
||||
from beaver.tasks import TaskRecord
|
||||
@ -20,22 +22,64 @@ def _task() -> TaskRecord:
|
||||
)
|
||||
|
||||
|
||||
def test_registry_seeds_builtin_agents_and_filters_disabled(tmp_path) -> None:
|
||||
def test_registry_starts_empty_and_filters_disabled(tmp_path) -> None:
|
||||
registry = AgentRegistry(tmp_path)
|
||||
|
||||
assert {agent.agent_id for agent in registry.list_active_agents()} >= {
|
||||
"researcher",
|
||||
"implementer",
|
||||
"reviewer",
|
||||
"tester",
|
||||
"documenter",
|
||||
}
|
||||
assert registry.list_agents() == []
|
||||
|
||||
registry.upsert_agent(
|
||||
RegisteredAgent(
|
||||
agent_id="tester",
|
||||
name="tester",
|
||||
display_name="Tester",
|
||||
role="testing",
|
||||
description="Runs checks.",
|
||||
system_prompt="test",
|
||||
)
|
||||
)
|
||||
registry.disable_agent("tester")
|
||||
|
||||
assert "tester" not in {agent.agent_id for agent in registry.list_active_agents()}
|
||||
|
||||
|
||||
def test_registry_drops_legacy_builtin_agents(tmp_path) -> None:
|
||||
registry_path = tmp_path / "agents" / "registry.json"
|
||||
registry_path.parent.mkdir(parents=True)
|
||||
registry_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"version": 1,
|
||||
"agents": [
|
||||
{
|
||||
"agent_id": "researcher",
|
||||
"name": "researcher",
|
||||
"display_name": "Researcher",
|
||||
"role": "research",
|
||||
"description": "legacy builtin",
|
||||
"system_prompt": "research",
|
||||
"source": "builtin",
|
||||
},
|
||||
{
|
||||
"agent_id": "workspace-agent",
|
||||
"name": "workspace-agent",
|
||||
"display_name": "Workspace Agent",
|
||||
"role": "workspace",
|
||||
"description": "user configured",
|
||||
"system_prompt": "work",
|
||||
"source": "workspace",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
registry = AgentRegistry(tmp_path)
|
||||
|
||||
assert [agent.agent_id for agent in registry.list_agents()] == ["workspace-agent"]
|
||||
|
||||
|
||||
def test_resolver_selects_registered_agent_by_role_and_capabilities(tmp_path) -> None:
|
||||
registry = AgentRegistry(tmp_path)
|
||||
registry.upsert_agent(
|
||||
@ -88,4 +132,3 @@ def test_resolver_falls_back_to_ephemeral_agent_when_no_match(tmp_path) -> None:
|
||||
assert resolved.nodes[0].agent.name == "rare"
|
||||
assert resolved.nodes[0].agent.metadata["resolution"] == "fallback_ephemeral"
|
||||
assert reports[0].fallback_used is True
|
||||
|
||||
|
||||
@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from beaver.interfaces.web.app import create_app
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
def test_channel_connection_api_creates_updates_lists_and_revokes(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
'{"agents": {"defaults": {"workspace": "%s"}}, "providers": {}}' % str(tmp_path),
|
||||
encoding="utf-8",
|
||||
)
|
||||
service = AgentService(config_path=config_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
|
||||
try:
|
||||
with TestClient(app) as client:
|
||||
created = client.post(
|
||||
"/api/channel-connections",
|
||||
json={
|
||||
"kind": "telegram",
|
||||
"mode": "polling",
|
||||
"displayName": "Telegram Main",
|
||||
"authType": "token",
|
||||
"secrets": {"botToken": "token-1"},
|
||||
"config": {"maxMessageChars": 4096, "requireMentionInGroups": True},
|
||||
},
|
||||
)
|
||||
assert created.status_code == 200
|
||||
body = created.json()
|
||||
connection_id = body["connection"]["connection_id"]
|
||||
assert body["connection"]["kind"] == "telegram"
|
||||
assert body["connection"]["status"] == "draft"
|
||||
assert "credentials_ref" not in body["connection"]
|
||||
assert body["connection"]["runtime_config"] == {
|
||||
"max_message_chars": 4096,
|
||||
"require_mention_in_groups": True,
|
||||
}
|
||||
assert body["credentials"] == {"botToken": "***"}
|
||||
|
||||
patched = client.patch(
|
||||
f"/api/channel-connections/{connection_id}",
|
||||
json={
|
||||
"displayName": "Telegram Ops",
|
||||
"config": {"maxMessageChars": 2048},
|
||||
"secrets": {"botToken": "token-2"},
|
||||
},
|
||||
)
|
||||
assert patched.status_code == 200
|
||||
assert patched.json()["connection"]["display_name"] == "Telegram Ops"
|
||||
assert patched.json()["connection"]["runtime_config"] == {"max_message_chars": 2048}
|
||||
assert patched.json()["credentials"] == {"botToken": "***"}
|
||||
|
||||
listed = client.get("/api/channel-connections")
|
||||
assert listed.status_code == 200
|
||||
assert listed.json()[0]["connection_id"] == connection_id
|
||||
assert "credentials_ref" not in listed.json()[0]
|
||||
|
||||
revoked = client.post(f"/api/channel-connections/{connection_id}/revoke")
|
||||
assert revoked.status_code == 200
|
||||
assert revoked.json()["connection"]["status"] == "revoked"
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_channel_connectors_api_lists_registered_connectors(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
'{"agents": {"defaults": {"workspace": "%s"}}, "providers": {}}' % str(tmp_path),
|
||||
encoding="utf-8",
|
||||
)
|
||||
service = AgentService(config_path=config_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
|
||||
try:
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channel-connectors")
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [{"kind": "feishu"}, {"kind": "telegram"}, {"kind": "weixin"}]
|
||||
@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from beaver.interfaces.channels.connections import (
|
||||
ChannelConnectionStore,
|
||||
CredentialStore,
|
||||
PairingTokenStore,
|
||||
)
|
||||
|
||||
|
||||
def test_channel_connection_store_creates_updates_lists_and_revokes(tmp_path) -> None:
|
||||
store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
|
||||
created = store.create(
|
||||
kind="telegram",
|
||||
mode="polling",
|
||||
display_name="Telegram Main",
|
||||
account_id="telegram:bot-main",
|
||||
owner_user_id="user-1",
|
||||
auth_type="token",
|
||||
runtime_config={"max_message_chars": 4096},
|
||||
capabilities=["receive_text", "send_text"],
|
||||
)
|
||||
updated = store.update_status(created.connection_id, status="connected", last_error=None)
|
||||
revoked = store.revoke(created.connection_id)
|
||||
|
||||
assert created.connection_id
|
||||
assert created.channel_id.startswith("telegram-")
|
||||
assert created.status == "draft"
|
||||
assert updated.status == "connected"
|
||||
assert revoked.status == "revoked"
|
||||
assert store.get(created.connection_id).status == "revoked"
|
||||
assert [item.connection_id for item in store.list()] == [created.connection_id]
|
||||
|
||||
|
||||
def test_credential_store_saves_values_by_reference_and_redacts_views(tmp_path) -> None:
|
||||
store = CredentialStore(tmp_path / "credentials.json")
|
||||
|
||||
ref = store.put(kind="telegram", values={"botToken": "secret-token", "empty": ""})
|
||||
|
||||
assert ref.startswith("cred_")
|
||||
assert store.get(ref) == {"botToken": "secret-token"}
|
||||
assert store.redacted(ref) == {"botToken": "***"}
|
||||
|
||||
|
||||
def test_pairing_token_store_uses_one_time_expiring_tokens(tmp_path) -> None:
|
||||
store = PairingTokenStore(tmp_path / "pairing.json")
|
||||
|
||||
session = store.create(kind="terminal", ttl_seconds=60, scope="channel:pair")
|
||||
consumed = store.consume(session.token, expected_kind="terminal")
|
||||
reused = store.consume(session.token, expected_kind="terminal")
|
||||
|
||||
assert session.status == "pending"
|
||||
assert consumed is not None
|
||||
assert consumed.status == "consumed"
|
||||
assert reused is None
|
||||
|
||||
|
||||
def test_pairing_token_store_rejects_expired_tokens(tmp_path) -> None:
|
||||
store = PairingTokenStore(tmp_path / "pairing.json")
|
||||
|
||||
session = store.create(kind="weixin", ttl_seconds=-1, scope="channel:pair")
|
||||
|
||||
assert store.consume(session.token, expected_kind="weixin") is None
|
||||
@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from beaver.foundation.config.schema import ChannelConfig
|
||||
from beaver.interfaces.channels.connections import (
|
||||
ChannelConnectionStore,
|
||||
ChannelConnectorRegistry,
|
||||
ChannelRuntimeSpec,
|
||||
CredentialStore,
|
||||
ValidationResult,
|
||||
)
|
||||
|
||||
|
||||
class FakeConnector:
|
||||
kind = "fake"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.validated: list[str] = []
|
||||
self.revoked: list[str] = []
|
||||
|
||||
async def validate(self, connection_id: str) -> ValidationResult:
|
||||
self.validated.append(connection_id)
|
||||
return ValidationResult(ok=True, status="connected", account_id="fake-account", display_name="Fake")
|
||||
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
return ChannelRuntimeSpec(
|
||||
channel_id="fake-channel",
|
||||
kind="fake",
|
||||
mode="webhook",
|
||||
account_id="fake-account",
|
||||
display_name="Fake",
|
||||
config={"enabled": True},
|
||||
)
|
||||
|
||||
async def revoke(self, connection_id: str) -> None:
|
||||
self.revoked.append(connection_id)
|
||||
return None
|
||||
|
||||
|
||||
def test_connector_registry_dispatches_by_kind(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
connector = FakeConnector()
|
||||
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||
registry.register(connector)
|
||||
|
||||
connection = connection_store.create(
|
||||
kind="fake",
|
||||
mode="webhook",
|
||||
display_name="Fake",
|
||||
account_id="fake-account",
|
||||
owner_user_id=None,
|
||||
auth_type="token",
|
||||
)
|
||||
result = await registry.validate(connection.connection_id)
|
||||
spec = await registry.materialize_runtime(connection.connection_id)
|
||||
|
||||
assert result.ok is True
|
||||
assert connector.validated == [connection.connection_id]
|
||||
assert spec.channel_id == "fake-channel"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_connector_registry_materializes_channel_configs_with_credentials(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
credentials_ref = credential_store.put(kind="telegram", values={"botToken": "token-1"})
|
||||
connection = connection_store.create(
|
||||
kind="fake",
|
||||
mode="webhook",
|
||||
display_name="Connected",
|
||||
account_id="connected",
|
||||
owner_user_id=None,
|
||||
auth_type="token",
|
||||
credentials_ref=credentials_ref,
|
||||
)
|
||||
connection_store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||
|
||||
class CredentialAwareConnector(FakeConnector):
|
||||
async def materialize_runtime(self, connection_id: str) -> ChannelRuntimeSpec:
|
||||
stored = connection_store.get(connection_id)
|
||||
return ChannelRuntimeSpec(
|
||||
channel_id="fake-channel",
|
||||
kind="fake",
|
||||
mode="webhook",
|
||||
account_id="fake-account",
|
||||
display_name="Fake",
|
||||
config={"enabled": True},
|
||||
secrets_ref=stored.credentials_ref,
|
||||
)
|
||||
|
||||
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||
registry.register(CredentialAwareConnector())
|
||||
|
||||
configs = await registry.materialize_channel_configs()
|
||||
|
||||
assert isinstance(configs["fake-channel"], ChannelConfig)
|
||||
assert configs["fake-channel"].enabled is True
|
||||
assert configs["fake-channel"].secrets == {"botToken": "token-1"}
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_connector_registry_materializes_only_connected_connections(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||
registry.register(FakeConnector())
|
||||
|
||||
draft = connection_store.create(
|
||||
kind="fake",
|
||||
mode="webhook",
|
||||
display_name="Draft",
|
||||
account_id="draft",
|
||||
owner_user_id=None,
|
||||
auth_type="token",
|
||||
)
|
||||
connected = connection_store.create(
|
||||
kind="fake",
|
||||
mode="webhook",
|
||||
display_name="Connected",
|
||||
account_id="connected",
|
||||
owner_user_id=None,
|
||||
auth_type="token",
|
||||
)
|
||||
connection_store.update_status(connected.connection_id, status="connected", last_error=None)
|
||||
|
||||
specs = await registry.materialize_connected_runtime_specs()
|
||||
|
||||
assert [spec.channel_id for spec in specs] == ["fake-channel"]
|
||||
assert connection_store.get(draft.connection_id).status == "draft"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_connector_registry_revoke_calls_connector_and_updates_store(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
connector = FakeConnector()
|
||||
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||
registry.register(connector)
|
||||
|
||||
connection = connection_store.create(
|
||||
kind="fake",
|
||||
mode="webhook",
|
||||
display_name="Fake",
|
||||
account_id="fake-account",
|
||||
owner_user_id=None,
|
||||
auth_type="token",
|
||||
)
|
||||
connection_store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||
|
||||
await registry.revoke(connection.connection_id)
|
||||
|
||||
assert connector.revoked == [connection.connection_id]
|
||||
assert connection_store.get(connection.connection_id).status == "revoked"
|
||||
|
||||
asyncio.run(run())
|
||||
414
app-instance/backend/tests/unit/test_channel_runtime.py
Normal file
414
app-instance/backend/tests/unit/test_channel_runtime.py
Normal file
@ -0,0 +1,414 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from beaver.foundation.config.schema import ChannelConfig
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage, OutboundMessage
|
||||
from beaver.foundation.events import MessageBus
|
||||
from beaver.interfaces.channels.generic_webhook import GenericWebhookAdapter
|
||||
from beaver.interfaces.channels.runtime import ChannelRuntime
|
||||
from beaver.interfaces.channels.state import ChannelDedupeStore, ChannelEventLog
|
||||
from beaver.interfaces.web.app import _self_restart_enabled, create_app
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
def test_channel_identity_builds_stable_session_id() -> None:
|
||||
identity = ChannelIdentity(
|
||||
channel_id="webhook-dev",
|
||||
kind="webhook",
|
||||
account_id="local",
|
||||
peer_id="demo-user",
|
||||
thread_id="main",
|
||||
peer_type="dm",
|
||||
message_id="msg-1",
|
||||
)
|
||||
|
||||
assert identity.session_id() == "webhook-dev:local:demo-user:main"
|
||||
assert identity.dedupe_key() == "webhook-dev:local:demo-user:main:msg-1"
|
||||
|
||||
|
||||
def test_channel_identity_requires_routing_fields() -> None:
|
||||
identity = ChannelIdentity(channel_id="webhook-dev", kind="webhook", account_id="", peer_id="demo")
|
||||
|
||||
assert identity.validation_error() == "account_id is required"
|
||||
|
||||
|
||||
def test_messages_carry_channel_identity() -> None:
|
||||
identity = ChannelIdentity(
|
||||
channel_id="webhook-dev",
|
||||
kind="webhook",
|
||||
account_id="local",
|
||||
peer_id="demo-user",
|
||||
message_id="msg-1",
|
||||
)
|
||||
|
||||
inbound = InboundMessage(channel="webhook-dev", content="hello", channel_identity=identity)
|
||||
outbound = OutboundMessage(
|
||||
channel="webhook-dev",
|
||||
content="ok",
|
||||
session_id=identity.session_id(),
|
||||
finish_reason="stop",
|
||||
channel_identity=identity,
|
||||
)
|
||||
|
||||
assert inbound.channel_identity is identity
|
||||
assert outbound.channel_identity is identity
|
||||
|
||||
|
||||
def test_dedupe_store_tracks_processing_and_done(tmp_path) -> None:
|
||||
store = ChannelDedupeStore(tmp_path / "dedupe.json", retention_hours=48)
|
||||
|
||||
created = store.mark_processing(
|
||||
dedupe_key="webhook-dev:local:demo:msg-1",
|
||||
session_id="webhook-dev:local:demo",
|
||||
message_id="msg-1",
|
||||
)
|
||||
duplicate = store.mark_processing(
|
||||
dedupe_key="webhook-dev:local:demo:msg-1",
|
||||
session_id="webhook-dev:local:demo",
|
||||
message_id="msg-1",
|
||||
)
|
||||
|
||||
assert created.created is True
|
||||
assert duplicate.created is False
|
||||
assert duplicate.record is not None
|
||||
assert duplicate.record["status"] == "processing"
|
||||
|
||||
store.mark_done(
|
||||
dedupe_key="webhook-dev:local:demo:msg-1",
|
||||
run_id="run-1",
|
||||
reply="hello" * 10000,
|
||||
max_reply_chars=20,
|
||||
)
|
||||
|
||||
done = store.get("webhook-dev:local:demo:msg-1")
|
||||
assert done is not None
|
||||
assert done["status"] == "done"
|
||||
assert done["reply"] == "hellohellohellohello"
|
||||
|
||||
|
||||
def test_channel_event_log_writes_recent_events(tmp_path) -> None:
|
||||
log = ChannelEventLog(tmp_path / "events.jsonl")
|
||||
log.record(
|
||||
channel_id="webhook-dev",
|
||||
kind="inbound_accepted",
|
||||
session_id="webhook-dev:local:demo",
|
||||
message_id="msg-1",
|
||||
status="ok",
|
||||
text="hello world",
|
||||
)
|
||||
|
||||
events = log.recent(channel_id="webhook-dev", limit=10)
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0]["kind"] == "inbound_accepted"
|
||||
assert events[0]["text_preview"] == "hello world"
|
||||
assert "raw_channel_payload" not in json.dumps(events[0])
|
||||
|
||||
|
||||
class FakeAgentService:
|
||||
is_running = True
|
||||
|
||||
async def handle_inbound_message(self, inbound):
|
||||
return OutboundMessage(
|
||||
message_id=inbound.message_id,
|
||||
channel=inbound.channel,
|
||||
content=f"echo:{inbound.content}",
|
||||
session_id=inbound.session_id,
|
||||
finish_reason="stop",
|
||||
run_id="run-1",
|
||||
channel_identity=inbound.channel_identity,
|
||||
)
|
||||
|
||||
|
||||
class SlowFakeAgentService(FakeAgentService):
|
||||
async def handle_inbound_message(self, inbound):
|
||||
await asyncio.sleep(0.05)
|
||||
return await super().handle_inbound_message(inbound)
|
||||
|
||||
|
||||
def test_channel_runtime_accept_inbound_normalizes_session_and_dedupes(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
bus = MessageBus()
|
||||
runtime = ChannelRuntime(
|
||||
service=FakeAgentService(),
|
||||
bus=bus,
|
||||
workspace=tmp_path,
|
||||
channels={},
|
||||
)
|
||||
identity = ChannelIdentity(
|
||||
channel_id="webhook-dev",
|
||||
kind="webhook",
|
||||
account_id="local",
|
||||
peer_id="demo",
|
||||
message_id="msg-1",
|
||||
)
|
||||
result = await runtime.accept_inbound(
|
||||
InboundMessage(
|
||||
channel="webhook-dev",
|
||||
content="hello",
|
||||
session_id="wrong",
|
||||
channel_identity=identity,
|
||||
)
|
||||
)
|
||||
duplicate = await runtime.accept_inbound(
|
||||
InboundMessage(
|
||||
channel="webhook-dev",
|
||||
content="hello",
|
||||
channel_identity=identity,
|
||||
)
|
||||
)
|
||||
|
||||
queued = await bus.consume_inbound()
|
||||
assert result.accepted is True
|
||||
assert queued.session_id == "webhook-dev:local:demo"
|
||||
assert duplicate.accepted is False
|
||||
assert duplicate.duplicate is True
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_generic_webhook_adapter_waits_for_outbound_reply(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
bus = MessageBus()
|
||||
runtime = ChannelRuntime(
|
||||
service=FakeAgentService(),
|
||||
bus=bus,
|
||||
workspace=tmp_path,
|
||||
channels={},
|
||||
)
|
||||
adapter = GenericWebhookAdapter(
|
||||
channel_id="webhook-dev",
|
||||
kind="webhook",
|
||||
mode="webhook",
|
||||
account_id="local",
|
||||
display_name="Webhook Dev",
|
||||
inbound_sink=runtime,
|
||||
response_timeout_seconds=1,
|
||||
)
|
||||
runtime.manager.register(adapter)
|
||||
await runtime.start()
|
||||
try:
|
||||
response = await adapter.handle_webhook_payload(
|
||||
{
|
||||
"peer_id": "demo",
|
||||
"message_id": "msg-1",
|
||||
"text": "hello",
|
||||
"peer_type": "dm",
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
assert response["ok"] is True
|
||||
assert response["reply"] == "echo:hello"
|
||||
assert response["session_id"] == "webhook-dev:local:demo"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_generic_webhook_records_unclaimed_outbound_after_timeout(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
bus = MessageBus()
|
||||
runtime = ChannelRuntime(
|
||||
service=SlowFakeAgentService(),
|
||||
bus=bus,
|
||||
workspace=tmp_path,
|
||||
channels={},
|
||||
)
|
||||
adapter = GenericWebhookAdapter(
|
||||
channel_id="webhook-dev",
|
||||
kind="webhook",
|
||||
mode="webhook",
|
||||
account_id="local",
|
||||
display_name="Webhook Dev",
|
||||
inbound_sink=runtime,
|
||||
response_timeout_seconds=1,
|
||||
)
|
||||
adapter.response_timeout_seconds = 0.01
|
||||
runtime.manager.register(adapter)
|
||||
await runtime.start()
|
||||
try:
|
||||
response = await adapter.handle_webhook_payload(
|
||||
{
|
||||
"peer_id": "demo",
|
||||
"message_id": "msg-1",
|
||||
"text": "hello",
|
||||
"peer_type": "dm",
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0.1)
|
||||
events = runtime.recent_events("webhook-dev", limit=20)
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
assert response["pending"] is True
|
||||
assert any(event["kind"] == "outbound_unclaimed" for event in events)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_channel_runtime_starts_enabled_generic_webhook_and_reports_status(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
runtime = ChannelRuntime(
|
||||
service=FakeAgentService(),
|
||||
workspace=tmp_path,
|
||||
channels={
|
||||
"webhook-dev": ChannelConfig(
|
||||
enabled=True,
|
||||
kind="webhook",
|
||||
mode="webhook",
|
||||
account_id="local",
|
||||
display_name="Webhook Dev",
|
||||
config={"response_timeout_seconds": 1800},
|
||||
),
|
||||
"off": ChannelConfig(
|
||||
enabled=False,
|
||||
kind="webhook",
|
||||
mode="webhook",
|
||||
account_id="local",
|
||||
),
|
||||
},
|
||||
)
|
||||
await runtime.start()
|
||||
try:
|
||||
statuses = runtime.statuses()
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
by_id = {item["channel_id"]: item for item in statuses}
|
||||
assert by_id["webhook-dev"]["state"] == "running"
|
||||
assert by_id["webhook-dev"]["webhook_url"] == "/api/channels/webhook-dev/webhook"
|
||||
assert by_id["off"]["state"] == "disabled"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_channel_runtime_builds_platform_adapters_without_starting_networks(tmp_path) -> None:
|
||||
runtime = ChannelRuntime(
|
||||
service=FakeAgentService(),
|
||||
workspace=tmp_path,
|
||||
channels={},
|
||||
)
|
||||
|
||||
cases = {
|
||||
"telegram-main": ChannelConfig(enabled=True, kind="telegram", mode="polling", account_id="bot-main"),
|
||||
"feishu-main": ChannelConfig(enabled=True, kind="feishu", mode="websocket", account_id="tenant-main"),
|
||||
"qq-main": ChannelConfig(enabled=True, kind="qqbot", mode="websocket", account_id="qq-main"),
|
||||
"weixin-main": ChannelConfig(enabled=True, kind="weixin", mode="polling", account_id="wx-main"),
|
||||
}
|
||||
|
||||
for channel_id, cfg in cases.items():
|
||||
adapter = runtime._build_adapter(channel_id, cfg)
|
||||
assert adapter.channel_id == channel_id
|
||||
assert adapter.kind == cfg.kind
|
||||
assert adapter.mode == cfg.mode
|
||||
|
||||
|
||||
def test_channel_runtime_reports_platform_capabilities(tmp_path) -> None:
|
||||
runtime = ChannelRuntime(
|
||||
service=FakeAgentService(),
|
||||
workspace=tmp_path,
|
||||
channels={
|
||||
"telegram-main": ChannelConfig(enabled=True, kind="telegram", mode="polling", account_id="bot-main"),
|
||||
"weixin-main": ChannelConfig(enabled=True, kind="weixin", mode="polling", account_id="wx-main"),
|
||||
},
|
||||
)
|
||||
|
||||
by_id = {item["channel_id"]: item for item in runtime.statuses()}
|
||||
|
||||
assert by_id["telegram-main"]["capabilities"] == [
|
||||
"receive_text",
|
||||
"send_text",
|
||||
"receive_media",
|
||||
"groups",
|
||||
]
|
||||
assert by_id["weixin-main"]["capabilities"] == [
|
||||
"receive_text",
|
||||
"send_text",
|
||||
"receive_media",
|
||||
"direct_messages",
|
||||
]
|
||||
|
||||
|
||||
def test_channel_runtime_platform_start_failure_does_not_stop_other_channels(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
runtime = ChannelRuntime(
|
||||
service=FakeAgentService(),
|
||||
workspace=tmp_path,
|
||||
channels={
|
||||
"telegram-main": ChannelConfig(
|
||||
enabled=True,
|
||||
kind="telegram",
|
||||
mode="polling",
|
||||
account_id="bot-main",
|
||||
secrets={},
|
||||
),
|
||||
"off": ChannelConfig(
|
||||
enabled=False,
|
||||
kind="weixin",
|
||||
mode="polling",
|
||||
account_id="wx-main",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
await runtime.start()
|
||||
try:
|
||||
by_id = {item["channel_id"]: item for item in runtime.statuses()}
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
assert by_id["telegram-main"]["state"] == "error"
|
||||
assert "botToken" in by_id["telegram-main"]["last_error"]
|
||||
assert by_id["off"]["state"] == "disabled"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_web_app_status_exposes_configured_channel(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {"defaults": {"workspace": str(workspace), "model": "openai/gpt-5"}},
|
||||
"providers": {},
|
||||
"channels": {
|
||||
"webhook-dev": {
|
||||
"enabled": True,
|
||||
"kind": "webhook",
|
||||
"mode": "webhook",
|
||||
"accountId": "local",
|
||||
"displayName": "Webhook Dev",
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
service = AgentService(config_path=config_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
|
||||
with TestClient(app) as client:
|
||||
payload = client.get("/api/status").json()
|
||||
|
||||
service.close()
|
||||
assert payload["channels"][0]["channel_id"] == "webhook-dev"
|
||||
assert payload["channels"][0]["state"] == "running"
|
||||
assert payload["channels"][0]["webhook_url"] == "/api/channels/webhook-dev/webhook"
|
||||
assert payload["runtime_controls"]["self_restart"] is True
|
||||
|
||||
|
||||
def test_self_restart_env_defaults_enabled(monkeypatch) -> None:
|
||||
monkeypatch.delenv("BEAVER_ENABLE_SELF_RESTART", raising=False)
|
||||
|
||||
assert _self_restart_enabled() is True
|
||||
|
||||
|
||||
def test_self_restart_env_can_disable(monkeypatch) -> None:
|
||||
monkeypatch.setenv("BEAVER_ENABLE_SELF_RESTART", "0")
|
||||
|
||||
assert _self_restart_enabled() is False
|
||||
@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from beaver.foundation.config.schema import ChannelConfig
|
||||
from beaver.foundation.events import MessageBus, OutboundMessage
|
||||
from beaver.interfaces.channels.runtime import ChannelRuntime
|
||||
|
||||
|
||||
class FakeService:
|
||||
async def handle_inbound_message(self, inbound):
|
||||
return OutboundMessage(channel=inbound.channel, content="ok", session_id=inbound.session_id, finish_reason="stop")
|
||||
|
||||
|
||||
def test_runtime_add_channel_starts_new_channel_after_runtime_start(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||
await runtime.start()
|
||||
try:
|
||||
await runtime.add_channel(
|
||||
"webhook-dev",
|
||||
ChannelConfig(enabled=True, kind="webhook", mode="webhook", account_id="acct"),
|
||||
)
|
||||
assert "webhook-dev" in runtime.adapters
|
||||
assert runtime.states["webhook-dev"]["state"] == "running"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_runtime_add_channel_noops_for_same_config(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
cfg = ChannelConfig(enabled=True, kind="webhook", mode="webhook", account_id="acct")
|
||||
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||
await runtime.start()
|
||||
try:
|
||||
await runtime.add_channel("webhook-dev", cfg)
|
||||
first = runtime.adapters["webhook-dev"]
|
||||
await runtime.add_channel("webhook-dev", cfg)
|
||||
assert runtime.adapters["webhook-dev"] is first
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_runtime_replacement_failure_keeps_old_channel(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
good = ChannelConfig(enabled=True, kind="webhook", mode="webhook", account_id="acct")
|
||||
bad = ChannelConfig(enabled=True, kind="missing", mode="http", account_id="acct")
|
||||
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||
await runtime.start()
|
||||
try:
|
||||
await runtime.add_channel("webhook-dev", good)
|
||||
old = runtime.adapters["webhook-dev"]
|
||||
try:
|
||||
await runtime.add_channel("webhook-dev", bad)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Expected ValueError")
|
||||
assert runtime.adapters["webhook-dev"] is old
|
||||
assert runtime.channel_configs["webhook-dev"] == good
|
||||
assert runtime.states["webhook-dev"]["state"] == "running"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_runtime_remove_channel_stops_and_unregisters(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||
await runtime.start()
|
||||
try:
|
||||
await runtime.add_channel(
|
||||
"webhook-dev",
|
||||
ChannelConfig(enabled=True, kind="webhook", mode="webhook", account_id="acct"),
|
||||
)
|
||||
await runtime.remove_channel("webhook-dev")
|
||||
assert "webhook-dev" not in runtime.adapters
|
||||
assert "webhook-dev" not in runtime.manager.channels
|
||||
assert runtime.states["webhook-dev"]["state"] == "removed"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_runtime_builds_external_connector_channel(tmp_path, monkeypatch) -> None:
|
||||
async def run() -> None:
|
||||
monkeypatch.setenv("EXTERNAL_CONNECTOR_TOKEN", "connector-token")
|
||||
runtime = ChannelRuntime(service=FakeService(), workspace=tmp_path, channels={}, bus=MessageBus())
|
||||
await runtime.start()
|
||||
try:
|
||||
await runtime.add_channel(
|
||||
"weixin-main",
|
||||
ChannelConfig(
|
||||
enabled=True,
|
||||
kind="external_connector",
|
||||
mode="http",
|
||||
account_id="weixin:me",
|
||||
display_name="Weixin Main",
|
||||
config={
|
||||
"platformKind": "weixin",
|
||||
"connectionId": "conn_1",
|
||||
"sidecarBaseUrl": "http://external-connector:8787",
|
||||
},
|
||||
),
|
||||
)
|
||||
adapter = runtime.adapters["weixin-main"]
|
||||
assert adapter.kind == "external_connector"
|
||||
assert adapter.mode == "http"
|
||||
assert getattr(adapter, "platform_kind") == "weixin"
|
||||
finally:
|
||||
await runtime.stop()
|
||||
|
||||
asyncio.run(run())
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@ -46,6 +47,44 @@ def test_load_config_reads_current_instance_shape(tmp_path) -> None:
|
||||
assert target["extra_headers"] == {"X-Test": "1"}
|
||||
|
||||
|
||||
def test_config_loader_reads_channels(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {"defaults": {"model": "openai/gpt-5"}},
|
||||
"channels": {
|
||||
"webhook-dev": {
|
||||
"enabled": True,
|
||||
"kind": "webhook",
|
||||
"mode": "webhook",
|
||||
"accountId": "local",
|
||||
"displayName": "Webhook Dev",
|
||||
"config": {
|
||||
"responseTimeoutSeconds": 1800,
|
||||
"dedupeRetentionHours": 48,
|
||||
},
|
||||
"secrets": {"ignored_for_status": "secret-value"},
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path=config_path)
|
||||
|
||||
channel = config.channels["webhook-dev"]
|
||||
assert channel.enabled is True
|
||||
assert channel.kind == "webhook"
|
||||
assert channel.mode == "webhook"
|
||||
assert channel.account_id == "local"
|
||||
assert channel.display_name == "Webhook Dev"
|
||||
assert channel.config["response_timeout_seconds"] == 1800
|
||||
assert channel.config["dedupe_retention_hours"] == 48
|
||||
assert channel.secrets == {"ignored_for_status": "secret-value"}
|
||||
|
||||
|
||||
def test_provider_resolution_ignores_custom_and_disabled_overrides(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
@ -163,6 +202,58 @@ def test_reload_agent_config_updates_booted_loop_config(tmp_path) -> None:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_reload_agent_config_keeps_running_service_when_old_mcp_close_fails(tmp_path) -> None:
|
||||
async def run_case() -> None:
|
||||
workspace = tmp_path / "workspace"
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {"defaults": {"workspace": str(workspace), "model": "old-model"}},
|
||||
"providers": {"openai": {"apiKey": "sk-test", "apiBase": "https://old.example.com/v1"}},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
service = AgentService(config_path=config_path)
|
||||
await service.start()
|
||||
|
||||
class FailingMCPManager:
|
||||
async def close(self) -> None:
|
||||
raise RuntimeError("Attempted to exit cancel scope in a different task than it was entered in")
|
||||
|
||||
loaded = service.create_loop().boot()
|
||||
loaded.mcp_manager = FailingMCPManager()
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {"defaults": {"workspace": str(workspace), "model": "new-model"}},
|
||||
"providers": {"openai": {"apiKey": "sk-test", "apiBase": "https://new.example.com/v1"}},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
unhandled: list[dict[str, object]] = []
|
||||
previous_handler = loop.get_exception_handler()
|
||||
loop.set_exception_handler(lambda _loop, context: unhandled.append(context))
|
||||
try:
|
||||
_reload_agent_config(service, config_path)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
target = service.create_loop().boot().config.resolve_provider_target()
|
||||
assert service.is_running is True
|
||||
assert target["model"] == "new-model"
|
||||
assert target["api_base"] == "https://new.example.com/v1"
|
||||
assert unhandled == []
|
||||
finally:
|
||||
loop.set_exception_handler(previous_handler)
|
||||
await service.shutdown(force=True)
|
||||
|
||||
asyncio.run(run_case())
|
||||
|
||||
|
||||
def test_agent_defaults_include_runtime_controls(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
@ -245,6 +336,67 @@ def test_agent_config_api_accepts_zero_temperature_and_iterations(tmp_path) -> N
|
||||
service.close()
|
||||
|
||||
|
||||
def test_channel_config_api_persists_and_masks_secrets(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"agents": {"defaults": {"model": "openai/gpt-5"}},
|
||||
"channels": {
|
||||
"telegram-main": {
|
||||
"enabled": False,
|
||||
"kind": "telegram",
|
||||
"mode": "polling",
|
||||
"accountId": "bot-main",
|
||||
"displayName": "Telegram Main",
|
||||
"secrets": {"botToken": "1234567890abcdef"},
|
||||
"config": {"requireMentionInGroups": True},
|
||||
}
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
service = AgentService(config_path=config_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
|
||||
with TestClient(app) as client:
|
||||
before = client.get("/api/channels/telegram-main/config")
|
||||
response = client.post(
|
||||
"/api/channels/telegram-main/config",
|
||||
json={
|
||||
"enabled": True,
|
||||
"kind": "telegram",
|
||||
"mode": "polling",
|
||||
"account_id": "bot-main",
|
||||
"display_name": "Telegram Primary",
|
||||
"secrets": {"botToken": ""},
|
||||
"config": {
|
||||
"requireMentionInGroups": False,
|
||||
"allowFrom": ["1001", "1002"],
|
||||
"maxMessageChars": 3000,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
channel = saved["channels"]["telegram-main"]
|
||||
|
||||
assert before.status_code == 200
|
||||
assert before.json()["secrets"] == {"botToken": "1234••••cdef"}
|
||||
assert response.status_code == 200
|
||||
assert response.json()["ok"] is True
|
||||
assert response.json()["restart_required"] is True
|
||||
assert response.json()["channel"]["display_name"] == "Telegram Primary"
|
||||
assert response.json()["channel"]["secrets"] == {"botToken": "1234••••cdef"}
|
||||
assert channel["enabled"] is True
|
||||
assert channel["displayName"] == "Telegram Primary"
|
||||
assert channel["secrets"]["botToken"] == "1234567890abcdef"
|
||||
assert channel["config"]["allowFrom"] == ["1001", "1002"]
|
||||
assert load_config(config_path=config_path).channels["telegram-main"].enabled is True
|
||||
service.close()
|
||||
|
||||
|
||||
def test_openai_compatible_qwen_config_keeps_openai_provider() -> None:
|
||||
bundle = make_provider_bundle(
|
||||
model="qwen-plus",
|
||||
@ -320,4 +472,5 @@ def test_load_config_adds_managed_local_mcp_servers(tmp_path) -> None:
|
||||
assert local.kind == "local"
|
||||
assert local.category == "filesystem"
|
||||
assert local.managed is True
|
||||
assert local.display_name == "个人智能体文件系统工具"
|
||||
assert "beaver.interfaces.mcp.tools_server" in local.args
|
||||
|
||||
@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from beaver.interfaces.channels.connections import MessageDedupeStore
|
||||
|
||||
|
||||
def test_message_dedupe_store_completes_and_dedupes_completed(tmp_path) -> None:
|
||||
store = MessageDedupeStore(tmp_path / "message_dedupe.json")
|
||||
|
||||
first = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=1)
|
||||
store.complete(first.dedupe_key, message_id="msg_1")
|
||||
duplicate = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=2)
|
||||
|
||||
assert first.should_process is True
|
||||
assert duplicate.should_process is False
|
||||
assert duplicate.status == "completed"
|
||||
assert duplicate.http_status == 200
|
||||
|
||||
|
||||
def test_message_dedupe_store_returns_conflict_for_active_processing(tmp_path) -> None:
|
||||
store = MessageDedupeStore(tmp_path / "message_dedupe.json", processing_ttl_seconds=60)
|
||||
|
||||
store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=1)
|
||||
duplicate = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=2)
|
||||
|
||||
assert duplicate.should_process is False
|
||||
assert duplicate.status == "processing"
|
||||
assert duplicate.http_status == 409
|
||||
assert duplicate.retry_after_seconds == 5
|
||||
|
||||
|
||||
def test_message_dedupe_store_reprocesses_stale_processing(tmp_path) -> None:
|
||||
store = MessageDedupeStore(tmp_path / "message_dedupe.json", processing_ttl_seconds=0)
|
||||
|
||||
store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=1)
|
||||
stale = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=2)
|
||||
|
||||
assert stale.should_process is True
|
||||
assert stale.status == "processing"
|
||||
assert stale.record.delivery_attempts == 2
|
||||
|
||||
|
||||
def test_message_dedupe_store_reprocesses_failed_records(tmp_path) -> None:
|
||||
store = MessageDedupeStore(tmp_path / "message_dedupe.json")
|
||||
|
||||
first = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=1)
|
||||
store.fail(first.dedupe_key, error="runtime rejected")
|
||||
retry = store.begin(connection_id="conn_1", event_id="evt_1", delivery_attempt=2)
|
||||
|
||||
assert retry.should_process is True
|
||||
assert retry.record.delivery_attempts == 2
|
||||
assert retry.record.last_error is None
|
||||
@ -26,3 +26,26 @@ def test_context_builder_injects_current_date_and_time() -> None:
|
||||
assert "Local UTC offset: +08:00" in system_prompt
|
||||
assert '"today", "tomorrow", "now", "this week", and "next month"' in system_prompt
|
||||
assert result.messages[-1] == {"role": "user", "content": "今天几号?"}
|
||||
|
||||
|
||||
def test_context_builder_uses_simplified_main_agent_prompt_by_default() -> None:
|
||||
system_prompt = ContextBuilder().build_system_prompt(ContextBuildInput())
|
||||
|
||||
assert "你是海狸 (Beaver)" in system_prompt
|
||||
assert "博维资讯系统有限公司研发的 AI 助手" in system_prompt
|
||||
assert "使用简体中文进行面向用户的回复" in system_prompt
|
||||
|
||||
|
||||
def test_context_builder_uses_traditional_main_agent_prompt_for_zh_hant() -> None:
|
||||
system_prompt = ContextBuilder().build_system_prompt(ContextBuildInput(prompt_locale="zh-Hant"))
|
||||
|
||||
assert "你是海狸 (Beaver)" in system_prompt
|
||||
assert "博維資訊系統有限公司研發的 AI 助手" in system_prompt
|
||||
assert "使用繁體中文進行面向使用者的回覆" in system_prompt
|
||||
|
||||
|
||||
def test_context_builder_uses_english_main_agent_prompt_for_en() -> None:
|
||||
system_prompt = ContextBuilder().build_system_prompt(ContextBuildInput(prompt_locale="en"))
|
||||
|
||||
assert "You are Beaver, an AI assistant developed by Boway Information Systems Co., Ltd." in system_prompt
|
||||
assert "Use English for user-facing replies" in system_prompt
|
||||
|
||||
@ -0,0 +1,193 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from beaver.interfaces.channels.connections import ChannelConnectionStore
|
||||
from beaver.interfaces.web.app import create_app
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
def _app(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("BEAVER_BRIDGE_TOKEN", "bridge-token")
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
'{"agents": {"defaults": {"workspace": "%s"}}, "providers": {}}' % str(tmp_path),
|
||||
encoding="utf-8",
|
||||
)
|
||||
service = AgentService(config_path=config_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
return app, service
|
||||
|
||||
|
||||
def _connected_connection(tmp_path):
|
||||
state_dir = tmp_path / "state" / "channel_connections"
|
||||
store = ChannelConnectionStore(state_dir / "connections.json")
|
||||
connection = store.create(
|
||||
kind="weixin",
|
||||
mode="sidecar",
|
||||
display_name="Weixin Main",
|
||||
account_id="weixin:me",
|
||||
owner_user_id=None,
|
||||
auth_type="connector_session",
|
||||
)
|
||||
store.update_status(connection.connection_id, status="connected", last_error=None)
|
||||
return connection
|
||||
|
||||
|
||||
def _connection_with_status(tmp_path, status: str):
|
||||
state_dir = tmp_path / "state" / "channel_connections"
|
||||
store = ChannelConnectionStore(state_dir / "connections.json")
|
||||
connection = store.create(
|
||||
kind="feishu",
|
||||
mode="sidecar",
|
||||
display_name="Feishu Main",
|
||||
account_id="feishu:app-1",
|
||||
owner_user_id=None,
|
||||
auth_type="connector_session",
|
||||
)
|
||||
store.update_status(connection.connection_id, status=status, last_error=None)
|
||||
return connection
|
||||
|
||||
|
||||
def _payload(connection, *, event_id: str = "evt-1", delivery_attempt: int = 1) -> dict:
|
||||
return {
|
||||
"eventId": event_id,
|
||||
"timestamp": "2026-06-02T09:30:00Z",
|
||||
"deliveryAttempt": delivery_attempt,
|
||||
"connectionId": connection.connection_id,
|
||||
"channelId": connection.channel_id,
|
||||
"kind": "weixin",
|
||||
"accountId": "weixin:me",
|
||||
"peerId": "peer-1",
|
||||
"peerType": "dm",
|
||||
"userId": "sender-1",
|
||||
"threadId": None,
|
||||
"messageId": "msg-1",
|
||||
"messageType": "text",
|
||||
"content": "hello",
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
|
||||
def test_bridge_endpoint_accepts_valid_event(tmp_path, monkeypatch) -> None:
|
||||
app, service = _app(tmp_path, monkeypatch)
|
||||
try:
|
||||
with TestClient(app) as client:
|
||||
connection = _connected_connection(tmp_path)
|
||||
response = client.post(
|
||||
"/api/channel-connector-bridge/events",
|
||||
headers={"Authorization": "Bearer bridge-token"},
|
||||
json=_payload(connection),
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["accepted"] is True
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_bridge_endpoint_rejects_invalid_token(tmp_path, monkeypatch) -> None:
|
||||
app, service = _app(tmp_path, monkeypatch)
|
||||
try:
|
||||
with TestClient(app) as client:
|
||||
connection = _connected_connection(tmp_path)
|
||||
response = client.post(
|
||||
"/api/channel-connector-bridge/events",
|
||||
headers={"Authorization": "Bearer wrong"},
|
||||
json=_payload(connection),
|
||||
)
|
||||
assert response.status_code == 401
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_bridge_endpoint_rejects_connection_identity_mismatch(tmp_path, monkeypatch) -> None:
|
||||
app, service = _app(tmp_path, monkeypatch)
|
||||
try:
|
||||
with TestClient(app) as client:
|
||||
connection = _connected_connection(tmp_path)
|
||||
payload = _payload(connection)
|
||||
payload["channelId"] = "forged-channel"
|
||||
payload["kind"] = "feishu"
|
||||
payload["accountId"] = "feishu:attacker"
|
||||
response = client.post(
|
||||
"/api/channel-connector-bridge/events",
|
||||
headers={"Authorization": "Bearer bridge-token"},
|
||||
json=payload,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
assert "does not match connection" in response.json()["detail"]
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_bridge_endpoint_rejects_inactive_connection(tmp_path, monkeypatch) -> None:
|
||||
app, service = _app(tmp_path, monkeypatch)
|
||||
try:
|
||||
with TestClient(app) as client:
|
||||
connection = _connection_with_status(tmp_path, "pairing")
|
||||
response = client.post(
|
||||
"/api/channel-connector-bridge/events",
|
||||
headers={"Authorization": "Bearer bridge-token"},
|
||||
json={
|
||||
**_payload(connection),
|
||||
"kind": "feishu",
|
||||
"accountId": "feishu:app-1",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
assert "not connected" in response.json()["detail"]
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_bridge_endpoint_rejects_empty_or_oversized_content(tmp_path, monkeypatch) -> None:
|
||||
app, service = _app(tmp_path, monkeypatch)
|
||||
try:
|
||||
with TestClient(app) as client:
|
||||
connection = _connection_with_status(tmp_path, "connected")
|
||||
blank = client.post(
|
||||
"/api/channel-connector-bridge/events",
|
||||
headers={"Authorization": "Bearer bridge-token"},
|
||||
json={
|
||||
**_payload(connection, event_id="blank"),
|
||||
"kind": "feishu",
|
||||
"accountId": "feishu:app-1",
|
||||
"content": " ",
|
||||
},
|
||||
)
|
||||
too_long = client.post(
|
||||
"/api/channel-connector-bridge/events",
|
||||
headers={"Authorization": "Bearer bridge-token"},
|
||||
json={
|
||||
**_payload(connection, event_id="too-long"),
|
||||
"kind": "feishu",
|
||||
"accountId": "feishu:app-1",
|
||||
"content": "x" * 20001,
|
||||
},
|
||||
)
|
||||
assert blank.status_code == 400
|
||||
assert too_long.status_code == 413
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_bridge_endpoint_dedupes_repeated_event(tmp_path, monkeypatch) -> None:
|
||||
app, service = _app(tmp_path, monkeypatch)
|
||||
try:
|
||||
with TestClient(app) as client:
|
||||
connection = _connected_connection(tmp_path)
|
||||
first = client.post(
|
||||
"/api/channel-connector-bridge/events",
|
||||
headers={"Authorization": "Bearer bridge-token"},
|
||||
json=_payload(connection),
|
||||
)
|
||||
second = client.post(
|
||||
"/api/channel-connector-bridge/events",
|
||||
headers={"Authorization": "Bearer bridge-token"},
|
||||
json=_payload(connection, delivery_attempt=2),
|
||||
)
|
||||
assert first.status_code == 200
|
||||
assert second.status_code in {200, 409}
|
||||
assert second.json()["duplicate"] is True
|
||||
finally:
|
||||
service.close()
|
||||
@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from beaver.foundation.events import ChannelIdentity, OutboundMessage
|
||||
from beaver.interfaces.channels.external_connector import ExternalConnectorChannel, _request_id
|
||||
|
||||
|
||||
class FakeSidecarClient:
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send(self, payload: dict) -> dict:
|
||||
self.sent.append(payload)
|
||||
return {"ok": True, "providerMessageId": "provider-1"}
|
||||
|
||||
|
||||
def test_external_connector_channel_sends_with_target_and_request_id() -> None:
|
||||
async def run() -> None:
|
||||
client = FakeSidecarClient()
|
||||
channel = ExternalConnectorChannel(
|
||||
channel_id="weixin-main",
|
||||
platform_kind="weixin",
|
||||
connection_id="conn_1",
|
||||
account_id="weixin:me",
|
||||
display_name="Weixin Main",
|
||||
sidecar_client=client,
|
||||
)
|
||||
message = OutboundMessage(
|
||||
channel="weixin-main",
|
||||
content="reply",
|
||||
session_id="s1",
|
||||
finish_reason="stop",
|
||||
message_id="out-msg-1",
|
||||
channel_identity=ChannelIdentity(
|
||||
channel_id="weixin-main",
|
||||
kind="weixin",
|
||||
account_id="weixin:me",
|
||||
peer_id="peer-1",
|
||||
peer_type="dm",
|
||||
thread_id=None,
|
||||
user_id="sender-1",
|
||||
message_id="in-msg-1",
|
||||
),
|
||||
metadata={"inbound_metadata": {"contextToken": "ctx-1"}},
|
||||
)
|
||||
|
||||
await channel.send(message)
|
||||
|
||||
assert client.sent == [
|
||||
{
|
||||
"requestId": "out_weixin-main:s1:out-msg-1",
|
||||
"connectionId": "conn_1",
|
||||
"channelId": "weixin-main",
|
||||
"kind": "weixin",
|
||||
"target": {"peerId": "peer-1", "peerType": "dm", "threadId": None},
|
||||
"content": "reply",
|
||||
"metadata": {"inboundMessageId": "in-msg-1", "sessionId": "s1", "contextToken": "ctx-1"},
|
||||
}
|
||||
]
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_external_connector_request_id_falls_back_when_message_id_is_none_or_blank() -> None:
|
||||
identity = ChannelIdentity(
|
||||
channel_id="weixin-main",
|
||||
kind="weixin",
|
||||
account_id="weixin:me",
|
||||
peer_id="peer-1",
|
||||
peer_type="dm",
|
||||
message_id="in-msg-1",
|
||||
)
|
||||
first = OutboundMessage(
|
||||
channel="weixin-main",
|
||||
content="same reply",
|
||||
session_id="s1",
|
||||
finish_reason="stop",
|
||||
message_id=None, # type: ignore[arg-type]
|
||||
channel_identity=identity,
|
||||
)
|
||||
second = OutboundMessage(
|
||||
channel="weixin-main",
|
||||
content="same reply",
|
||||
session_id="s1",
|
||||
finish_reason="stop",
|
||||
message_id="",
|
||||
channel_identity=identity,
|
||||
)
|
||||
|
||||
assert _request_id(first) == _request_id(second)
|
||||
assert _request_id(first).startswith("out_weixin-main:s1:")
|
||||
|
||||
|
||||
def test_external_connector_channel_requires_identity() -> None:
|
||||
async def run() -> None:
|
||||
channel = ExternalConnectorChannel(
|
||||
channel_id="weixin-main",
|
||||
platform_kind="weixin",
|
||||
connection_id="conn_1",
|
||||
account_id="weixin:me",
|
||||
display_name="Weixin Main",
|
||||
sidecar_client=FakeSidecarClient(),
|
||||
)
|
||||
message = OutboundMessage(channel="weixin-main", content="reply", session_id="s1", finish_reason="stop")
|
||||
|
||||
try:
|
||||
await channel.send(message)
|
||||
except ValueError as exc:
|
||||
assert "channel_identity is required" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError")
|
||||
|
||||
asyncio.run(run())
|
||||
@ -0,0 +1,320 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from beaver.interfaces.channels.connections import (
|
||||
ChannelConnectionStore,
|
||||
ChannelConnectorRegistry,
|
||||
CredentialStore,
|
||||
FeishuConnector,
|
||||
WeixinConnector,
|
||||
)
|
||||
from beaver.interfaces.web.app import create_app
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
class FakeSidecarClient:
|
||||
def __init__(self) -> None:
|
||||
self.sessions: dict[str, dict] = {}
|
||||
self.started: list[dict] = []
|
||||
self.logged_out: list[str] = []
|
||||
|
||||
async def start_session(self, payload: dict) -> dict:
|
||||
self.started.append(payload)
|
||||
session = {
|
||||
"sessionId": "cs_1",
|
||||
"kind": payload["kind"],
|
||||
"status": "qr_ready",
|
||||
"qrImage": "data:image/png;base64,abc",
|
||||
"accountId": None,
|
||||
"displayName": None,
|
||||
"metadata": {},
|
||||
}
|
||||
self.sessions["cs_1"] = session
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> dict:
|
||||
return self.sessions[session_id]
|
||||
|
||||
async def logout(self, connection_id: str) -> dict:
|
||||
self.logged_out.append(connection_id)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
class ImmediateConnectedSidecarClient(FakeSidecarClient):
|
||||
async def start_session(self, payload: dict) -> dict:
|
||||
self.started.append(payload)
|
||||
session = {
|
||||
"sessionId": "cs_connected",
|
||||
"kind": payload["kind"],
|
||||
"status": "connected",
|
||||
"qrImage": None,
|
||||
"accountId": f"{payload['kind']}:me",
|
||||
"displayName": "Connected Account",
|
||||
"metadata": {"stateRef": "state-1", "appSecret": "secret-1", "tenantAccessToken": "token-1"},
|
||||
}
|
||||
self.sessions["cs_connected"] = session
|
||||
return session
|
||||
|
||||
|
||||
def test_weixin_connector_starts_connector_session(tmp_path, monkeypatch) -> None:
|
||||
async def run() -> None:
|
||||
monkeypatch.setenv("EXTERNAL_CONNECTOR_CALLBACK_BASE_URL", "http://app-instance-jaychen:8080")
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
client = FakeSidecarClient()
|
||||
connector = WeixinConnector(
|
||||
connection_store=connection_store,
|
||||
credential_store=credential_store,
|
||||
sidecar_client=client,
|
||||
sidecar_base_url="http://external-connector:8787",
|
||||
)
|
||||
|
||||
view = await connector.start_session(display_name="Weixin Main", owner_user_id="user-1", options={})
|
||||
|
||||
assert view["sessionId"] == "cs_1"
|
||||
assert view["connectionId"].startswith("conn_")
|
||||
assert client.started[0]["kind"] == "weixin"
|
||||
assert client.started[0]["connectionId"].startswith("conn_")
|
||||
assert client.started[0]["callbackBaseUrl"] == "http://app-instance-jaychen:8080"
|
||||
assert connection_store.list()[0].kind == "weixin"
|
||||
assert connection_store.list()[0].status == "pairing"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_feishu_connector_start_session_connected_updates_connection(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
client = ImmediateConnectedSidecarClient()
|
||||
connector = FeishuConnector(
|
||||
connection_store=connection_store,
|
||||
credential_store=credential_store,
|
||||
sidecar_client=client,
|
||||
sidecar_base_url="http://external-connector:8787",
|
||||
)
|
||||
|
||||
view = await connector.start_session(display_name="Feishu Main", owner_user_id=None, options={})
|
||||
connection = connection_store.get(view["connectionId"])
|
||||
|
||||
assert view["status"] == "connected"
|
||||
assert connection.status == "connected"
|
||||
assert connection.account_id == "feishu:me"
|
||||
assert connection.display_name == "Connected Account"
|
||||
assert connection.credentials_ref is not None
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_weixin_connector_poll_connected_materializes_external_runtime(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
client = FakeSidecarClient()
|
||||
connector = WeixinConnector(
|
||||
connection_store=connection_store,
|
||||
credential_store=credential_store,
|
||||
sidecar_client=client,
|
||||
sidecar_base_url="http://external-connector:8787",
|
||||
)
|
||||
await connector.start_session(display_name="Weixin Main", owner_user_id=None, options={})
|
||||
connection = connection_store.list()[0]
|
||||
client.sessions["cs_1"] = {
|
||||
"sessionId": "cs_1",
|
||||
"kind": "weixin",
|
||||
"status": "connected",
|
||||
"accountId": "weixin:me",
|
||||
"displayName": "Me",
|
||||
"metadata": {"stateRef": "state-1"},
|
||||
}
|
||||
|
||||
result = await connector.poll_session("cs_1")
|
||||
updated = connection_store.get(connection.connection_id)
|
||||
spec = await connector.materialize_runtime(connection.connection_id)
|
||||
|
||||
assert result["status"] == "connected"
|
||||
assert updated.status == "connected"
|
||||
assert updated.account_id == "weixin:me"
|
||||
assert spec.kind == "external_connector"
|
||||
assert spec.mode == "http"
|
||||
assert spec.config["platformKind"] == "weixin"
|
||||
assert spec.config["sidecarBaseUrl"] == "http://external-connector:8787"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_feishu_connector_uses_feishu_kind(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
client = FakeSidecarClient()
|
||||
connector = FeishuConnector(
|
||||
connection_store=connection_store,
|
||||
credential_store=credential_store,
|
||||
sidecar_client=client,
|
||||
sidecar_base_url="http://external-connector:8787",
|
||||
)
|
||||
|
||||
await connector.start_session(display_name="Feishu Main", owner_user_id=None, options={"domain": "feishu"})
|
||||
|
||||
assert client.started[0]["kind"] == "feishu"
|
||||
assert client.started[0]["options"] == {"domain": "feishu"}
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_feishu_connector_persists_policy_options_in_runtime_config(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
client = FakeSidecarClient()
|
||||
connector = FeishuConnector(
|
||||
connection_store=connection_store,
|
||||
credential_store=credential_store,
|
||||
sidecar_client=client,
|
||||
sidecar_base_url="http://external-connector:8787",
|
||||
)
|
||||
|
||||
await connector.start_session(
|
||||
display_name="Feishu Main",
|
||||
owner_user_id=None,
|
||||
options={
|
||||
"domain": "feishu",
|
||||
"requireMentionInGroups": True,
|
||||
"allowFrom": ["ou_1"],
|
||||
"groupAllowFrom": ["oc_1"],
|
||||
"maxMessageChars": 1234,
|
||||
},
|
||||
)
|
||||
connection = connection_store.list()[0]
|
||||
|
||||
assert client.started[0]["options"]["requireMentionInGroups"] is True
|
||||
assert connection.runtime_config["requireMentionInGroups"] is True
|
||||
assert connection.runtime_config["allowFrom"] == ["ou_1"]
|
||||
assert connection.runtime_config["groupAllowFrom"] == ["oc_1"]
|
||||
assert connection.runtime_config["maxMessageChars"] == 1234
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_feishu_connector_materializes_policy_for_external_runtime(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
connection_store = ChannelConnectionStore(tmp_path / "connections.json")
|
||||
credential_store = CredentialStore(tmp_path / "credentials.json")
|
||||
client = ImmediateConnectedSidecarClient()
|
||||
connector = FeishuConnector(
|
||||
connection_store=connection_store,
|
||||
credential_store=credential_store,
|
||||
sidecar_client=client,
|
||||
sidecar_base_url="http://external-connector:8787",
|
||||
)
|
||||
|
||||
view = await connector.start_session(
|
||||
display_name="Feishu Main",
|
||||
owner_user_id=None,
|
||||
options={"requireMentionInGroups": True, "allowFrom": ["ou_1"], "groupAllowFrom": ["oc_1"]},
|
||||
)
|
||||
spec = await connector.materialize_runtime(view["connectionId"])
|
||||
|
||||
assert spec.config["platformKind"] == "feishu"
|
||||
assert spec.config["requireMentionInGroups"] is True
|
||||
assert spec.config["allowFrom"] == ["ou_1"]
|
||||
assert spec.config["groupAllowFrom"] == ["oc_1"]
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_connector_session_api_starts_and_polls_connected_session(tmp_path, monkeypatch) -> None:
|
||||
monkeypatch.setenv("EXTERNAL_CONNECTOR_TOKEN", "connector-token")
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
'{"agents": {"defaults": {"workspace": "%s"}}, "providers": {}}' % str(tmp_path),
|
||||
encoding="utf-8",
|
||||
)
|
||||
service = AgentService(config_path=config_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
client = FakeSidecarClient()
|
||||
|
||||
try:
|
||||
with TestClient(app) as http:
|
||||
state_dir = tmp_path / "state" / "channel_connections"
|
||||
connection_store = ChannelConnectionStore(state_dir / "connections.json")
|
||||
credential_store = CredentialStore(state_dir / "credentials.json")
|
||||
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||
registry.register(
|
||||
WeixinConnector(
|
||||
connection_store=connection_store,
|
||||
credential_store=credential_store,
|
||||
sidecar_client=client,
|
||||
sidecar_base_url="http://external-connector:8787",
|
||||
)
|
||||
)
|
||||
app.state.channel_connector_registry = registry
|
||||
|
||||
started = http.post(
|
||||
"/api/channel-connector-sessions",
|
||||
json={"kind": "weixin", "displayName": "Weixin Main", "options": {}},
|
||||
)
|
||||
session_id = started.json()["session"]["sessionId"]
|
||||
connection_id = started.json()["connection"]["connection_id"]
|
||||
client.sessions[session_id] = {
|
||||
"sessionId": session_id,
|
||||
"kind": "weixin",
|
||||
"status": "connected",
|
||||
"accountId": "weixin:me",
|
||||
"displayName": "Me",
|
||||
"metadata": {},
|
||||
}
|
||||
polled = http.get(f"/api/channel-connector-sessions/{session_id}")
|
||||
|
||||
assert started.status_code == 200
|
||||
assert polled.status_code == 200
|
||||
assert polled.json()["connection"]["status"] == "connected"
|
||||
assert connection_store.get(connection_id).status == "connected"
|
||||
assert polled.json()["connection"]["channel_id"] in app.state.channel_runtime.adapters
|
||||
finally:
|
||||
service.close()
|
||||
|
||||
|
||||
def test_connector_session_api_activates_immediate_connected_session(tmp_path, monkeypatch) -> None:
|
||||
monkeypatch.setenv("EXTERNAL_CONNECTOR_TOKEN", "connector-token")
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
'{"agents": {"defaults": {"workspace": "%s"}}, "providers": {}}' % str(tmp_path),
|
||||
encoding="utf-8",
|
||||
)
|
||||
service = AgentService(config_path=config_path)
|
||||
app = create_app(service=service, manage_service_lifecycle=False)
|
||||
client = ImmediateConnectedSidecarClient()
|
||||
|
||||
try:
|
||||
with TestClient(app) as http:
|
||||
state_dir = tmp_path / "state" / "channel_connections"
|
||||
connection_store = ChannelConnectionStore(state_dir / "connections.json")
|
||||
credential_store = CredentialStore(state_dir / "credentials.json")
|
||||
registry = ChannelConnectorRegistry(connection_store=connection_store, credential_store=credential_store)
|
||||
registry.register(
|
||||
FeishuConnector(
|
||||
connection_store=connection_store,
|
||||
credential_store=credential_store,
|
||||
sidecar_client=client,
|
||||
sidecar_base_url="http://external-connector:8787",
|
||||
)
|
||||
)
|
||||
app.state.channel_connector_registry = registry
|
||||
|
||||
started = http.post(
|
||||
"/api/channel-connector-sessions",
|
||||
json={"kind": "feishu", "displayName": "Feishu Main", "options": {}},
|
||||
)
|
||||
|
||||
assert started.status_code == 200
|
||||
connection = started.json()["connection"]
|
||||
assert connection["status"] == "connected"
|
||||
assert connection["channel_id"] in app.state.channel_runtime.adapters
|
||||
assert started.json()["session"]["metadata"] == {"stateRef": "state-1"}
|
||||
finally:
|
||||
service.close()
|
||||
154
app-instance/backend/tests/unit/test_feishu_channel_adapter.py
Normal file
154
app-instance/backend/tests/unit/test_feishu_channel_adapter.py
Normal file
@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
|
||||
from beaver.foundation.events import OutboundMessage
|
||||
from beaver.interfaces.channels.platforms.feishu import FeishuAdapter
|
||||
|
||||
|
||||
class FakeSink:
|
||||
def __init__(self) -> None:
|
||||
self.messages = []
|
||||
|
||||
async def accept_inbound(self, message):
|
||||
self.messages.append(message)
|
||||
|
||||
|
||||
class FakeFeishuClient:
|
||||
def __init__(self) -> None:
|
||||
self.sent = []
|
||||
|
||||
async def send_text(self, *, receive_id_type: str, receive_id: str, text: str):
|
||||
self.sent.append({"receive_id_type": receive_id_type, "receive_id": receive_id, "text": text})
|
||||
|
||||
|
||||
def test_feishu_normalizes_direct_text_event() -> None:
|
||||
async def run() -> None:
|
||||
sink = FakeSink()
|
||||
adapter = FeishuAdapter(
|
||||
channel_id="feishu-main",
|
||||
kind="feishu",
|
||||
mode="websocket",
|
||||
account_id="tenant-main",
|
||||
display_name=None,
|
||||
inbound_sink=sink,
|
||||
secrets={"appId": "app", "appSecret": "secret"},
|
||||
config={},
|
||||
client=FakeFeishuClient(),
|
||||
)
|
||||
|
||||
await adapter.handle_event_payload(
|
||||
{
|
||||
"event": {
|
||||
"message": {
|
||||
"message_id": "m1",
|
||||
"chat_id": "oc_chat",
|
||||
"chat_type": "p2p",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"hello\"}",
|
||||
},
|
||||
"sender": {"sender_id": {"open_id": "ou_user"}},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
message = sink.messages[0]
|
||||
assert message.content == "hello"
|
||||
assert message.session_id == "feishu-main:tenant-main:oc_chat"
|
||||
assert message.channel_identity.peer_type == "dm"
|
||||
assert message.channel_identity.user_id == "ou_user"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_feishu_group_mention_gate() -> None:
|
||||
async def run() -> None:
|
||||
sink = FakeSink()
|
||||
adapter = FeishuAdapter(
|
||||
channel_id="feishu-main",
|
||||
kind="feishu",
|
||||
mode="websocket",
|
||||
account_id="tenant-main",
|
||||
display_name=None,
|
||||
inbound_sink=sink,
|
||||
secrets={"appId": "app", "appSecret": "secret"},
|
||||
config={"requireMentionInGroups": True, "botOpenId": "ou_bot"},
|
||||
client=FakeFeishuClient(),
|
||||
)
|
||||
|
||||
await adapter.handle_event_payload(
|
||||
{
|
||||
"event": {
|
||||
"message": {
|
||||
"message_id": "m1",
|
||||
"chat_id": "oc_group",
|
||||
"chat_type": "group",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"hello\"}",
|
||||
"mentions": [],
|
||||
},
|
||||
"sender": {"sender_id": {"open_id": "ou_user"}},
|
||||
}
|
||||
}
|
||||
)
|
||||
await adapter.handle_event_payload(
|
||||
{
|
||||
"event": {
|
||||
"message": {
|
||||
"message_id": "m2",
|
||||
"chat_id": "oc_group",
|
||||
"chat_type": "group",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"hello\"}",
|
||||
"mentions": [{"id": {"open_id": "ou_bot"}}],
|
||||
},
|
||||
"sender": {"sender_id": {"open_id": "ou_user"}},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert len(sink.messages) == 1
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_feishu_sends_text_to_chat_id() -> None:
|
||||
async def run() -> None:
|
||||
sink = FakeSink()
|
||||
client = FakeFeishuClient()
|
||||
adapter = FeishuAdapter(
|
||||
channel_id="feishu-main",
|
||||
kind="feishu",
|
||||
mode="websocket",
|
||||
account_id="tenant-main",
|
||||
display_name=None,
|
||||
inbound_sink=sink,
|
||||
secrets={"appId": "app", "appSecret": "secret"},
|
||||
config={},
|
||||
client=client,
|
||||
)
|
||||
await adapter.handle_event_payload(
|
||||
{
|
||||
"event": {
|
||||
"message": {
|
||||
"message_id": "m1",
|
||||
"chat_id": "oc_chat",
|
||||
"chat_type": "p2p",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"hello\"}",
|
||||
},
|
||||
"sender": {"sender_id": {"open_id": "ou_user"}},
|
||||
}
|
||||
}
|
||||
)
|
||||
await adapter.send(
|
||||
OutboundMessage(
|
||||
channel="feishu-main",
|
||||
content="ok",
|
||||
session_id=sink.messages[0].session_id,
|
||||
finish_reason="stop",
|
||||
channel_identity=sink.messages[0].channel_identity,
|
||||
)
|
||||
)
|
||||
|
||||
assert client.sent == [{"receive_id_type": "chat_id", "receive_id": "oc_chat", "text": "ok"}]
|
||||
|
||||
asyncio.run(run())
|
||||
@ -6,7 +6,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
from beaver.tools import ObjectBackedTool, ToolContext
|
||||
from beaver.tools.builtins import ListDirectoryTool, ReadFileTool, SearchFilesTool
|
||||
from beaver.tools.builtins import ListDirectoryTool, PatchFileTool, ReadFileTool, SearchFilesTool, WriteFileTool
|
||||
|
||||
|
||||
def _run_tool(tool, arguments: dict, workspace: Path):
|
||||
@ -127,3 +127,23 @@ def test_read_file_rejects_binary_files(tmp_path: Path) -> None:
|
||||
assert payload["success"] is False
|
||||
assert "binary" in payload["error"]
|
||||
|
||||
|
||||
def test_workspace_tools_reject_user_file_virtual_paths(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
|
||||
read = _run_tool(ReadFileTool(), {"path": "uploads/get_helm.sh"}, workspace)
|
||||
listed = _run_tool(ListDirectoryTool(), {"path": "outputs"}, workspace)
|
||||
written = _run_tool(WriteFileTool(), {"path": "shared/profile.json", "content": "{}"}, workspace)
|
||||
patched = _run_tool(
|
||||
PatchFileTool(),
|
||||
{"path": "tasks/task-123/draft.md", "old_text": "a", "new_text": "b"},
|
||||
workspace,
|
||||
)
|
||||
|
||||
for result in (read, listed, written, patched):
|
||||
payload = _payload(result)
|
||||
assert result.success is False
|
||||
assert payload["success"] is False
|
||||
assert "personal agent file system path" in payload["error"]
|
||||
assert "user_files_read" in payload["error"]
|
||||
|
||||
@ -2,9 +2,10 @@ import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from beaver.foundation.events import InboundMessage, MessageBus
|
||||
from beaver.foundation.events import InboundMessage, MessageBus, OutboundMessage
|
||||
from beaver.interfaces.channels import ChannelManager, MemoryChannelAdapter
|
||||
from beaver.interfaces.gateway.main import run_gateway
|
||||
from beaver.interfaces.channels.runtime import ChannelRuntime
|
||||
from beaver.services.agent_service import AgentService
|
||||
|
||||
|
||||
@ -52,22 +53,15 @@ class InvalidService:
|
||||
is_running = True
|
||||
|
||||
|
||||
def test_gateway_routes_memory_channel_roundtrip() -> None:
|
||||
def test_gateway_routes_memory_channel_roundtrip(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
bus = MessageBus()
|
||||
channel = MemoryChannelAdapter(bus)
|
||||
stop_event = asyncio.Event()
|
||||
task = asyncio.create_task(
|
||||
run_gateway(
|
||||
service=FakeService(),
|
||||
manage_service_lifecycle=False,
|
||||
bus=bus,
|
||||
channels=[channel],
|
||||
stop_event=stop_event,
|
||||
)
|
||||
)
|
||||
runtime = ChannelRuntime(service=FakeService(), bus=bus, channels={}, workspace=tmp_path)
|
||||
channel = MemoryChannelAdapter(runtime)
|
||||
runtime.manager.register(channel)
|
||||
await runtime.start()
|
||||
|
||||
await channel.publish_text("hello", session_id="s1")
|
||||
await channel.publish_text("hello", peer_id="s1", message_id="m1")
|
||||
for _ in range(40):
|
||||
if channel.sent_messages:
|
||||
break
|
||||
@ -76,38 +70,73 @@ def test_gateway_routes_memory_channel_roundtrip() -> None:
|
||||
assert channel.sent_messages
|
||||
message = channel.sent_messages[0]
|
||||
assert message.content == "echo:hello"
|
||||
assert message.session_id == "s1"
|
||||
assert message.session_id == "memory-dev:memory:s1"
|
||||
assert message.finish_reason == "stop"
|
||||
assert message.metadata["task_id"] == "task-1"
|
||||
assert message.metadata["task_status"] == "awaiting_acceptance"
|
||||
assert message.metadata["evidence_status"] == "recorded"
|
||||
assert message.metadata["validation_result"] is None
|
||||
|
||||
stop_event.set()
|
||||
await asyncio.wait_for(task, timeout=2)
|
||||
await runtime.stop()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_gateway_delivers_cancelled_outbound_to_channel() -> None:
|
||||
def test_channel_manager_dispatches_by_channel_id() -> None:
|
||||
class CaptureChannel:
|
||||
channel_id = "webhook-dev"
|
||||
kind = "webhook"
|
||||
mode = "webhook"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent = []
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, message: Any) -> None:
|
||||
self.sent.append(message)
|
||||
|
||||
async def run() -> None:
|
||||
bus = MessageBus()
|
||||
channel = MemoryChannelAdapter(bus)
|
||||
stop_event = asyncio.Event()
|
||||
task = asyncio.create_task(
|
||||
run_gateway(
|
||||
service=SlowService(),
|
||||
manage_service_lifecycle=False,
|
||||
bus=bus,
|
||||
channels=[channel],
|
||||
stop_event=stop_event,
|
||||
channel = CaptureChannel()
|
||||
manager = ChannelManager(bus)
|
||||
manager.register(channel)
|
||||
await bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel="webhook-dev",
|
||||
content="ok",
|
||||
session_id="webhook-dev:local:demo",
|
||||
finish_reason="stop",
|
||||
)
|
||||
)
|
||||
|
||||
await channel.publish_text("slow", session_id="s1")
|
||||
await asyncio.sleep(0.05)
|
||||
stop_event = asyncio.Event()
|
||||
stop_event.set()
|
||||
await asyncio.wait_for(task, timeout=3)
|
||||
|
||||
await manager.dispatch_outbound(stop_event)
|
||||
|
||||
assert channel.sent[0].content == "ok"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_gateway_delivers_cancelled_outbound_to_channel(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
bus = MessageBus()
|
||||
runtime = ChannelRuntime(service=SlowService(), bus=bus, channels={}, workspace=tmp_path)
|
||||
channel = MemoryChannelAdapter(runtime)
|
||||
runtime.manager.register(channel)
|
||||
await runtime.start()
|
||||
|
||||
await channel.publish_text("slow", peer_id="s1", message_id="m1")
|
||||
for _ in range(40):
|
||||
if any(event["kind"] == "direct_run_started" for event in runtime.events.recent(limit=20)):
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
await runtime.stop()
|
||||
|
||||
assert channel.sent_messages
|
||||
assert channel.sent_messages[0].finish_reason == "cancelled"
|
||||
@ -118,13 +147,27 @@ def test_gateway_delivers_cancelled_outbound_to_channel() -> None:
|
||||
def test_gateway_rejects_channel_manager_and_channels_together() -> None:
|
||||
async def run() -> None:
|
||||
bus = MessageBus()
|
||||
class CaptureChannel:
|
||||
channel_id = "memory-dev"
|
||||
kind = "memory"
|
||||
mode = "webhook"
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, message: Any) -> None:
|
||||
pass
|
||||
|
||||
try:
|
||||
await run_gateway(
|
||||
service=FakeService(),
|
||||
manage_service_lifecycle=False,
|
||||
bus=bus,
|
||||
channel_manager=ChannelManager(bus),
|
||||
channels=[MemoryChannelAdapter(bus)],
|
||||
channels=[CaptureChannel()],
|
||||
stop_event=asyncio.Event(),
|
||||
)
|
||||
except ValueError as exc:
|
||||
@ -212,10 +255,16 @@ def test_channel_manager_keeps_unknown_channel_outbound_undeliverable() -> None:
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_memory_channel_adapts_old_style_payload_to_stable_session_id() -> None:
|
||||
def test_memory_channel_adapts_payload_to_channel_identity_session_id(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
bus = MessageBus()
|
||||
channel = MemoryChannelAdapter(bus, name="telegram")
|
||||
runtime = ChannelRuntime(service=FakeService(), bus=bus, channels={}, workspace=tmp_path)
|
||||
channel = MemoryChannelAdapter(
|
||||
runtime,
|
||||
channel_id="telegram-main",
|
||||
kind="telegram",
|
||||
account_id="bot-main",
|
||||
)
|
||||
inbound = await channel.publish_external_text(
|
||||
"hello",
|
||||
chat_id="chat-1",
|
||||
@ -225,8 +274,10 @@ def test_memory_channel_adapts_old_style_payload_to_stable_session_id() -> None:
|
||||
|
||||
queued = await bus.consume_inbound()
|
||||
assert queued is inbound
|
||||
assert queued.channel == "telegram"
|
||||
assert queued.session_id == "telegram:chat-1"
|
||||
assert queued.channel == "telegram-main"
|
||||
assert queued.session_id == "telegram-main:bot-main:chat-1"
|
||||
assert queued.channel_identity is not None
|
||||
assert queued.channel_identity.kind == "telegram"
|
||||
assert queued.metadata["chat_id"] == "chat-1"
|
||||
assert queued.metadata["message_id"] == "message-1"
|
||||
assert queued.metadata["raw_channel_payload"] == {"platform": "telegram", "text": "hello"}
|
||||
@ -236,7 +287,9 @@ def test_memory_channel_adapts_old_style_payload_to_stable_session_id() -> None:
|
||||
|
||||
def test_channel_manager_start_cancellation_rolls_back_started_channels() -> None:
|
||||
class StartedChannel:
|
||||
name = "started"
|
||||
channel_id = "started"
|
||||
kind = "memory"
|
||||
mode = "webhook"
|
||||
|
||||
def __init__(self, bus: MessageBus) -> None:
|
||||
self.bus = bus
|
||||
@ -252,7 +305,9 @@ def test_channel_manager_start_cancellation_rolls_back_started_channels() -> Non
|
||||
pass
|
||||
|
||||
class BlockingChannel:
|
||||
name = "blocking"
|
||||
channel_id = "blocking"
|
||||
kind = "memory"
|
||||
mode = "webhook"
|
||||
|
||||
def __init__(self, bus: MessageBus) -> None:
|
||||
self.bus = bus
|
||||
|
||||
@ -6,6 +6,34 @@ from beaver.interfaces.web.app import create_app
|
||||
from beaver.interfaces.web.schemas import WebChatRequest, WebChatResponse
|
||||
|
||||
|
||||
def test_platform_channel_modules_import_without_live_clients() -> None:
|
||||
from beaver.interfaces.channels.platforms.feishu import FeishuAdapter
|
||||
from beaver.interfaces.channels.platforms.qqbot import QQBotAdapter
|
||||
from beaver.interfaces.channels.platforms.telegram import TelegramAdapter
|
||||
from beaver.interfaces.channels.platforms.weixin import WeixinAdapter
|
||||
|
||||
assert FeishuAdapter.KIND == "feishu"
|
||||
assert QQBotAdapter.KIND == "qqbot"
|
||||
assert TelegramAdapter.KIND == "telegram"
|
||||
assert WeixinAdapter.KIND == "weixin"
|
||||
|
||||
|
||||
def test_platform_channel_optional_extras_are_declared() -> None:
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
|
||||
pyproject = Path(__file__).resolve().parents[2] / "pyproject.toml"
|
||||
data = tomllib.loads(pyproject.read_text(encoding="utf-8"))
|
||||
extras = data["project"]["optional-dependencies"]
|
||||
|
||||
assert "python-telegram-bot>=22.0,<23.0" in extras["telegram"]
|
||||
assert "lark-oapi>=1.4.22,<2.0.0" in extras["feishu"]
|
||||
assert "aiohttp>=3.9.0,<4.0.0" in extras["qqbot"]
|
||||
assert "aiohttp>=3.9.0,<4.0.0" in extras["weixin"]
|
||||
assert "python-telegram-bot>=22.0,<23.0" in extras["channels"]
|
||||
assert "lark-oapi>=1.4.22,<2.0.0" in extras["channels"]
|
||||
|
||||
|
||||
def test_agent_loop_boots(tmp_path) -> None:
|
||||
loop = AgentLoop(loader=EngineLoader(workspace=tmp_path))
|
||||
loaded = loop.boot()
|
||||
@ -32,10 +60,14 @@ def test_message_bus_imports() -> None:
|
||||
|
||||
def test_channel_imports() -> None:
|
||||
bus = MessageBus()
|
||||
channel = MemoryChannelAdapter(bus)
|
||||
class Sink:
|
||||
async def accept_inbound(self, message):
|
||||
await bus.publish_inbound(message)
|
||||
|
||||
channel = MemoryChannelAdapter(Sink())
|
||||
manager = ChannelManager(bus)
|
||||
manager.register(channel)
|
||||
assert manager.channels["memory"] is channel
|
||||
assert manager.channels["memory-dev"] is channel
|
||||
|
||||
|
||||
def test_web_schema_imports() -> None:
|
||||
|
||||
@ -30,10 +30,14 @@ EXPECTED_INITIAL_SKILL_TOOLS = {
|
||||
"mcp_outlook_mcp_calendar_find_meeting_times",
|
||||
"mcp_outlook_mcp_calendar_delta_sync",
|
||||
],
|
||||
"skills-admin": ["skills_list", "skill_manage", "skill_view"],
|
||||
"skills-admin": ["skills_list", "skill_view"],
|
||||
"terminal-operation": ["terminal", "process", "execute_code"],
|
||||
"utility-tools": ["clarify", "delegate", "send_message", "spawn", "todo"],
|
||||
"web-operation": ["web_fetch", "web_search"],
|
||||
"multi-search-engine": ["web_fetch"],
|
||||
}
|
||||
|
||||
EXPECTED_NON_INITIAL_SKILL_TOOLS = {
|
||||
"skills-authoring-admin": ["skill_manage"],
|
||||
}
|
||||
|
||||
|
||||
@ -48,6 +52,23 @@ def test_initial_skill_tool_hints_match_runtime_tool_names() -> None:
|
||||
assert version["tool_hints"] == expected_tools
|
||||
|
||||
|
||||
def test_skill_authoring_admin_is_seeded_but_not_initial() -> None:
|
||||
published = json.loads((REPO_ROOT / "skills" / "_index" / "published.json").read_text(encoding="utf-8"))
|
||||
disabled = json.loads((REPO_ROOT / "skills" / "_index" / "disabled.json").read_text(encoding="utf-8"))
|
||||
|
||||
assert "skills-authoring-admin" not in published["items"]
|
||||
assert "skills-authoring-admin" in disabled["items"]
|
||||
|
||||
for skill_name, expected_tools in EXPECTED_NON_INITIAL_SKILL_TOOLS.items():
|
||||
skill_dir = REPO_ROOT / "skills" / skill_name / "versions" / "v0001"
|
||||
frontmatter, _body = parse_frontmatter((skill_dir / "SKILL.md").read_text(encoding="utf-8"))
|
||||
version = json.loads((skill_dir / "version.json").read_text(encoding="utf-8"))
|
||||
|
||||
assert frontmatter["tools"] == expected_tools
|
||||
assert version["frontmatter"]["tools"] == expected_tools
|
||||
assert version["tool_hints"] == expected_tools
|
||||
|
||||
|
||||
def test_default_runtime_registers_skill_view_tool(tmp_path: Path) -> None:
|
||||
loaded = EngineLoader(workspace=tmp_path).load()
|
||||
try:
|
||||
|
||||
110
app-instance/backend/tests/unit/test_keycloak_auth.py
Normal file
110
app-instance/backend/tests/unit/test_keycloak_auth.py
Normal file
@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from beaver.interfaces.web.keycloak_auth import (
|
||||
KeycloakAuthConfig,
|
||||
KeycloakIdentity,
|
||||
KeycloakTokenVerifier,
|
||||
extract_bearer_token,
|
||||
)
|
||||
|
||||
|
||||
def _verifier() -> KeycloakTokenVerifier:
|
||||
return KeycloakTokenVerifier(
|
||||
config=KeycloakAuthConfig(
|
||||
issuer="https://keycloak.bwgdi.com/realms/beaver",
|
||||
client_id="beaver-agnet",
|
||||
token_url="https://keycloak.bwgdi.com/realms/beaver/protocol/openid-connect/token",
|
||||
jwks_url="https://keycloak.bwgdi.com/realms/beaver/protocol/openid-connect/certs",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _claims(**overrides):
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"sub": "user-123",
|
||||
"preferred_username": "alice",
|
||||
"email": "alice@example.com",
|
||||
"name": "Alice Example",
|
||||
"iss": "https://keycloak.bwgdi.com/realms/beaver",
|
||||
"aud": "beaver-agnet",
|
||||
"azp": "beaver-agnet",
|
||||
"iat": now,
|
||||
"exp": now + 300,
|
||||
"nonce": "nonce-1",
|
||||
"realm_access": {"roles": ["user", "admin"]},
|
||||
"resource_access": {"beaver-agnet": {"roles": ["agent-user"]}},
|
||||
}
|
||||
payload.update(overrides)
|
||||
return payload
|
||||
|
||||
|
||||
def test_extract_bearer_token_accepts_case_insensitive_prefix() -> None:
|
||||
assert extract_bearer_token("Bearer abc.def") == "abc.def"
|
||||
assert extract_bearer_token("bearer xyz") == "xyz"
|
||||
|
||||
|
||||
def test_extract_bearer_token_rejects_missing_or_invalid_header() -> None:
|
||||
with pytest.raises(HTTPException) as missing:
|
||||
extract_bearer_token(None)
|
||||
with pytest.raises(HTTPException) as invalid:
|
||||
extract_bearer_token("Basic abc")
|
||||
|
||||
assert missing.value.status_code == 401
|
||||
assert invalid.value.status_code == 401
|
||||
|
||||
|
||||
def test_validate_claims_accepts_audience_and_extracts_roles() -> None:
|
||||
identity = _verifier().validate_claims(_claims(), expected_nonce="nonce-1")
|
||||
|
||||
assert identity == KeycloakIdentity(
|
||||
user_id="user-123",
|
||||
username="alice",
|
||||
email="alice@example.com",
|
||||
name="Alice Example",
|
||||
realm_roles=("user", "admin"),
|
||||
client_roles=("agent-user",),
|
||||
)
|
||||
|
||||
|
||||
def test_validate_claims_accepts_azp_when_audience_differs() -> None:
|
||||
identity = _verifier().validate_claims(_claims(aud="account", azp="beaver-agnet"))
|
||||
|
||||
assert identity.user_id == "user-123"
|
||||
|
||||
|
||||
def test_validate_claims_rejects_wrong_nonce() -> None:
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
_verifier().validate_claims(_claims(), expected_nonce="different")
|
||||
|
||||
assert exc.value.status_code == 401
|
||||
assert "nonce" in exc.value.detail.lower()
|
||||
|
||||
|
||||
def test_validate_claims_rejects_wrong_audience_and_azp() -> None:
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
_verifier().validate_claims(_claims(aud="account", azp="other-client"))
|
||||
|
||||
assert exc.value.status_code == 401
|
||||
assert "audience" in exc.value.detail.lower()
|
||||
|
||||
|
||||
def test_verify_raises_http_exception_for_bad_jwt(monkeypatch) -> None:
|
||||
verifier = _verifier()
|
||||
|
||||
def fake_decode(*args, **kwargs):
|
||||
raise jwt.InvalidTokenError("bad token")
|
||||
|
||||
monkeypatch.setattr(jwt, "decode", fake_decode)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
verifier.verify("bad-token")
|
||||
|
||||
assert exc.value.status_code == 401
|
||||
assert "invalid token" in exc.value.detail.lower()
|
||||
@ -169,6 +169,175 @@ def test_thinking_mode_is_forced_disabled_even_when_requested_enabled(monkeypatc
|
||||
}
|
||||
|
||||
|
||||
def test_mistral_vllm_uses_reasoning_effort_instead_of_qwen_thinking_body(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict = {}
|
||||
|
||||
class Message:
|
||||
content = "ok"
|
||||
reasoning_content = None
|
||||
tool_calls = []
|
||||
|
||||
class Choice:
|
||||
message = Message()
|
||||
finish_reason = "stop"
|
||||
|
||||
class Response:
|
||||
choices = [Choice()]
|
||||
usage = None
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return Response()
|
||||
|
||||
monkeypatch.setattr("beaver.engine.providers.litellm.acompletion", fake_acompletion)
|
||||
monkeypatch.setattr("beaver.engine.providers.litellm.litellm", SimpleNamespace())
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
api_key="EMPTY",
|
||||
api_base="http://localhost:8000/v1",
|
||||
default_model="mistralai/Mistral-Medium-3.5-128B",
|
||||
provider_name="vllm",
|
||||
)
|
||||
asyncio.run(
|
||||
provider.chat(
|
||||
[{"role": "user", "content": "reply ok"}],
|
||||
model="mistralai/Mistral-Medium-3.5-128B",
|
||||
thinking_enabled=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert captured["model"] == "hosted_vllm/mistralai/Mistral-Medium-3.5-128B"
|
||||
assert captured["extra_body"] == {"reasoning_effort": "high"}
|
||||
|
||||
|
||||
def test_mistral_vllm_omits_reasoning_body_when_thinking_mode_is_unspecified(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict = {}
|
||||
|
||||
class Message:
|
||||
content = "ok"
|
||||
reasoning_content = None
|
||||
tool_calls = []
|
||||
|
||||
class Choice:
|
||||
message = Message()
|
||||
finish_reason = "stop"
|
||||
|
||||
class Response:
|
||||
choices = [Choice()]
|
||||
usage = None
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return Response()
|
||||
|
||||
monkeypatch.setattr("beaver.engine.providers.litellm.acompletion", fake_acompletion)
|
||||
monkeypatch.setattr("beaver.engine.providers.litellm.litellm", SimpleNamespace())
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
api_key="EMPTY",
|
||||
api_base="http://localhost:8000/v1",
|
||||
default_model="mistralai/Mistral-Medium-3.5-128B",
|
||||
provider_name="vllm",
|
||||
)
|
||||
asyncio.run(
|
||||
provider.chat(
|
||||
[{"role": "user", "content": "reply ok"}],
|
||||
model="mistralai/Mistral-Medium-3.5-128B",
|
||||
)
|
||||
)
|
||||
|
||||
assert "extra_body" not in captured
|
||||
|
||||
|
||||
def test_mistral_openai_compatible_private_vllm_uses_reasoning_effort(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict = {}
|
||||
|
||||
class Message:
|
||||
content = "ok"
|
||||
reasoning_content = None
|
||||
tool_calls = []
|
||||
|
||||
class Choice:
|
||||
message = Message()
|
||||
finish_reason = "stop"
|
||||
|
||||
class Response:
|
||||
choices = [Choice()]
|
||||
usage = None
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return Response()
|
||||
|
||||
monkeypatch.setattr("beaver.engine.providers.litellm.acompletion", fake_acompletion)
|
||||
monkeypatch.setattr("beaver.engine.providers.litellm.litellm", SimpleNamespace())
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
api_key="EMPTY",
|
||||
api_base="http://172.19.207.103/v1",
|
||||
default_model="Mistral-Medium-3.5-128B",
|
||||
provider_name="openai",
|
||||
)
|
||||
asyncio.run(
|
||||
provider.chat(
|
||||
[{"role": "user", "content": "reply ok"}],
|
||||
model="Mistral-Medium-3.5-128B",
|
||||
thinking_enabled=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert captured["extra_body"] == {"reasoning_effort": "none"}
|
||||
assert "chat_template_kwargs" not in captured["extra_body"]
|
||||
assert "thinking" not in captured["extra_body"]
|
||||
|
||||
|
||||
def test_mistral_openai_compatible_private_vllm_omits_body_when_unspecified(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
captured: dict = {}
|
||||
|
||||
class Message:
|
||||
content = "ok"
|
||||
reasoning_content = None
|
||||
tool_calls = []
|
||||
|
||||
class Choice:
|
||||
message = Message()
|
||||
finish_reason = "stop"
|
||||
|
||||
class Response:
|
||||
choices = [Choice()]
|
||||
usage = None
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return Response()
|
||||
|
||||
monkeypatch.setattr("beaver.engine.providers.litellm.acompletion", fake_acompletion)
|
||||
monkeypatch.setattr("beaver.engine.providers.litellm.litellm", SimpleNamespace())
|
||||
|
||||
provider = LiteLLMProvider(
|
||||
api_key="EMPTY",
|
||||
api_base="http://172.19.207.103/v1",
|
||||
default_model="Mistral-Medium-3.5-128B",
|
||||
provider_name="openai",
|
||||
)
|
||||
asyncio.run(
|
||||
provider.chat(
|
||||
[{"role": "user", "content": "reply ok"}],
|
||||
model="Mistral-Medium-3.5-128B",
|
||||
)
|
||||
)
|
||||
|
||||
assert "extra_body" not in captured
|
||||
|
||||
|
||||
def test_litellm_provider_sanitizes_tool_call_arguments(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict = {}
|
||||
|
||||
|
||||
@ -149,6 +149,22 @@ def test_router_injects_intent_skill_guidance() -> None:
|
||||
assert "Weather and current external data" in prompt
|
||||
|
||||
|
||||
def test_router_prompt_treats_unrelated_lightweight_conversation_as_new_topic() -> None:
|
||||
provider = RouterProvider('{"action":"simple_chat","reason":"unrelated lightweight conversation"}')
|
||||
|
||||
asyncio.run(
|
||||
MainAgentRouter().classify(
|
||||
"吃饭没",
|
||||
active_task=_task(),
|
||||
provider=provider,
|
||||
)
|
||||
)
|
||||
|
||||
prompt = provider.calls[0]["messages"][1]["content"]
|
||||
assert "unrelated lightweight conversation" in prompt
|
||||
assert "must not be classified as revise_task merely because the active Task is awaiting acceptance" in prompt
|
||||
|
||||
|
||||
def test_router_closes_active_task_from_llm_decision() -> None:
|
||||
decision = asyncio.run(
|
||||
MainAgentRouter().classify(
|
||||
|
||||
22
app-instance/backend/tests/unit/test_mcp_tools_server.py
Normal file
22
app-instance/backend/tests/unit/test_mcp_tools_server.py
Normal file
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from beaver.interfaces.mcp.tools_server import _category_tools
|
||||
|
||||
|
||||
def test_local_filesystem_mcp_exposes_personal_user_file_tools_only(tmp_path) -> None:
|
||||
tools, _context = _category_tools("filesystem", tmp_path)
|
||||
|
||||
names = [tool.spec.name for tool in tools]
|
||||
|
||||
assert names == [
|
||||
"user_files_list",
|
||||
"user_files_read",
|
||||
"user_files_write",
|
||||
"user_files_mkdir",
|
||||
"user_files_copy_to_workspace",
|
||||
"user_files_publish_output",
|
||||
]
|
||||
assert "read_file" not in names
|
||||
assert "search_files" not in names
|
||||
assert "list_directory" not in names
|
||||
assert all("personal agent file system" in tool.spec.description for tool in tools)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user