feat: support session_id in current chat search

This commit is contained in:
2026-06-23 11:39:29 +08:00
parent f77454b4cc
commit d7e061b780
4 changed files with 78 additions and 5 deletions

View File

@ -10,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
import httpx
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
from pydantic import ValidationError
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from starlette.datastructures import UploadFile as StarletteUploadFile
from starlette.responses import Response
@ -43,6 +43,7 @@ class SearchMemoriesRequest(BaseModel):
user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1)
agent_id: str | None = Field(default=None, min_length=1)
session_id: str | None = Field(default=None, min_length=1)
conversation_id: str | None = None
query: str = Field(min_length=1)
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
@ -64,6 +65,16 @@ class SearchMemoriesRequest(BaseModel):
raise ValueError("top_k must be -1 or in 1..100")
return value
@model_validator(mode="after")
def validate_current_chat_session_aliases(self) -> SearchMemoriesRequest:
if self.session_id and self.conversation_id:
expected_session_id = f"chat:{self.conversation_id}"
if self.session_id != expected_session_id:
raise ValueError(
"session_id must match chat:{conversation_id} when both are provided"
)
return self
class AddMemoryMessage(BaseModel):
sender_id: str = Field(min_length=1)
@ -487,6 +498,7 @@ def create_app(
user_id=request.user_id,
agent_id=request.agent_id,
query=request.query,
session_id=request.session_id,
conversation_id=request.conversation_id,
scope=request.scope,
method=request.method,

View File

@ -502,6 +502,7 @@ class MemoryGatewayService:
user_id: str,
agent_id: str | None,
query: str,
session_id: str | None,
conversation_id: str | None,
scope: list[str],
method: str,
@ -515,8 +516,11 @@ class MemoryGatewayService:
) -> dict[str, Any]:
results: list[dict[str, Any]] = []
session_resource_map: dict[str, dict[str, Any]] = {}
current_chat_session_id = session_id
if current_chat_session_id is None and conversation_id:
current_chat_session_id = f"chat:{conversation_id}"
if "current_chat" in scope and conversation_id:
if "current_chat" in scope and current_chat_session_id:
payload = self._search_payload(
user_id=user_id,
agent_id=agent_id,
@ -530,7 +534,7 @@ class MemoryGatewayService:
project_id=project_id,
filters=_combine_filters(
filters,
{"session_id": f"chat:{conversation_id}"},
{"session_id": current_chat_session_id},
),
)
results.extend(