Merge pull request 'feat: support session_id in current chat search' (#1) from codex/support-search-session-id into main
Reviewed-on: #1
This commit is contained in:
@ -486,7 +486,7 @@ curl -X POST http://127.0.0.1:8010/memories/search \
|
||||
-d '{
|
||||
"user_id": "u_123",
|
||||
"user_key": "uk_xxx",
|
||||
"conversation_id": "c_456",
|
||||
"session_id": "chat:c_456",
|
||||
"query": "图片里的蓝色圆形在哪里?",
|
||||
"scope": ["current_chat", "resources"],
|
||||
"method": "hybrid",
|
||||
@ -502,7 +502,7 @@ curl -X POST http://127.0.0.1:8010/memories/search \
|
||||
|
||||
| scope | 行为 |
|
||||
|---|---|
|
||||
| `current_chat` | 搜索 `chat:{conversation_id}`,需要传 `conversation_id` |
|
||||
| `current_chat` | 搜索指定聊天 session,推荐传完整 `session_id`,如 `chat:c_456`;也兼容旧字段 `conversation_id`,会转换为 `chat:{conversation_id}` |
|
||||
| `resources` | 搜索当前用户已提取且未删除的资源 session |
|
||||
| `all_user_memory` | 搜索用户全部记忆,不加 session 过滤 |
|
||||
|
||||
|
||||
14
core/api.py
14
core/api.py
@ -10,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
||||
import httpx
|
||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||
from starlette.responses import Response
|
||||
|
||||
@ -43,6 +43,7 @@ class SearchMemoriesRequest(BaseModel):
|
||||
user_id: str = Field(min_length=1)
|
||||
user_key: str = Field(min_length=1)
|
||||
agent_id: str | None = Field(default=None, min_length=1)
|
||||
session_id: str | None = Field(default=None, min_length=1)
|
||||
conversation_id: str | None = None
|
||||
query: str = Field(min_length=1)
|
||||
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
|
||||
@ -64,6 +65,16 @@ class SearchMemoriesRequest(BaseModel):
|
||||
raise ValueError("top_k must be -1 or in 1..100")
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_current_chat_session_aliases(self) -> SearchMemoriesRequest:
|
||||
if self.session_id and self.conversation_id:
|
||||
expected_session_id = f"chat:{self.conversation_id}"
|
||||
if self.session_id != expected_session_id:
|
||||
raise ValueError(
|
||||
"session_id must match chat:{conversation_id} when both are provided"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AddMemoryMessage(BaseModel):
|
||||
sender_id: str = Field(min_length=1)
|
||||
@ -487,6 +498,7 @@ def create_app(
|
||||
user_id=request.user_id,
|
||||
agent_id=request.agent_id,
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
conversation_id=request.conversation_id,
|
||||
scope=request.scope,
|
||||
method=request.method,
|
||||
|
||||
@ -502,6 +502,7 @@ class MemoryGatewayService:
|
||||
user_id: str,
|
||||
agent_id: str | None,
|
||||
query: str,
|
||||
session_id: str | None,
|
||||
conversation_id: str | None,
|
||||
scope: list[str],
|
||||
method: str,
|
||||
@ -515,8 +516,11 @@ class MemoryGatewayService:
|
||||
) -> dict[str, Any]:
|
||||
results: list[dict[str, Any]] = []
|
||||
session_resource_map: dict[str, dict[str, Any]] = {}
|
||||
current_chat_session_id = session_id
|
||||
if current_chat_session_id is None and conversation_id:
|
||||
current_chat_session_id = f"chat:{conversation_id}"
|
||||
|
||||
if "current_chat" in scope and conversation_id:
|
||||
if "current_chat" in scope and current_chat_session_id:
|
||||
payload = self._search_payload(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
@ -530,7 +534,7 @@ class MemoryGatewayService:
|
||||
project_id=project_id,
|
||||
filters=_combine_filters(
|
||||
filters,
|
||||
{"session_id": f"chat:{conversation_id}"},
|
||||
{"session_id": current_chat_session_id},
|
||||
),
|
||||
)
|
||||
results.extend(
|
||||
|
||||
@ -1181,6 +1181,63 @@ async def test_search_rejects_invalid_upstream_options(
|
||||
assert backend.search_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_current_chat_accepts_session_id(
|
||||
config: GatewayConfig,
|
||||
) -> None:
|
||||
backend = FakeBackendClient()
|
||||
async with app_client(config, backend) as client:
|
||||
user_key = await create_user(client)
|
||||
response = await client.post(
|
||||
"/memories/search",
|
||||
json={
|
||||
"user_id": "u_123",
|
||||
"user_key": user_key,
|
||||
"session_id": "chat:c_1",
|
||||
"query": "hello",
|
||||
"scope": ["current_chat"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
assert backend.search_calls == [
|
||||
{
|
||||
"user_id": "u_123",
|
||||
"query": "hello",
|
||||
"method": "hybrid",
|
||||
"top_k": 8,
|
||||
"include_profile": True,
|
||||
"enable_llm_rerank": True,
|
||||
"app_id": "default",
|
||||
"project_id": "default",
|
||||
"filters": {"session_id": "chat:c_1"},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_rejects_conflicting_session_and_conversation_ids(
|
||||
config: GatewayConfig,
|
||||
) -> None:
|
||||
backend = FakeBackendClient()
|
||||
async with app_client(config, backend) as client:
|
||||
user_key = await create_user(client)
|
||||
response = await client.post(
|
||||
"/memories/search",
|
||||
json={
|
||||
"user_id": "u_123",
|
||||
"user_key": user_key,
|
||||
"session_id": "chat:c_1",
|
||||
"conversation_id": "c_2",
|
||||
"query": "hello",
|
||||
"scope": ["current_chat"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 422, response.text
|
||||
assert backend.search_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_combines_custom_and_scope_filters(
|
||||
config: GatewayConfig,
|
||||
|
||||
Reference in New Issue
Block a user