467 lines
14 KiB
Python
Executable File
467 lines
14 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import mimetypes
|
|
import os
|
|
import sys
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib.error import HTTPError, URLError
|
|
from urllib.parse import urlencode
|
|
from urllib.request import Request, urlopen
|
|
|
|
|
|
class GatewayError(RuntimeError):
|
|
pass
|
|
|
|
|
|
class Settings:
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
user_id: str | None,
|
|
user_key: str | None,
|
|
timeout: float = 120.0,
|
|
) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.user_id = user_id
|
|
self.user_key = user_key
|
|
self.timeout = timeout
|
|
|
|
@classmethod
|
|
def from_env(cls) -> Settings:
|
|
return cls(
|
|
base_url=os.environ.get(
|
|
"MEMORY_GATEWAY_BASE_URL",
|
|
"http://127.0.0.1:8010",
|
|
),
|
|
user_id=os.environ.get("MEMORY_GATEWAY_USER_ID"),
|
|
user_key=os.environ.get("MEMORY_GATEWAY_USER_KEY"),
|
|
timeout=float(os.environ.get("MEMORY_GATEWAY_TIMEOUT_SECONDS", "120")),
|
|
)
|
|
|
|
|
|
class MemoryGatewayClient:
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
*,
|
|
user_id: str | None = None,
|
|
user_key: str | None = None,
|
|
timeout: float = 120.0,
|
|
) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.user_id = user_id
|
|
self.user_key = user_key
|
|
self.timeout = timeout
|
|
|
|
def _credentials(self) -> dict[str, str]:
|
|
if not self.user_id or not self.user_key:
|
|
raise GatewayError(
|
|
"user credentials are required; set MEMORY_GATEWAY_USER_ID and "
|
|
"MEMORY_GATEWAY_USER_KEY or pass --user-id and --user-key"
|
|
)
|
|
return {"user_id": self.user_id, "user_key": self.user_key}
|
|
|
|
def _request(
|
|
self,
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
query: dict[str, Any] | None = None,
|
|
json_body: dict[str, Any] | None = None,
|
|
body: bytes | None = None,
|
|
headers: dict[str, str] | None = None,
|
|
) -> dict[str, Any]:
|
|
url = f"{self.base_url}{path}"
|
|
if query:
|
|
url = f"{url}?{urlencode(query, doseq=True)}"
|
|
request_headers = dict(headers or {})
|
|
request_body = body
|
|
if json_body is not None:
|
|
request_body = json.dumps(json_body, ensure_ascii=False).encode("utf-8")
|
|
request_headers["Content-Type"] = "application/json"
|
|
request = Request(
|
|
url,
|
|
data=request_body,
|
|
headers=request_headers,
|
|
method=method,
|
|
)
|
|
try:
|
|
with urlopen(request, timeout=self.timeout) as response:
|
|
raw = response.read()
|
|
except HTTPError as exc:
|
|
raw = exc.read()
|
|
detail = _error_detail(raw, exc.reason)
|
|
raise GatewayError(f"Memory Gateway returned HTTP {exc.code}: {detail}") from exc
|
|
except URLError as exc:
|
|
raise GatewayError(f"cannot connect to Memory Gateway: {exc.reason}") from exc
|
|
if not raw:
|
|
return {}
|
|
try:
|
|
value = json.loads(raw.decode("utf-8"))
|
|
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
|
|
raise GatewayError("Memory Gateway returned a non-JSON response") from exc
|
|
if not isinstance(value, dict):
|
|
raise GatewayError("Memory Gateway returned an unexpected JSON response")
|
|
return value
|
|
|
|
def health(self) -> dict[str, Any]:
|
|
return self._request("GET", "/health")
|
|
|
|
def create_user(self, user_id: str) -> dict[str, Any]:
|
|
return self._request("POST", "/users", json_body={"user_id": user_id})
|
|
|
|
def upload_resource(
|
|
self,
|
|
file_path: Path,
|
|
*,
|
|
app_id: str = "default",
|
|
project_id: str = "default",
|
|
title: str | None = None,
|
|
description: str | None = None,
|
|
) -> dict[str, Any]:
|
|
if not file_path.is_file():
|
|
raise GatewayError(f"upload file does not exist: {file_path}")
|
|
fields: dict[str, str] = {
|
|
**self._credentials(),
|
|
"app_id": app_id,
|
|
"project_id": project_id,
|
|
}
|
|
if title is not None:
|
|
fields["title"] = title
|
|
if description is not None:
|
|
fields["description"] = description
|
|
boundary = f"memory-gateway-{uuid.uuid4().hex}"
|
|
mime_type = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
|
|
body = _multipart_body(
|
|
boundary,
|
|
fields,
|
|
field_name="file",
|
|
file_path=file_path,
|
|
mime_type=mime_type,
|
|
)
|
|
return self._request(
|
|
"POST",
|
|
"/resources",
|
|
body=body,
|
|
headers={"Content-Type": f"multipart/form-data; boundary={boundary}"},
|
|
)
|
|
|
|
def list_resources(self) -> dict[str, Any]:
|
|
return self._request("GET", "/resources", query=self._credentials())
|
|
|
|
def get_resource(self, resource_id: str) -> dict[str, Any]:
|
|
return self._request(
|
|
"GET",
|
|
f"/resources/{resource_id}",
|
|
query=self._credentials(),
|
|
)
|
|
|
|
def delete_resource(self, resource_id: str) -> dict[str, Any]:
|
|
return self._request(
|
|
"DELETE",
|
|
f"/resources/{resource_id}",
|
|
query=self._credentials(),
|
|
)
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
*,
|
|
conversation_id: str | None = None,
|
|
scopes: list[str] | None = None,
|
|
top_k: int = 8,
|
|
app_id: str = "default",
|
|
project_id: str = "default",
|
|
) -> dict[str, Any]:
|
|
selected_scopes = scopes or (
|
|
["current_chat", "resources"] if conversation_id else ["resources"]
|
|
)
|
|
if "current_chat" in selected_scopes and not conversation_id:
|
|
raise GatewayError(
|
|
"conversation_id is required when search scope includes current_chat"
|
|
)
|
|
payload: dict[str, Any] = {
|
|
**self._credentials(),
|
|
"query": query,
|
|
"scope": selected_scopes,
|
|
"top_k": top_k,
|
|
"app_id": app_id,
|
|
"project_id": project_id,
|
|
}
|
|
if conversation_id is not None:
|
|
payload["conversation_id"] = conversation_id
|
|
return self._request("POST", "/memories/search", json_body=payload)
|
|
|
|
def add_memory(
|
|
self,
|
|
session_id: str,
|
|
messages: list[dict[str, Any]],
|
|
*,
|
|
app_id: str = "default",
|
|
project_id: str = "default",
|
|
) -> dict[str, Any]:
|
|
return self._request(
|
|
"POST",
|
|
"/memories/add",
|
|
json_body={
|
|
**self._credentials(),
|
|
"session_id": session_id,
|
|
"messages": messages,
|
|
"app_id": app_id,
|
|
"project_id": project_id,
|
|
},
|
|
)
|
|
|
|
def flush_memory(
|
|
self,
|
|
session_id: str,
|
|
*,
|
|
app_id: str = "default",
|
|
project_id: str = "default",
|
|
) -> dict[str, Any]:
|
|
return self._request(
|
|
"POST",
|
|
"/memories/flush",
|
|
json_body={
|
|
**self._credentials(),
|
|
"session_id": session_id,
|
|
"app_id": app_id,
|
|
"project_id": project_id,
|
|
},
|
|
)
|
|
|
|
def override_memory(
|
|
self,
|
|
memory_id: str,
|
|
session_id: str,
|
|
override_text: str,
|
|
) -> dict[str, Any]:
|
|
return self._request(
|
|
"PATCH",
|
|
f"/memories/{memory_id}",
|
|
json_body={
|
|
**self._credentials(),
|
|
"session_id": session_id,
|
|
"override_text": override_text,
|
|
},
|
|
)
|
|
|
|
def delete_memory(
|
|
self,
|
|
memory_id: str,
|
|
session_id: str,
|
|
*,
|
|
reason: str | None = None,
|
|
) -> dict[str, Any]:
|
|
payload: dict[str, Any] = {
|
|
**self._credentials(),
|
|
"session_id": session_id,
|
|
}
|
|
if reason is not None:
|
|
payload["reason"] = reason
|
|
return self._request(
|
|
"DELETE",
|
|
f"/memories/{memory_id}",
|
|
json_body=payload,
|
|
)
|
|
|
|
|
|
def _error_detail(raw: bytes, fallback: Any) -> str:
|
|
try:
|
|
body = json.loads(raw.decode("utf-8"))
|
|
except (UnicodeDecodeError, json.JSONDecodeError):
|
|
return str(fallback)
|
|
if isinstance(body, dict) and body.get("detail"):
|
|
return str(body["detail"])
|
|
return str(fallback)
|
|
|
|
|
|
def _multipart_body(
|
|
boundary: str,
|
|
fields: dict[str, str],
|
|
*,
|
|
field_name: str,
|
|
file_path: Path,
|
|
mime_type: str,
|
|
) -> bytes:
|
|
marker = boundary.encode("ascii")
|
|
chunks: list[bytes] = []
|
|
for name, value in fields.items():
|
|
chunks.extend(
|
|
[
|
|
b"--" + marker + b"\r\n",
|
|
f'Content-Disposition: form-data; name="{name}"\r\n\r\n'.encode(),
|
|
value.encode("utf-8"),
|
|
b"\r\n",
|
|
]
|
|
)
|
|
chunks.extend(
|
|
[
|
|
b"--" + marker + b"\r\n",
|
|
(
|
|
f'Content-Disposition: form-data; name="{field_name}"; '
|
|
f'filename="{file_path.name}"\r\n'
|
|
).encode(),
|
|
f"Content-Type: {mime_type}\r\n\r\n".encode(),
|
|
file_path.read_bytes(),
|
|
b"\r\n--" + marker + b"--\r\n",
|
|
]
|
|
)
|
|
return b"".join(chunks)
|
|
|
|
|
|
def _load_json_array(value: str) -> list[dict[str, Any]]:
|
|
source = Path(value)
|
|
text = source.read_text(encoding="utf-8") if source.is_file() else value
|
|
try:
|
|
parsed = json.loads(text)
|
|
except json.JSONDecodeError as exc:
|
|
raise GatewayError(f"invalid messages JSON: {exc}") from exc
|
|
if not isinstance(parsed, list) or not all(isinstance(item, dict) for item in parsed):
|
|
raise GatewayError("messages JSON must be an array of objects")
|
|
return parsed
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description="Memory Gateway agent CLI")
|
|
parser.add_argument("--base-url")
|
|
parser.add_argument("--user-id")
|
|
parser.add_argument("--user-key")
|
|
parser.add_argument("--timeout", type=float)
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
subparsers.add_parser("health")
|
|
create_user = subparsers.add_parser("create-user")
|
|
create_user.add_argument("user_id")
|
|
|
|
upload = subparsers.add_parser("upload-resource")
|
|
upload.add_argument("file", type=Path)
|
|
_add_scope_arguments(upload)
|
|
upload.add_argument("--title")
|
|
upload.add_argument("--description")
|
|
|
|
subparsers.add_parser("list-resources")
|
|
get_resource = subparsers.add_parser("get-resource")
|
|
get_resource.add_argument("resource_id")
|
|
delete_resource = subparsers.add_parser("delete-resource")
|
|
delete_resource.add_argument("resource_id")
|
|
|
|
search = subparsers.add_parser("search")
|
|
search.add_argument("query")
|
|
search.add_argument("--conversation-id")
|
|
search.add_argument(
|
|
"--scope",
|
|
action="append",
|
|
choices=["current_chat", "resources", "all_user_memory"],
|
|
)
|
|
search.add_argument("--top-k", type=int, default=8)
|
|
_add_scope_arguments(search)
|
|
|
|
add = subparsers.add_parser("add-memory")
|
|
add.add_argument("--session-id", required=True)
|
|
add.add_argument(
|
|
"--messages",
|
|
required=True,
|
|
help="JSON array or path to a JSON file containing messages",
|
|
)
|
|
_add_scope_arguments(add)
|
|
|
|
flush = subparsers.add_parser("flush-memory")
|
|
flush.add_argument("--session-id", required=True)
|
|
_add_scope_arguments(flush)
|
|
|
|
override = subparsers.add_parser("override-memory")
|
|
override.add_argument("memory_id")
|
|
override.add_argument("--session-id", required=True)
|
|
override.add_argument("--text", required=True)
|
|
|
|
delete_memory = subparsers.add_parser("delete-memory")
|
|
delete_memory.add_argument("memory_id")
|
|
delete_memory.add_argument("--session-id", required=True)
|
|
delete_memory.add_argument("--reason")
|
|
return parser
|
|
|
|
|
|
def _add_scope_arguments(parser: argparse.ArgumentParser) -> None:
|
|
parser.add_argument("--app-id", default="default")
|
|
parser.add_argument("--project-id", default="default")
|
|
|
|
|
|
def main(argv: list[str] | None = None) -> int:
|
|
settings = Settings.from_env()
|
|
args = build_parser().parse_args(argv)
|
|
client = MemoryGatewayClient(
|
|
args.base_url or settings.base_url,
|
|
user_id=args.user_id or settings.user_id,
|
|
user_key=args.user_key or settings.user_key,
|
|
timeout=args.timeout or settings.timeout,
|
|
)
|
|
try:
|
|
result = _run_command(client, args)
|
|
except GatewayError as exc:
|
|
print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr)
|
|
return 1
|
|
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
return 0
|
|
|
|
|
|
def _run_command(client: MemoryGatewayClient, args: argparse.Namespace) -> dict[str, Any]:
|
|
if args.command == "health":
|
|
return client.health()
|
|
if args.command == "create-user":
|
|
return client.create_user(args.user_id)
|
|
if args.command == "upload-resource":
|
|
return client.upload_resource(
|
|
args.file,
|
|
app_id=args.app_id,
|
|
project_id=args.project_id,
|
|
title=args.title,
|
|
description=args.description,
|
|
)
|
|
if args.command == "list-resources":
|
|
return client.list_resources()
|
|
if args.command == "get-resource":
|
|
return client.get_resource(args.resource_id)
|
|
if args.command == "delete-resource":
|
|
return client.delete_resource(args.resource_id)
|
|
if args.command == "search":
|
|
return client.search(
|
|
args.query,
|
|
conversation_id=args.conversation_id,
|
|
scopes=args.scope,
|
|
top_k=args.top_k,
|
|
app_id=args.app_id,
|
|
project_id=args.project_id,
|
|
)
|
|
if args.command == "add-memory":
|
|
return client.add_memory(
|
|
args.session_id,
|
|
_load_json_array(args.messages),
|
|
app_id=args.app_id,
|
|
project_id=args.project_id,
|
|
)
|
|
if args.command == "flush-memory":
|
|
return client.flush_memory(
|
|
args.session_id,
|
|
app_id=args.app_id,
|
|
project_id=args.project_id,
|
|
)
|
|
if args.command == "override-memory":
|
|
return client.override_memory(args.memory_id, args.session_id, args.text)
|
|
if args.command == "delete-memory":
|
|
return client.delete_memory(
|
|
args.memory_id,
|
|
args.session_id,
|
|
reason=args.reason,
|
|
)
|
|
raise GatewayError(f"unsupported command: {args.command}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|