168 lines
5.2 KiB
Python
168 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Literal
|
|
|
|
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, UploadFile
|
|
from pydantic import BaseModel, Field
|
|
|
|
from .config import GatewayConfig
|
|
from .db import init_db
|
|
from .everos_client import EverOSClient
|
|
from .repository import MemoryRepository
|
|
from .service import MemoryGatewayService
|
|
|
|
|
|
class SearchMemoriesRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
user_key: str = Field(min_length=1)
|
|
conversation_id: str | None = None
|
|
query: str = Field(min_length=1)
|
|
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
|
|
default_factory=lambda: ["current_chat", "resources"]
|
|
)
|
|
top_k: int = Field(default=8, ge=1, le=100)
|
|
app_id: str = "default"
|
|
project_id: str = "default"
|
|
|
|
|
|
class MemoryOverrideRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
user_key: str = Field(min_length=1)
|
|
session_id: str | None = None
|
|
override_text: str = Field(min_length=1)
|
|
|
|
|
|
class MemoryDeleteRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
user_key: str = Field(min_length=1)
|
|
session_id: str | None = None
|
|
reason: str | None = None
|
|
|
|
|
|
class UserCreateRequest(BaseModel):
|
|
user_id: str = Field(min_length=1)
|
|
|
|
|
|
def create_app(
|
|
*,
|
|
config: GatewayConfig | None = None,
|
|
everos_client: Any | None = None,
|
|
) -> FastAPI:
|
|
cfg = config or GatewayConfig.from_env()
|
|
init_db(cfg.database_path)
|
|
repository = MemoryRepository(cfg.database_path)
|
|
client = everos_client or EverOSClient(cfg.everos_base_url)
|
|
service = MemoryGatewayService(cfg, repository, client)
|
|
|
|
app = FastAPI(title="memory-gateway2", version="0.1.0")
|
|
app.state.config = cfg
|
|
app.state.repository = repository
|
|
app.state.everos_client = client
|
|
app.state.gateway_service = service
|
|
|
|
router = APIRouter()
|
|
|
|
def require_user(user_id: str, user_key: str) -> None:
|
|
if not service.authenticate_user(user_id, user_key):
|
|
raise HTTPException(status_code=401, detail="invalid user credentials")
|
|
|
|
@router.post("/users")
|
|
async def create_user(request: UserCreateRequest) -> dict[str, Any]:
|
|
return service.create_user(request.user_id)
|
|
|
|
@router.post("/resources")
|
|
async def upload_resource(
|
|
user_id: str = Form(...),
|
|
user_key: str = Form(...),
|
|
app_id: str = Form("default"),
|
|
project_id: str = Form("default"),
|
|
title: str | None = Form(None),
|
|
description: str | None = Form(None),
|
|
file: UploadFile = File(...),
|
|
) -> dict[str, Any]:
|
|
require_user(user_id, user_key)
|
|
return await service.upload_resource(
|
|
user_id=user_id,
|
|
app_id=app_id,
|
|
project_id=project_id,
|
|
file=file,
|
|
title=title,
|
|
description=description,
|
|
)
|
|
|
|
@router.get("/resources")
|
|
async def list_resources(
|
|
user_id: str,
|
|
user_key: str,
|
|
) -> dict[str, Any]:
|
|
require_user(user_id, user_key)
|
|
return {"resources": service.list_resources(user_id)}
|
|
|
|
@router.get("/resources/{resource_id}")
|
|
async def get_resource(
|
|
resource_id: str,
|
|
user_id: str,
|
|
user_key: str,
|
|
) -> dict[str, Any]:
|
|
require_user(user_id, user_key)
|
|
resource = service.get_resource_detail(resource_id, user_id)
|
|
if resource is None:
|
|
return {"resources": []}
|
|
return {"resources": [resource]}
|
|
|
|
@router.delete("/resources/{resource_id}")
|
|
async def delete_resource(
|
|
resource_id: str,
|
|
user_id: str,
|
|
user_key: str,
|
|
) -> dict[str, Any]:
|
|
require_user(user_id, user_key)
|
|
resource = service.delete_resource(resource_id, user_id)
|
|
if resource is None:
|
|
raise HTTPException(status_code=404, detail="resource not found")
|
|
return resource
|
|
|
|
@router.post("/memories/search")
|
|
async def search_memories(
|
|
request: SearchMemoriesRequest,
|
|
) -> dict[str, Any]:
|
|
require_user(request.user_id, request.user_key)
|
|
return await service.search_memories(
|
|
user_id=request.user_id,
|
|
query=request.query,
|
|
conversation_id=request.conversation_id,
|
|
scope=request.scope,
|
|
top_k=request.top_k,
|
|
app_id=request.app_id,
|
|
project_id=request.project_id,
|
|
)
|
|
|
|
@router.patch("/memories/{memory_id}")
|
|
async def patch_memory(
|
|
memory_id: str,
|
|
request: MemoryOverrideRequest,
|
|
) -> dict[str, Any]:
|
|
require_user(request.user_id, request.user_key)
|
|
return service.upsert_override(
|
|
user_id=request.user_id,
|
|
memory_id=memory_id,
|
|
session_id=request.session_id,
|
|
override_text=request.override_text,
|
|
)
|
|
|
|
@router.delete("/memories/{memory_id}")
|
|
async def delete_memory(
|
|
memory_id: str,
|
|
request: MemoryDeleteRequest,
|
|
) -> dict[str, Any]:
|
|
require_user(request.user_id, request.user_key)
|
|
return service.delete_memory(
|
|
user_id=request.user_id,
|
|
memory_id=memory_id,
|
|
session_id=request.session_id,
|
|
reason=request.reason,
|
|
)
|
|
|
|
app.include_router(router)
|
|
return app
|