Files
beaver_project/app-instance/backend/beaver/interfaces/web/app.py

199 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""FastAPI app factory for Beaver."""
from __future__ import annotations
from collections.abc import AsyncIterator, Callable
from contextlib import asynccontextmanager
from pathlib import Path
from types import SimpleNamespace
from typing import Any
from beaver.services.agent_service import AgentService
from .deps import get_agent_service
from .schemas import WebChatRequest, WebChatResponse, WebErrorResponse, WebStatusResponse
try:
from fastapi import FastAPI, HTTPException, Request
except ModuleNotFoundError: # pragma: no cover - fallback for skeleton-only environments
class HTTPException(Exception):
"""Minimal fallback exception matching FastAPI's constructor shape."""
def __init__(self, status_code: int, detail: str) -> None:
super().__init__(detail)
self.status_code = status_code
self.detail = detail
class Request: # type: ignore[override]
"""Fallback request shim used only for import-time compatibility."""
def __init__(self, app: Any) -> None:
self.app = app
class FastAPI: # type: ignore[override]
"""Small fallback shim so the package can import before dependencies are installed."""
def __init__(self, *, title: str, lifespan: Callable[..., Any] | None = None) -> None:
self.title = title
self.lifespan = lifespan
self.state = SimpleNamespace()
def get(self, _path: str, **_kwargs: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
return func
return decorator
def post(self, _path: str, **_kwargs: Any) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
return func
return decorator
@asynccontextmanager
async def _app_lifespan(
app: FastAPI,
*,
workspace: str | Path | None,
service: AgentService | None,
manage_service_lifecycle: bool | None,
shutdown_timeout_seconds: float | None,
shutdown_force: bool,
) -> AsyncIterator[None]:
"""把 Web app 接到 AgentService lifecycle 上。"""
attached_service = service or AgentService(workspace=workspace)
owns_service = manage_service_lifecycle if manage_service_lifecycle is not None else service is None
app.state.agent_service = attached_service
started = False
if owns_service:
try:
await attached_service.start()
started = True
except Exception:
attached_service.close()
raise
try:
yield
finally:
if owns_service and started:
await attached_service.shutdown(
timeout_seconds=shutdown_timeout_seconds,
force=shutdown_force,
)
def create_app(
*,
workspace: str | Path | None = None,
service: AgentService | None = None,
manage_service_lifecycle: bool | None = None,
shutdown_timeout_seconds: float | None = 5.0,
shutdown_force: bool = True,
) -> FastAPI:
"""Create a Beaver web app hosted by AgentService running mode.
默认 ownership 语义:
- 未传 `service`app 自己创建并接管其 lifecycle
- 传入外部 `service`:默认只挂载,不自动 start/shutdown
如果确实需要覆盖默认行为,可以显式传 `manage_service_lifecycle=True/False`。
"""
app = FastAPI(
title="Beaver Backend",
lifespan=lambda fastapi_app: _app_lifespan(
fastapi_app,
workspace=workspace,
service=service,
manage_service_lifecycle=manage_service_lifecycle,
shutdown_timeout_seconds=shutdown_timeout_seconds,
shutdown_force=shutdown_force,
),
)
@app.get("/api/ping", response_model=WebStatusResponse)
async def ping(request: Request) -> WebStatusResponse:
agent_service = get_agent_service(request)
running = agent_service.is_running
return WebStatusResponse(
status="ok",
running=running,
mode="running" if running else ("direct" if agent_service.has_loop else "idle"),
)
@app.post(
"/api/chat",
response_model=WebChatResponse,
responses={
400: {"model": WebErrorResponse},
409: {"model": WebErrorResponse},
503: {"model": WebErrorResponse},
},
)
async def chat(request: Request, payload: WebChatRequest) -> WebChatResponse:
agent_service = get_agent_service(request)
message = payload.message.strip()
if not message:
raise HTTPException(status_code=400, detail="'message' is required")
fallback_target = _model_dump(payload.fallback_target)
auxiliary_target = _model_dump(payload.auxiliary_target)
embedding_target = _model_dump(payload.embedding_target)
try:
result = await agent_service.submit_direct(
message,
session_id=payload.session_id,
source="web",
user_id=payload.user_id,
title=payload.title,
execution_context=payload.execution_context,
model=payload.model,
provider_name=payload.provider_name,
embedding_model=payload.embedding_model,
temperature=payload.temperature,
max_tokens=payload.max_tokens,
max_tool_iterations=payload.max_tool_iterations,
fallback_target=fallback_target,
auxiliary_target=auxiliary_target,
embedding_target=embedding_target,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except RuntimeError as exc:
detail = str(exc)
if "requires an active run() loop" in detail or "not ready" in detail:
status_code = 503
elif "submit_direct" in detail or "running" in detail:
status_code = 409
else:
status_code = 503
raise HTTPException(status_code=status_code, detail=detail) from exc
return WebChatResponse(
session_id=result.session_id,
run_id=result.run_id,
output_text=result.output_text,
finish_reason=result.finish_reason,
tool_iterations=result.tool_iterations,
provider_name=result.provider_name,
model=result.model,
usage=result.usage,
)
return app
def _model_dump(value: Any) -> dict[str, Any] | None:
"""兼容 Pydantic v1/v2 的最小导出辅助。"""
if value is None:
return None
if hasattr(value, "model_dump"):
return value.model_dump(exclude_none=True)
if hasattr(value, "dict"):
return value.dict(exclude_none=True)
return dict(value)