feat(memory-gateway): merge memory mode with main
This commit is contained in:
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
@ -21,6 +22,13 @@ from typing import Any
|
||||
from beaver.engine.providers.registry import PROVIDERS, find_by_name
|
||||
from beaver.foundation.config import default_config_path, load_config
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage
|
||||
from beaver.memory.gateway import (
|
||||
MemoryGatewayClient,
|
||||
MemoryGatewayClientError,
|
||||
MemoryGatewayCredentialStore,
|
||||
MemoryGatewayUserCredential,
|
||||
default_memory_gateway_users_path,
|
||||
)
|
||||
from beaver.interfaces.channels.runtime import ChannelRuntime
|
||||
from beaver.interfaces.channels.connections import (
|
||||
ChannelConnectionStore,
|
||||
@ -103,6 +111,8 @@ from .schemas import (
|
||||
WebStatusResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from fastapi import FastAPI, File, Form, Header, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@ -624,6 +634,10 @@ def create_app(
|
||||
app.state.handoff_codes = {}
|
||||
app.state.skill_eval_tasks = {}
|
||||
app.state.auth_file = Path(os.getenv("BEAVER_AUTH_FILE") or "")
|
||||
app.state.memory_gateway_credential_store = MemoryGatewayCredentialStore(
|
||||
default_memory_gateway_users_path()
|
||||
)
|
||||
app.state.memory_gateway_client_factory = lambda config: MemoryGatewayClient(config)
|
||||
max_file_size = 50 * 1024 * 1024
|
||||
max_user_file_upload_size = _int_env("BEAVER_USER_FILES_MAX_UPLOAD_BYTES", 5 * 1024 * 1024 * 1024)
|
||||
user_file_upload_part_size = _int_env("BEAVER_USER_FILES_UPLOAD_PART_SIZE", 10 * 1024 * 1024)
|
||||
@ -1139,6 +1153,30 @@ def create_app(
|
||||
users[username] = password
|
||||
_save_auth_users(auth_file, users)
|
||||
|
||||
if config.memory.mode == "hybrid" and config.memory.gateway.is_configured:
|
||||
try:
|
||||
gateway_client = app.state.memory_gateway_client_factory(config.memory.gateway)
|
||||
gateway_payload = await gateway_client.create_user(username)
|
||||
gateway_user_id = _clean_text(gateway_payload.get("user_id"))
|
||||
gateway_user_key = _clean_text(gateway_payload.get("user_key"))
|
||||
if not gateway_user_id or not gateway_user_key:
|
||||
raise MemoryGatewayClientError("create_user", "invalid_response")
|
||||
app.state.memory_gateway_credential_store.save(
|
||||
username,
|
||||
MemoryGatewayUserCredential(
|
||||
user_id=gateway_user_id,
|
||||
user_key=gateway_user_key,
|
||||
),
|
||||
)
|
||||
except MemoryGatewayClientError as exc:
|
||||
logger.warning(
|
||||
"Memory Gateway user provisioning failed for Beaver user %s: operation=%s category=%s status_code=%s",
|
||||
username,
|
||||
exc.operation,
|
||||
exc.category,
|
||||
exc.status_code,
|
||||
)
|
||||
|
||||
token = _issue_web_token(app, username)
|
||||
handoff_code, handoff_expires_at = _issue_handoff_code(app, username, token)
|
||||
backend_connection = {
|
||||
@ -2571,7 +2609,11 @@ def create_app(
|
||||
503: {"model": WebErrorResponse},
|
||||
},
|
||||
)
|
||||
async def chat(request: Request, payload: WebChatRequest) -> WebChatResponse:
|
||||
async def chat(
|
||||
request: Request,
|
||||
payload: WebChatRequest,
|
||||
authorization: str | None = Header(default=None),
|
||||
) -> WebChatResponse:
|
||||
agent_service = get_agent_service(request)
|
||||
message = payload.message.strip()
|
||||
if not message:
|
||||
@ -2622,10 +2664,12 @@ def create_app(
|
||||
embedding_target = _model_dump(payload.embedding_target)
|
||||
|
||||
try:
|
||||
gateway_user_id = _optional_web_user(app, authorization)
|
||||
direct_kwargs = {
|
||||
"session_id": payload.session_id,
|
||||
"source": "web",
|
||||
"user_id": payload.user_id,
|
||||
"gateway_user_id": gateway_user_id,
|
||||
"title": payload.title,
|
||||
"execution_context": payload.execution_context,
|
||||
"prompt_locale": payload.prompt_locale,
|
||||
@ -2684,6 +2728,7 @@ def create_app(
|
||||
await websocket.send_json({"type": "error", "error": "AgentService is not ready"})
|
||||
await websocket.close(code=1011)
|
||||
return
|
||||
gateway_user_id = _web_user_from_token(app, websocket.query_params.get("token"))
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -2742,6 +2787,7 @@ def create_app(
|
||||
"session_id": session_id,
|
||||
"source": "websocket",
|
||||
"user_id": _clean_text(payload.get("user_id")) or None,
|
||||
"gateway_user_id": gateway_user_id,
|
||||
"title": _clean_text(payload.get("title")) or None,
|
||||
"execution_context": _clean_text(payload.get("execution_context")) or None,
|
||||
"prompt_locale": _clean_text(payload.get("prompt_locale")) or None,
|
||||
@ -3806,6 +3852,22 @@ def _require_web_user(app: FastAPI, authorization: str | None) -> str:
|
||||
return username
|
||||
|
||||
|
||||
def _optional_web_user(app: FastAPI, authorization: str | None) -> str | None:
|
||||
if not authorization:
|
||||
return None
|
||||
prefix = "bearer "
|
||||
if not authorization.lower().startswith(prefix):
|
||||
return None
|
||||
return _web_user_from_token(app, authorization[len(prefix):].strip())
|
||||
|
||||
|
||||
def _web_user_from_token(app: FastAPI, token: str | None) -> str | None:
|
||||
cleaned = _clean_text(token)
|
||||
if not cleaned:
|
||||
return None
|
||||
return app.state.auth_tokens.get(cleaned)
|
||||
|
||||
|
||||
def _backend_connection_view(request: Request) -> dict[str, Any]:
|
||||
public_base_url = (
|
||||
os.getenv("BEAVER_BACKEND_IDENTITY__PUBLIC_BASE_URL")
|
||||
|
||||
Reference in New Issue
Block a user