feat: support session_id in current chat search
This commit is contained in:
14
core/api.py
14
core/api.py
@ -10,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
||||
import httpx
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
from starlette.responses import Response
|
||||
|
||||
@ -43,6 +43,7 @@ class SearchMemoriesRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
agent_id: str | None = Field(default=None, min_length=1)
|
||||
session_id: str | None = Field(default=None, min_length=1)
|
||||
conversation_id: str | None = None
|
||||
query: str = Field(min_length=1)
|
||||
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
|
||||
@ -64,6 +65,16 @@ class SearchMemoriesRequest(BaseModel):
|
||||
raise ValueError("top_k must be -1 or in 1..100")
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_current_chat_session_aliases(self) -> SearchMemoriesRequest:
|
||||
if self.session_id and self.conversation_id:
|
||||
expected_session_id = f"chat:{self.conversation_id}"
|
||||
if self.session_id != expected_session_id:
|
||||
raise ValueError(
|
||||
"session_id must match chat:{conversation_id} when both are provided"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AddMemoryMessage(BaseModel):
|
||||
sender_id: str = Field(min_length=1)
|
||||
@ -487,6 +498,7 @@ def create_app(
|
||||
user_id=request.user_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
conversation_id=request.conversation_id,
|
||||
scope=request.scope,
|
||||
method=request.method,
|
||||
|
||||
@ -502,6 +502,7 @@ class MemoryGatewayService:
|
||||
user_id: str,
|
||||
agent_id: str | None,
|
||||
query: str,
|
||||
session_id: str | None,
|
||||
conversation_id: str | None,
|
||||
scope: list[str],
|
||||
method: str,
|
||||
@ -515,8 +516,11 @@ class MemoryGatewayService:
|
||||
) -> dict[str, Any]:
|
||||
results: list[dict[str, Any]] = []
|
||||
session_resource_map: dict[str, dict[str, Any]] = {}
|
||||
current_chat_session_id = session_id
|
||||
if current_chat_session_id is None and conversation_id:
|
||||
current_chat_session_id = f"chat:{conversation_id}"
|
||||
|
||||
if "current_chat" in scope and conversation_id:
|
||||
if "current_chat" in scope and current_chat_session_id:
|
||||
payload = self._search_payload(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
@ -530,7 +534,7 @@ class MemoryGatewayService:
|
||||
project_id=project_id,
|
||||
filters=_combine_filters(
|
||||
filters,
|
||||
{"session_id": f"chat:{conversation_id}"},
|
||||
{"session_id": current_chat_session_id},
|
||||
),
|
||||
)
|
||||
results.extend(
|
||||
|
||||
Reference in New Issue
Block a user