"""OpenViking client wrapper used by Memory Gateway.""" from __future__ import annotations import json import logging import mimetypes import tempfile from pathlib import Path from typing import Any, Optional import httpx from .config import get_config from .types import MemoryEntry, ResourceEntry, SearchResult logger = logging.getLogger(__name__) class OpenVikingClient: """Thin async client for the OpenViking HTTP API.""" def __init__( self, base_url: Optional[str] = None, api_key: Optional[str] = None, timeout: int = 30, account: str = "default", user: str = "default", ): self.config = get_config() self.base_url = base_url or self.config.openviking.url self.api_key = api_key or self.config.openviking.api_key or "your-secret-root-key" self.timeout = timeout self.account = account self.user = user self._client: Optional[httpx.AsyncClient] = None def _get_headers(self) -> dict[str, str]: headers = {} if self.api_key: headers["X-API-Key"] = self.api_key headers["X-OpenViking-Account"] = self.account headers["X-OpenViking-User"] = self.user return headers async def _get_client(self) -> httpx.AsyncClient: if self._client is None: self._client = httpx.AsyncClient( base_url=self.base_url, headers=self._get_headers(), timeout=self.timeout, ) return self._client async def close(self): if self._client: await self._client.aclose() self._client = None async def health_check(self) -> dict[str, Any]: client = await self._get_client() try: response = await client.get("/health") response.raise_for_status() return response.json() except httpx.HTTPError as e: logger.error(f"OpenViking 健康检查失败: {e}") return {"status": "error", "message": str(e)} async def search( self, query: str, namespace: Optional[str] = None, limit: Optional[int] = None, uri: Optional[str] = None, ) -> SearchResult: """Semantic search against OpenViking resources/memories.""" client = await self._get_client() payload: dict[str, Any] = {"query": query} if limit: payload["limit"] = limit if uri: payload["uri"] = uri elif namespace: payload["uri"] = f"viking://{namespace}" try: response = await client.post("/api/v1/search/search", json=payload) response.raise_for_status() data = response.json() if data.get("status") != "ok": logger.warning(f"搜索返回错误: {data.get('error')}") return SearchResult(results=[], total=0) result = data.get("result", {}) memories = result.get("memories", []) resources = result.get("resources", []) all_results = [] for m in memories + resources: all_results.append( { "uri": m.get("uri"), "abstract": m.get("abstract"), "score": m.get("score"), "context_type": m.get("context_type"), } ) return SearchResult(results=all_results, total=result.get("total", len(all_results))) except httpx.HTTPError as e: logger.error(f"搜索失败: {e}") return SearchResult(results=[], total=0) async def add_memory( self, content: str, namespace: Optional[str] = None, memory_type: str = "general", ) -> dict[str, Any]: """Add memory via session commit flow.""" client = await self._get_client() ns = namespace or self.config.memory.default_namespace or "user/default/memories" try: response = await client.post("/api/v1/sessions", json={"mode": "interactive"}) response.raise_for_status() session_data = response.json() if session_data.get("status") != "ok": return session_data session_id = session_data["result"]["session_id"] commit_response = await client.post( f"/api/v1/sessions/{session_id}/commit", json={ "messages": [ { "role": "user", "content": f"[{ns}/{memory_type}] {content}", } ] }, ) commit_response.raise_for_status() return commit_response.json() except httpx.HTTPError as e: logger.error(f"添加记忆失败: {e}") raise async def _upload_temp_file(self, file_path: str | Path) -> str: client = await self._get_client() file_path = Path(file_path) mime_type = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream" with file_path.open("rb") as f: response = await client.post( "/api/v1/resources/temp_upload", files={"file": (file_path.name, f, mime_type)}, ) response.raise_for_status() data = response.json() result = data.get("result", {}) if "temp_path" in result: return result["temp_path"] if "temp_file_id" in result: return result["temp_file_id"] raise KeyError(f"Unexpected temp upload response: {data}") async def add_resource( self, uri: str, content: str, resource_type: str = "text", wait: bool = False, ) -> dict[str, Any]: """Add a text/json resource by uploading a temporary file first. OpenViking HTTP API does not accept raw `uri + content` directly. The client must upload a temp file and then create the resource with `to`. """ client = await self._get_client() suffix_map = { "json": ".json", "text": ".txt", "markdown": ".md", "md": ".md", } suffix = suffix_map.get(resource_type, ".txt") with tempfile.NamedTemporaryFile("w", encoding="utf-8", suffix=suffix, delete=False) as tmp: tmp.write(content) tmp_path = Path(tmp.name) try: temp_ref = await self._upload_temp_file(tmp_path) payload = { "temp_path": temp_ref, "to": uri, "wait": wait, "source_name": Path(uri).name or tmp_path.name, "strict": False, } response = await client.post("/api/v1/resources", json=payload) if response.status_code >= 400: logger.error("添加资源失败响应: %s", response.text) response.raise_for_status() return response.json() except httpx.HTTPError as e: logger.error(f"添加资源失败: {e}") raise finally: tmp_path.unlink(missing_ok=True) async def list_memories( self, namespace: Optional[str] = None, memory_type: Optional[str] = None, limit: Optional[int] = None, ) -> list[MemoryEntry]: client = await self._get_client() ns = namespace or "user/default/memories" if memory_type: ns = f"{ns}/{memory_type}" try: response = await client.post( "/api/v1/search/search", json={"query": "", "uri": f"viking://{ns}", "limit": limit or 10}, ) response.raise_for_status() data = response.json() if data.get("status") == "ok": result = data.get("result", {}) memories = result.get("memories", []) return [ MemoryEntry( id=m.get("uri", ""), content=m.get("abstract", ""), namespace=ns, memory_type=memory_type or "general", ) for m in memories ] return [] except httpx.HTTPError as e: logger.error(f"列出记忆失败: {e}") return [] async def list_resources( self, namespace: Optional[str] = None, limit: Optional[int] = None, ) -> list[ResourceEntry]: client = await self._get_client() uri = f"viking://{namespace}" if namespace else "viking://resources" try: response = await client.post( "/api/v1/search/search", json={"query": "", "uri": uri, "limit": limit or 10}, ) response.raise_for_status() data = response.json() if data.get("status") == "ok": result = data.get("result", {}) resources = result.get("resources", []) return [ ResourceEntry( uri=r.get("uri", ""), content=r.get("abstract", ""), resource_type="text", ) for r in resources ] return [] except httpx.HTTPError as e: logger.error(f"列出资源失败: {e}") return [] _client: Optional[OpenVikingClient] = None async def get_openviking_client() -> OpenVikingClient: global _client if _client is None: _client = OpenVikingClient() return _client async def close_openviking_client(): global _client if _client: await _client.close() _client = None