diff --git a/README.md b/README.md index b0f81e1..c0967d5 100644 --- a/README.md +++ b/README.md @@ -486,7 +486,7 @@ curl -X POST http://127.0.0.1:8010/memories/search \ -d '{ "user_id": "u_123", "user_key": "uk_xxx", - "conversation_id": "c_456", + "session_id": "chat:c_456", "query": "图片里的蓝色圆形在哪里?", "scope": ["current_chat", "resources"], "method": "hybrid", @@ -502,7 +502,7 @@ curl -X POST http://127.0.0.1:8010/memories/search \ | scope | 行为 | |---|---| -| `current_chat` | 搜索 `chat:{conversation_id}`,需要传 `conversation_id` | +| `current_chat` | 搜索指定聊天 session,推荐传完整 `session_id`,如 `chat:c_456`;也兼容旧字段 `conversation_id`,会转换为 `chat:{conversation_id}` | | `resources` | 搜索当前用户已提取且未删除的资源 session | | `all_user_memory` | 搜索用户全部记忆,不加 session 过滤 | diff --git a/core/api.py b/core/api.py index baec889..806366c 100644 --- a/core/api.py +++ b/core/api.py @@ -10,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit import httpx from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile from pydantic import ValidationError -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from starlette.datastructures import UploadFile as StarletteUploadFile from starlette.responses import Response @@ -43,6 +43,7 @@ class SearchMemoriesRequest(BaseModel): user_id: str = Field(min_length=1) user_key: str = Field(min_length=1) agent_id: str | None = Field(default=None, min_length=1) + session_id: str | None = Field(default=None, min_length=1) conversation_id: str | None = None query: str = Field(min_length=1) scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field( @@ -64,6 +65,16 @@ class SearchMemoriesRequest(BaseModel): raise ValueError("top_k must be -1 or in 1..100") return value + @model_validator(mode="after") + def validate_current_chat_session_aliases(self) -> SearchMemoriesRequest: + if self.session_id and self.conversation_id: + expected_session_id = f"chat:{self.conversation_id}" + if self.session_id != expected_session_id: + raise ValueError( + "session_id must match chat:{conversation_id} when both are provided" + ) + return self + class AddMemoryMessage(BaseModel): sender_id: str = Field(min_length=1) @@ -487,6 +498,7 @@ def create_app( user_id=request.user_id, agent_id=request.agent_id, query=request.query, + session_id=request.session_id, conversation_id=request.conversation_id, scope=request.scope, method=request.method, diff --git a/core/service.py b/core/service.py index cebbce8..3169334 100644 --- a/core/service.py +++ b/core/service.py @@ -502,6 +502,7 @@ class MemoryGatewayService: user_id: str, agent_id: str | None, query: str, + session_id: str | None, conversation_id: str | None, scope: list[str], method: str, @@ -515,8 +516,11 @@ class MemoryGatewayService: ) -> dict[str, Any]: results: list[dict[str, Any]] = [] session_resource_map: dict[str, dict[str, Any]] = {} + current_chat_session_id = session_id + if current_chat_session_id is None and conversation_id: + current_chat_session_id = f"chat:{conversation_id}" - if "current_chat" in scope and conversation_id: + if "current_chat" in scope and current_chat_session_id: payload = self._search_payload( user_id=user_id, agent_id=agent_id, @@ -530,7 +534,7 @@ class MemoryGatewayService: project_id=project_id, filters=_combine_filters( filters, - {"session_id": f"chat:{conversation_id}"}, + {"session_id": current_chat_session_id}, ), ) results.extend( diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 47d7b54..41e141a 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -1181,6 +1181,63 @@ async def test_search_rejects_invalid_upstream_options( assert backend.search_calls == [] +@pytest.mark.asyncio +async def test_search_current_chat_accepts_session_id( + config: GatewayConfig, +) -> None: + backend = FakeBackendClient() + async with app_client(config, backend) as client: + user_key = await create_user(client) + response = await client.post( + "/memories/search", + json={ + "user_id": "u_123", + "user_key": user_key, + "session_id": "chat:c_1", + "query": "hello", + "scope": ["current_chat"], + }, + ) + + assert response.status_code == 200, response.text + assert backend.search_calls == [ + { + "user_id": "u_123", + "query": "hello", + "method": "hybrid", + "top_k": 8, + "include_profile": True, + "enable_llm_rerank": True, + "app_id": "default", + "project_id": "default", + "filters": {"session_id": "chat:c_1"}, + } + ] + + +@pytest.mark.asyncio +async def test_search_rejects_conflicting_session_and_conversation_ids( + config: GatewayConfig, +) -> None: + backend = FakeBackendClient() + async with app_client(config, backend) as client: + user_key = await create_user(client) + response = await client.post( + "/memories/search", + json={ + "user_id": "u_123", + "user_key": user_key, + "session_id": "chat:c_1", + "conversation_id": "c_2", + "query": "hello", + "scope": ["current_chat"], + }, + ) + + assert response.status_code == 422, response.text + assert backend.search_calls == [] + + @pytest.mark.asyncio async def test_search_combines_custom_and_scope_filters( config: GatewayConfig,