feat: support session_id in current chat search #1
@ -486,7 +486,7 @@ curl -X POST http://127.0.0.1:8010/memories/search \
|
|||||||
-d '{
|
-d '{
|
||||||
"user_id": "u_123",
|
"user_id": "u_123",
|
||||||
"user_key": "uk_xxx",
|
"user_key": "uk_xxx",
|
||||||
"conversation_id": "c_456",
|
"session_id": "chat:c_456",
|
||||||
"query": "图片里的蓝色圆形在哪里?",
|
"query": "图片里的蓝色圆形在哪里?",
|
||||||
"scope": ["current_chat", "resources"],
|
"scope": ["current_chat", "resources"],
|
||||||
"method": "hybrid",
|
"method": "hybrid",
|
||||||
@ -502,7 +502,7 @@ curl -X POST http://127.0.0.1:8010/memories/search \
|
|||||||
|
|
||||||
| scope | 行为 |
|
| scope | 行为 |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `current_chat` | 搜索 `chat:{conversation_id}`,需要传 `conversation_id` |
|
| `current_chat` | 搜索指定聊天 session,推荐传完整 `session_id`,如 `chat:c_456`;也兼容旧字段 `conversation_id`,会转换为 `chat:{conversation_id}` |
|
||||||
| `resources` | 搜索当前用户已提取且未删除的资源 session |
|
| `resources` | 搜索当前用户已提取且未删除的资源 session |
|
||||||
| `all_user_memory` | 搜索用户全部记忆,不加 session 过滤 |
|
| `all_user_memory` | 搜索用户全部记忆,不加 session 过滤 |
|
||||||
|
|
||||||
|
|||||||
14
core/api.py
14
core/api.py
@ -10,7 +10,7 @@ from urllib.parse import parse_qsl, quote, urlsplit, urlunsplit
|
|||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
from fastapi import APIRouter, FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from starlette.datastructures import UploadFile as StarletteUploadFile
|
from starlette.datastructures import UploadFile as StarletteUploadFile
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
@ -43,6 +43,7 @@ class SearchMemoriesRequest(BaseModel):
|
|||||||
user_id: str = Field(min_length=1)
|
user_id: str = Field(min_length=1)
|
||||||
user_key: str = Field(min_length=1)
|
user_key: str = Field(min_length=1)
|
||||||
agent_id: str | None = Field(default=None, min_length=1)
|
agent_id: str | None = Field(default=None, min_length=1)
|
||||||
|
session_id: str | None = Field(default=None, min_length=1)
|
||||||
conversation_id: str | None = None
|
conversation_id: str | None = None
|
||||||
query: str = Field(min_length=1)
|
query: str = Field(min_length=1)
|
||||||
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
|
scope: list[Literal["current_chat", "resources", "all_user_memory"]] = Field(
|
||||||
@ -64,6 +65,16 @@ class SearchMemoriesRequest(BaseModel):
|
|||||||
raise ValueError("top_k must be -1 or in 1..100")
|
raise ValueError("top_k must be -1 or in 1..100")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_current_chat_session_aliases(self) -> SearchMemoriesRequest:
|
||||||
|
if self.session_id and self.conversation_id:
|
||||||
|
expected_session_id = f"chat:{self.conversation_id}"
|
||||||
|
if self.session_id != expected_session_id:
|
||||||
|
raise ValueError(
|
||||||
|
"session_id must match chat:{conversation_id} when both are provided"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class AddMemoryMessage(BaseModel):
|
class AddMemoryMessage(BaseModel):
|
||||||
sender_id: str = Field(min_length=1)
|
sender_id: str = Field(min_length=1)
|
||||||
@ -487,6 +498,7 @@ def create_app(
|
|||||||
user_id=request.user_id,
|
user_id=request.user_id,
|
||||||
agent_id=request.agent_id,
|
agent_id=request.agent_id,
|
||||||
query=request.query,
|
query=request.query,
|
||||||
|
session_id=request.session_id,
|
||||||
conversation_id=request.conversation_id,
|
conversation_id=request.conversation_id,
|
||||||
scope=request.scope,
|
scope=request.scope,
|
||||||
method=request.method,
|
method=request.method,
|
||||||
|
|||||||
@ -502,6 +502,7 @@ class MemoryGatewayService:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
agent_id: str | None,
|
agent_id: str | None,
|
||||||
query: str,
|
query: str,
|
||||||
|
session_id: str | None,
|
||||||
conversation_id: str | None,
|
conversation_id: str | None,
|
||||||
scope: list[str],
|
scope: list[str],
|
||||||
method: str,
|
method: str,
|
||||||
@ -515,8 +516,11 @@ class MemoryGatewayService:
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
session_resource_map: dict[str, dict[str, Any]] = {}
|
session_resource_map: dict[str, dict[str, Any]] = {}
|
||||||
|
current_chat_session_id = session_id
|
||||||
|
if current_chat_session_id is None and conversation_id:
|
||||||
|
current_chat_session_id = f"chat:{conversation_id}"
|
||||||
|
|
||||||
if "current_chat" in scope and conversation_id:
|
if "current_chat" in scope and current_chat_session_id:
|
||||||
payload = self._search_payload(
|
payload = self._search_payload(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
@ -530,7 +534,7 @@ class MemoryGatewayService:
|
|||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
filters=_combine_filters(
|
filters=_combine_filters(
|
||||||
filters,
|
filters,
|
||||||
{"session_id": f"chat:{conversation_id}"},
|
{"session_id": current_chat_session_id},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
results.extend(
|
results.extend(
|
||||||
|
|||||||
@ -1181,6 +1181,63 @@ async def test_search_rejects_invalid_upstream_options(
|
|||||||
assert backend.search_calls == []
|
assert backend.search_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_current_chat_accepts_session_id(
|
||||||
|
config: GatewayConfig,
|
||||||
|
) -> None:
|
||||||
|
backend = FakeBackendClient()
|
||||||
|
async with app_client(config, backend) as client:
|
||||||
|
user_key = await create_user(client)
|
||||||
|
response = await client.post(
|
||||||
|
"/memories/search",
|
||||||
|
json={
|
||||||
|
"user_id": "u_123",
|
||||||
|
"user_key": user_key,
|
||||||
|
"session_id": "chat:c_1",
|
||||||
|
"query": "hello",
|
||||||
|
"scope": ["current_chat"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
assert backend.search_calls == [
|
||||||
|
{
|
||||||
|
"user_id": "u_123",
|
||||||
|
"query": "hello",
|
||||||
|
"method": "hybrid",
|
||||||
|
"top_k": 8,
|
||||||
|
"include_profile": True,
|
||||||
|
"enable_llm_rerank": True,
|
||||||
|
"app_id": "default",
|
||||||
|
"project_id": "default",
|
||||||
|
"filters": {"session_id": "chat:c_1"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_rejects_conflicting_session_and_conversation_ids(
|
||||||
|
config: GatewayConfig,
|
||||||
|
) -> None:
|
||||||
|
backend = FakeBackendClient()
|
||||||
|
async with app_client(config, backend) as client:
|
||||||
|
user_key = await create_user(client)
|
||||||
|
response = await client.post(
|
||||||
|
"/memories/search",
|
||||||
|
json={
|
||||||
|
"user_id": "u_123",
|
||||||
|
"user_key": user_key,
|
||||||
|
"session_id": "chat:c_1",
|
||||||
|
"conversation_id": "c_2",
|
||||||
|
"query": "hello",
|
||||||
|
"scope": ["current_chat"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422, response.text
|
||||||
|
assert backend.search_calls == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_combines_custom_and_scope_filters(
|
async def test_search_combines_custom_and_scope_filters(
|
||||||
config: GatewayConfig,
|
config: GatewayConfig,
|
||||||
|
|||||||
Reference in New Issue
Block a user