Files

467 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import mimetypes
import os
import sys
import uuid
from pathlib import Path
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode
from urllib.request import Request, urlopen
class GatewayError(RuntimeError):
pass
class Settings:
def __init__(
self,
base_url: str,
user_id: str | None,
user_key: str | None,
timeout: float = 120.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.user_id = user_id
self.user_key = user_key
self.timeout = timeout
@classmethod
def from_env(cls) -> Settings:
return cls(
base_url=os.environ.get(
"MEMORY_GATEWAY_BASE_URL",
"http://127.0.0.1:8010",
),
user_id=os.environ.get("MEMORY_GATEWAY_USER_ID"),
user_key=os.environ.get("MEMORY_GATEWAY_USER_KEY"),
timeout=float(os.environ.get("MEMORY_GATEWAY_TIMEOUT_SECONDS", "120")),
)
class MemoryGatewayClient:
def __init__(
self,
base_url: str,
*,
user_id: str | None = None,
user_key: str | None = None,
timeout: float = 120.0,
) -> None:
self.base_url = base_url.rstrip("/")
self.user_id = user_id
self.user_key = user_key
self.timeout = timeout
def _credentials(self) -> dict[str, str]:
if not self.user_id or not self.user_key:
raise GatewayError(
"user credentials are required; set MEMORY_GATEWAY_USER_ID and "
"MEMORY_GATEWAY_USER_KEY or pass --user-id and --user-key"
)
return {"user_id": self.user_id, "user_key": self.user_key}
def _request(
self,
method: str,
path: str,
*,
query: dict[str, Any] | None = None,
json_body: dict[str, Any] | None = None,
body: bytes | None = None,
headers: dict[str, str] | None = None,
) -> dict[str, Any]:
url = f"{self.base_url}{path}"
if query:
url = f"{url}?{urlencode(query, doseq=True)}"
request_headers = dict(headers or {})
request_body = body
if json_body is not None:
request_body = json.dumps(json_body, ensure_ascii=False).encode("utf-8")
request_headers["Content-Type"] = "application/json"
request = Request(
url,
data=request_body,
headers=request_headers,
method=method,
)
try:
with urlopen(request, timeout=self.timeout) as response:
raw = response.read()
except HTTPError as exc:
raw = exc.read()
detail = _error_detail(raw, exc.reason)
raise GatewayError(f"Memory Gateway returned HTTP {exc.code}: {detail}") from exc
except URLError as exc:
raise GatewayError(f"cannot connect to Memory Gateway: {exc.reason}") from exc
if not raw:
return {}
try:
value = json.loads(raw.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError) as exc:
raise GatewayError("Memory Gateway returned a non-JSON response") from exc
if not isinstance(value, dict):
raise GatewayError("Memory Gateway returned an unexpected JSON response")
return value
def health(self) -> dict[str, Any]:
return self._request("GET", "/health")
def create_user(self, user_id: str) -> dict[str, Any]:
return self._request("POST", "/users", json_body={"user_id": user_id})
def upload_resource(
self,
file_path: Path,
*,
app_id: str = "default",
project_id: str = "default",
title: str | None = None,
description: str | None = None,
) -> dict[str, Any]:
if not file_path.is_file():
raise GatewayError(f"upload file does not exist: {file_path}")
fields: dict[str, str] = {
**self._credentials(),
"app_id": app_id,
"project_id": project_id,
}
if title is not None:
fields["title"] = title
if description is not None:
fields["description"] = description
boundary = f"memory-gateway-{uuid.uuid4().hex}"
mime_type = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
body = _multipart_body(
boundary,
fields,
field_name="file",
file_path=file_path,
mime_type=mime_type,
)
return self._request(
"POST",
"/resources",
body=body,
headers={"Content-Type": f"multipart/form-data; boundary={boundary}"},
)
def list_resources(self) -> dict[str, Any]:
return self._request("GET", "/resources", query=self._credentials())
def get_resource(self, resource_id: str) -> dict[str, Any]:
return self._request(
"GET",
f"/resources/{resource_id}",
query=self._credentials(),
)
def delete_resource(self, resource_id: str) -> dict[str, Any]:
return self._request(
"DELETE",
f"/resources/{resource_id}",
query=self._credentials(),
)
def search(
self,
query: str,
*,
conversation_id: str | None = None,
scopes: list[str] | None = None,
top_k: int = 8,
app_id: str = "default",
project_id: str = "default",
) -> dict[str, Any]:
selected_scopes = scopes or (
["current_chat", "resources"] if conversation_id else ["resources"]
)
if "current_chat" in selected_scopes and not conversation_id:
raise GatewayError(
"conversation_id is required when search scope includes current_chat"
)
payload: dict[str, Any] = {
**self._credentials(),
"query": query,
"scope": selected_scopes,
"top_k": top_k,
"app_id": app_id,
"project_id": project_id,
}
if conversation_id is not None:
payload["conversation_id"] = conversation_id
return self._request("POST", "/memories/search", json_body=payload)
def add_memory(
self,
session_id: str,
messages: list[dict[str, Any]],
*,
app_id: str = "default",
project_id: str = "default",
) -> dict[str, Any]:
return self._request(
"POST",
"/memories/add",
json_body={
**self._credentials(),
"session_id": session_id,
"messages": messages,
"app_id": app_id,
"project_id": project_id,
},
)
def flush_memory(
self,
session_id: str,
*,
app_id: str = "default",
project_id: str = "default",
) -> dict[str, Any]:
return self._request(
"POST",
"/memories/flush",
json_body={
**self._credentials(),
"session_id": session_id,
"app_id": app_id,
"project_id": project_id,
},
)
def override_memory(
self,
memory_id: str,
session_id: str,
override_text: str,
) -> dict[str, Any]:
return self._request(
"PATCH",
f"/memories/{memory_id}",
json_body={
**self._credentials(),
"session_id": session_id,
"override_text": override_text,
},
)
def delete_memory(
self,
memory_id: str,
session_id: str,
*,
reason: str | None = None,
) -> dict[str, Any]:
payload: dict[str, Any] = {
**self._credentials(),
"session_id": session_id,
}
if reason is not None:
payload["reason"] = reason
return self._request(
"DELETE",
f"/memories/{memory_id}",
json_body=payload,
)
def _error_detail(raw: bytes, fallback: Any) -> str:
try:
body = json.loads(raw.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError):
return str(fallback)
if isinstance(body, dict) and body.get("detail"):
return str(body["detail"])
return str(fallback)
def _multipart_body(
boundary: str,
fields: dict[str, str],
*,
field_name: str,
file_path: Path,
mime_type: str,
) -> bytes:
marker = boundary.encode("ascii")
chunks: list[bytes] = []
for name, value in fields.items():
chunks.extend(
[
b"--" + marker + b"\r\n",
f'Content-Disposition: form-data; name="{name}"\r\n\r\n'.encode(),
value.encode("utf-8"),
b"\r\n",
]
)
chunks.extend(
[
b"--" + marker + b"\r\n",
(
f'Content-Disposition: form-data; name="{field_name}"; '
f'filename="{file_path.name}"\r\n'
).encode(),
f"Content-Type: {mime_type}\r\n\r\n".encode(),
file_path.read_bytes(),
b"\r\n--" + marker + b"--\r\n",
]
)
return b"".join(chunks)
def _load_json_array(value: str) -> list[dict[str, Any]]:
source = Path(value)
text = source.read_text(encoding="utf-8") if source.is_file() else value
try:
parsed = json.loads(text)
except json.JSONDecodeError as exc:
raise GatewayError(f"invalid messages JSON: {exc}") from exc
if not isinstance(parsed, list) or not all(isinstance(item, dict) for item in parsed):
raise GatewayError("messages JSON must be an array of objects")
return parsed
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Memory Gateway agent CLI")
parser.add_argument("--base-url")
parser.add_argument("--user-id")
parser.add_argument("--user-key")
parser.add_argument("--timeout", type=float)
subparsers = parser.add_subparsers(dest="command", required=True)
subparsers.add_parser("health")
create_user = subparsers.add_parser("create-user")
create_user.add_argument("user_id")
upload = subparsers.add_parser("upload-resource")
upload.add_argument("file", type=Path)
_add_scope_arguments(upload)
upload.add_argument("--title")
upload.add_argument("--description")
subparsers.add_parser("list-resources")
get_resource = subparsers.add_parser("get-resource")
get_resource.add_argument("resource_id")
delete_resource = subparsers.add_parser("delete-resource")
delete_resource.add_argument("resource_id")
search = subparsers.add_parser("search")
search.add_argument("query")
search.add_argument("--conversation-id")
search.add_argument(
"--scope",
action="append",
choices=["current_chat", "resources", "all_user_memory"],
)
search.add_argument("--top-k", type=int, default=8)
_add_scope_arguments(search)
add = subparsers.add_parser("add-memory")
add.add_argument("--session-id", required=True)
add.add_argument(
"--messages",
required=True,
help="JSON array or path to a JSON file containing messages",
)
_add_scope_arguments(add)
flush = subparsers.add_parser("flush-memory")
flush.add_argument("--session-id", required=True)
_add_scope_arguments(flush)
override = subparsers.add_parser("override-memory")
override.add_argument("memory_id")
override.add_argument("--session-id", required=True)
override.add_argument("--text", required=True)
delete_memory = subparsers.add_parser("delete-memory")
delete_memory.add_argument("memory_id")
delete_memory.add_argument("--session-id", required=True)
delete_memory.add_argument("--reason")
return parser
def _add_scope_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--app-id", default="default")
parser.add_argument("--project-id", default="default")
def main(argv: list[str] | None = None) -> int:
settings = Settings.from_env()
args = build_parser().parse_args(argv)
client = MemoryGatewayClient(
args.base_url or settings.base_url,
user_id=args.user_id or settings.user_id,
user_key=args.user_key or settings.user_key,
timeout=args.timeout or settings.timeout,
)
try:
result = _run_command(client, args)
except GatewayError as exc:
print(json.dumps({"error": str(exc)}, ensure_ascii=False), file=sys.stderr)
return 1
print(json.dumps(result, ensure_ascii=False, indent=2))
return 0
def _run_command(client: MemoryGatewayClient, args: argparse.Namespace) -> dict[str, Any]:
if args.command == "health":
return client.health()
if args.command == "create-user":
return client.create_user(args.user_id)
if args.command == "upload-resource":
return client.upload_resource(
args.file,
app_id=args.app_id,
project_id=args.project_id,
title=args.title,
description=args.description,
)
if args.command == "list-resources":
return client.list_resources()
if args.command == "get-resource":
return client.get_resource(args.resource_id)
if args.command == "delete-resource":
return client.delete_resource(args.resource_id)
if args.command == "search":
return client.search(
args.query,
conversation_id=args.conversation_id,
scopes=args.scope,
top_k=args.top_k,
app_id=args.app_id,
project_id=args.project_id,
)
if args.command == "add-memory":
return client.add_memory(
args.session_id,
_load_json_array(args.messages),
app_id=args.app_id,
project_id=args.project_id,
)
if args.command == "flush-memory":
return client.flush_memory(
args.session_id,
app_id=args.app_id,
project_id=args.project_id,
)
if args.command == "override-memory":
return client.override_memory(args.memory_id, args.session_id, args.text)
if args.command == "delete-memory":
return client.delete_memory(
args.memory_id,
args.session_id,
reason=args.reason,
)
raise GatewayError(f"unsupported command: {args.command}")
if __name__ == "__main__":
raise SystemExit(main())