Merge pull request 'feat: support session_id in current chat search' (#1) from codex/support-search-session-id into main

Reviewed-on: #1
This commit is contained in:
2026-06-23 03:52:10 +00:00
4 changed files with 78 additions and 5 deletions

View File

@ -486,7 +486,7 @@ curl -X POST http://127.0.0.1:8010/memories/search \
-d '{ -d '{
"user_id": "u_123", "user_id": "u_123",
"user_key": "uk_xxx", "user_key": "uk_xxx",
"conversation_id": "c_456", "session_id": "chat:c_456",
"query": "图片里的蓝色圆形在哪里?", "query": "图片里的蓝色圆形在哪里?",
"scope": ["current_chat", "resources"], "scope": ["current_chat", "resources"],
"method": "hybrid", "method": "hybrid",
@ -502,7 +502,7 @@ curl -X POST http://127.0.0.1:8010/memories/search \
| scope | 行为 | | scope | 行为 |
|---|---| |---|---|
| `current_chat` | 搜索 `chat:{conversation_id}`需要传 `conversation_id` | | `current_chat` | 搜索指定聊天 session推荐传完整 `session_id`,如 `chat:c_456`;也兼容旧字段 `conversation_id`会转换为 `chat:{conversation_id}` |
| `resources` | 搜索当前用户已提取且未删除的资源 session | | `resources` | 搜索当前用户已提取且未删除的资源 session |
| `all_user_memory` | 搜索用户全部记忆,不加 session 过滤 | | `all_user_memory` | 搜索用户全部记忆,不加 session 过滤 |

View File

@ -10,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
import httpx import httpx
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
from pydantic import ValidationError from pydantic import ValidationError
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator, model_validator
from starlette.datastructures import UploadFile as StarletteUploadFile from starlette.datastructures import UploadFile as StarletteUploadFile
from starlette.responses import Response from starlette.responses import Response
@ -43,6 +43,7 @@ class SearchMemoriesRequest(BaseModel):
user_id: str = Field(min_length=1) user_id: str = Field(min_length=1)
user_key: str = Field(min_length=1) user_key: str = Field(min_length=1)
agent_id: str | None = Field(default=None, min_length=1) agent_id: str | None = Field(default=None, min_length=1)
session_id: str | None = Field(default=None, min_length=1)
conversation_id: str | None = None conversation_id: str | None = None
query: str = Field(min_length=1) query: str = Field(min_length=1)
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field( scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
@ -64,6 +65,16 @@ class SearchMemoriesRequest(BaseModel):
raise ValueError("top_k must be -1 or in 1..100") raise ValueError("top_k must be -1 or in 1..100")
return value return value
@model_validator(mode="after")
def validate_current_chat_session_aliases(self) -> SearchMemoriesRequest:
if self.session_id and self.conversation_id:
expected_session_id = f"chat:{self.conversation_id}"
if self.session_id != expected_session_id:
raise ValueError(
"session_id must match chat:{conversation_id} when both are provided"
)
return self
class AddMemoryMessage(BaseModel): class AddMemoryMessage(BaseModel):
sender_id: str = Field(min_length=1) sender_id: str = Field(min_length=1)
@ -487,6 +498,7 @@ def create_app(
user_id=request.user_id, user_id=request.user_id,
agent_id=request.agent_id, agent_id=request.agent_id,
query=request.query, query=request.query,
session_id=request.session_id,
conversation_id=request.conversation_id, conversation_id=request.conversation_id,
scope=request.scope, scope=request.scope,
method=request.method, method=request.method,

View File

@ -502,6 +502,7 @@ class MemoryGatewayService:
user_id: str, user_id: str,
agent_id: str | None, agent_id: str | None,
query: str, query: str,
session_id: str | None,
conversation_id: str | None, conversation_id: str | None,
scope: list[str], scope: list[str],
method: str, method: str,
@ -515,8 +516,11 @@ class MemoryGatewayService:
) -> dict[str, Any]: ) -> dict[str, Any]:
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []
session_resource_map: dict[str, dict[str, Any]] = {} session_resource_map: dict[str, dict[str, Any]] = {}
current_chat_session_id = session_id
if current_chat_session_id is None and conversation_id:
current_chat_session_id = f"chat:{conversation_id}"
if "current_chat" in scope and conversation_id: if "current_chat" in scope and current_chat_session_id:
payload = self._search_payload( payload = self._search_payload(
user_id=user_id, user_id=user_id,
agent_id=agent_id, agent_id=agent_id,
@ -530,7 +534,7 @@ class MemoryGatewayService:
project_id=project_id, project_id=project_id,
filters=_combine_filters( filters=_combine_filters(
filters, filters,
{"session_id": f"chat:{conversation_id}"}, {"session_id": current_chat_session_id},
), ),
) )
results.extend( results.extend(

View File

@ -1181,6 +1181,63 @@ async def test_search_rejects_invalid_upstream_options(
assert backend.search_calls == [] assert backend.search_calls == []
@pytest.mark.asyncio
async def test_search_current_chat_accepts_session_id(
config: GatewayConfig,
) -> None:
backend = FakeBackendClient()
async with app_client(config, backend) as client:
user_key = await create_user(client)
response = await client.post(
"/memories/search",
json={
"user_id": "u_123",
"user_key": user_key,
"session_id": "chat:c_1",
"query": "hello",
"scope": ["current_chat"],
},
)
assert response.status_code == 200, response.text
assert backend.search_calls == [
{
"user_id": "u_123",
"query": "hello",
"method": "hybrid",
"top_k": 8,
"include_profile": True,
"enable_llm_rerank": True,
"app_id": "default",
"project_id": "default",
"filters": {"session_id": "chat:c_1"},
}
]
@pytest.mark.asyncio
async def test_search_rejects_conflicting_session_and_conversation_ids(
config: GatewayConfig,
) -> None:
backend = FakeBackendClient()
async with app_client(config, backend) as client:
user_key = await create_user(client)
response = await client.post(
"/memories/search",
json={
"user_id": "u_123",
"user_key": user_key,
"session_id": "chat:c_1",
"conversation_id": "c_2",
"query": "hello",
"scope": ["current_chat"],
},
)
assert response.status_code == 422, response.text
assert backend.search_calls == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_combines_custom_and_scope_filters( async def test_search_combines_custom_and_scope_filters(
config: GatewayConfig, config: GatewayConfig,