feat: extend memory search and attachment mapping
This commit is contained in:
48
core/api.py
48
core/api.py
@ -8,14 +8,19 @@ from typing import Any, Literal
|
||||
from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
||||
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from starlette.responses import Response
|
||||
|
||||
from .config import GatewayConfig
|
||||
from .db import init_db
|
||||
from .backend_client import BackendClient
|
||||
from .repository import MemoryRepository
|
||||
from .service import MemoryGatewayService, UnsupportedContentType, UploadTooLarge
|
||||
from .service import (
|
||||
InvalidAttachment,
|
||||
MemoryGatewayService,
|
||||
UnsupportedContentType,
|
||||
UploadTooLarge,
|
||||
)
|
||||
|
||||
|
||||
API_LOGGER = logging.getLogger("memory_gateway.api")
|
||||
@ -34,15 +39,28 @@ SENSITIVE_FIELD_NAMES = {
|
||||
class SearchMemoriesRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
agent_id: str | None = Field(default=None, min_length=1)
|
||||
conversation_id: str | None = None
|
||||
query: str = Field(min_length=1)
|
||||
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
|
||||
default_factory=lambda: ["current_chat", "resources"]
|
||||
)
|
||||
top_k: int = Field(default=8, ge=1, le=100)
|
||||
method: Literal["keyword", "vector", "hybrid", "agentic"] = "hybrid"
|
||||
top_k: int = 8
|
||||
radius: float | None = Field(default=None, ge=0, le=1)
|
||||
include_profile: bool = True
|
||||
enable_llm_rerank: bool = True
|
||||
filters: dict[str, Any] | None = None
|
||||
app_id: str = "default"
|
||||
project_id: str = "default"
|
||||
|
||||
@field_validator("top_k")
|
||||
@classmethod
|
||||
def validate_top_k(cls, value: int) -> int:
|
||||
if value != -1 and not 1 <= value <= 100:
|
||||
raise ValueError("top_k must be -1 or in 1..100")
|
||||
return value
|
||||
|
||||
|
||||
class AddMemoryMessage(BaseModel):
|
||||
sender_id: str = Field(min_length=1)
|
||||
@ -367,10 +385,16 @@ def create_app(
|
||||
require_user(request.user_id, request.user_key)
|
||||
return await service.search_memories(
|
||||
user_id=request.user_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
conversation_id=request.conversation_id,
|
||||
scope=request.scope,
|
||||
method=request.method,
|
||||
top_k=request.top_k,
|
||||
radius=request.radius,
|
||||
include_profile=request.include_profile,
|
||||
enable_llm_rerank=request.enable_llm_rerank,
|
||||
filters=request.filters,
|
||||
app_id=request.app_id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
@ -380,12 +404,18 @@ def create_app(
|
||||
request: AddMemoryRequest,
|
||||
) -> dict[str, Any]:
|
||||
require_user(request.user_id, request.user_key)
|
||||
return await service.add_memory(
|
||||
session_id=request.session_id,
|
||||
app_id=request.app_id,
|
||||
project_id=request.project_id,
|
||||
messages=[message.model_dump() for message in request.messages],
|
||||
)
|
||||
try:
|
||||
return await service.add_memory(
|
||||
user_id=request.user_id,
|
||||
session_id=request.session_id,
|
||||
app_id=request.app_id,
|
||||
project_id=request.project_id,
|
||||
messages=[message.model_dump() for message in request.messages],
|
||||
)
|
||||
except UploadTooLarge as exc:
|
||||
raise HTTPException(status_code=413, detail=str(exc)) from exc
|
||||
except InvalidAttachment as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
|
||||
@router.post("/memories/flush")
|
||||
async def flush_memory(
|
||||
|
||||
Reference in New Issue
Block a user