第一次提交
This commit is contained in:
6
app-instance/backend/nanobot/channels/__init__.py
Normal file
6
app-instance/backend/nanobot/channels/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Chat channels module with plugin architecture."""
|
||||
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
|
||||
__all__ = ["BaseChannel", "ChannelManager"]
|
||||
131
app-instance/backend/nanobot/channels/base.py
Normal file
131
app-instance/backend/nanobot/channels/base.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""Base channel interface for chat platforms."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
class BaseChannel(ABC):
|
||||
"""
|
||||
Abstract base class for chat channel implementations.
|
||||
|
||||
Each channel (Telegram, Discord, etc.) should implement this interface
|
||||
to integrate with the nanobot message bus.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
Initialize the channel.
|
||||
|
||||
Args:
|
||||
config: Channel-specific configuration.
|
||||
bus: The message bus for communication.
|
||||
"""
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self._running = False
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the channel and begin listening for messages.
|
||||
|
||||
This should be a long-running async task that:
|
||||
1. Connects to the chat platform
|
||||
2. Listens for incoming messages
|
||||
3. Forwards messages to the bus via _handle_message()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the channel and clean up resources."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""
|
||||
Send a message through this channel.
|
||||
|
||||
Args:
|
||||
msg: The message to send.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""
|
||||
Check if a sender is allowed to use this bot.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
|
||||
Returns:
|
||||
True if allowed, False otherwise.
|
||||
"""
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
|
||||
# If no allow list, allow everyone
|
||||
if not allow_list:
|
||||
return True
|
||||
|
||||
sender_str = str(sender_id)
|
||||
if sender_str in allow_list:
|
||||
return True
|
||||
if "|" in sender_str:
|
||||
for part in sender_str.split("|"):
|
||||
if part and part in allow_list:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
sender_id: str,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
media: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Handle an incoming message from the chat platform.
|
||||
|
||||
This method checks permissions and forwards to the bus.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
chat_id: The chat/channel identifier.
|
||||
content: Message text content.
|
||||
media: Optional list of media URLs.
|
||||
metadata: Optional channel-specific metadata.
|
||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||
"""
|
||||
if not self.is_allowed(sender_id):
|
||||
logger.warning(
|
||||
"Access denied for sender {} on channel {}. "
|
||||
"Add them to allowFrom list in config to grant access.",
|
||||
sender_id, self.name,
|
||||
)
|
||||
return
|
||||
|
||||
msg = InboundMessage(
|
||||
channel=self.name,
|
||||
sender_id=str(sender_id),
|
||||
chat_id=str(chat_id),
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata=metadata or {},
|
||||
session_key_override=session_key,
|
||||
)
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the channel is running."""
|
||||
return self._running
|
||||
247
app-instance/backend/nanobot/channels/dingtalk.py
Normal file
247
app-instance/backend/nanobot/channels/dingtalk.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""DingTalk/DingDing channel implementation using Stream Mode."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
import httpx
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
try:
|
||||
from dingtalk_stream import (
|
||||
DingTalkStreamClient,
|
||||
Credential,
|
||||
CallbackHandler,
|
||||
CallbackMessage,
|
||||
AckMessage,
|
||||
)
|
||||
from dingtalk_stream.chatbot import ChatbotMessage
|
||||
|
||||
DINGTALK_AVAILABLE = True
|
||||
except ImportError:
|
||||
DINGTALK_AVAILABLE = False
|
||||
# Fallback so class definitions don't crash at module level
|
||||
CallbackHandler = object # type: ignore[assignment,misc]
|
||||
CallbackMessage = None # type: ignore[assignment,misc]
|
||||
AckMessage = None # type: ignore[assignment,misc]
|
||||
ChatbotMessage = None # type: ignore[assignment,misc]
|
||||
|
||||
|
||||
class NanobotDingTalkHandler(CallbackHandler):
|
||||
"""
|
||||
Standard DingTalk Stream SDK Callback Handler.
|
||||
Parses incoming messages and forwards them to the Nanobot channel.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: "DingTalkChannel"):
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
|
||||
async def process(self, message: CallbackMessage):
|
||||
"""Process incoming stream message."""
|
||||
try:
|
||||
# Parse using SDK's ChatbotMessage for robust handling
|
||||
chatbot_msg = ChatbotMessage.from_dict(message.data)
|
||||
|
||||
# Extract text content; fall back to raw dict if SDK object is empty
|
||||
content = ""
|
||||
if chatbot_msg.text:
|
||||
content = chatbot_msg.text.content.strip()
|
||||
if not content:
|
||||
content = message.data.get("text", {}).get("content", "").strip()
|
||||
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Received empty or unsupported message type: {}",
|
||||
chatbot_msg.message_type,
|
||||
)
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
||||
sender_name = chatbot_msg.sender_nick or "Unknown"
|
||||
|
||||
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
||||
|
||||
# Forward to Nanobot via _on_message (non-blocking).
|
||||
# Store reference to prevent GC before task completes.
|
||||
task = asyncio.create_task(
|
||||
self.channel._on_message(content, sender_id, sender_name)
|
||||
)
|
||||
self.channel._background_tasks.add(task)
|
||||
task.add_done_callback(self.channel._background_tasks.discard)
|
||||
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing DingTalk message: {}", e)
|
||||
# Return OK to avoid retry loop from DingTalk server
|
||||
return AckMessage.STATUS_OK, "Error"
|
||||
|
||||
|
||||
class DingTalkChannel(BaseChannel):
|
||||
"""
|
||||
DingTalk channel using Stream Mode.
|
||||
|
||||
Uses WebSocket to receive events via `dingtalk-stream` SDK.
|
||||
Uses direct HTTP API to send messages (SDK is mainly for receiving).
|
||||
|
||||
Note: Currently only supports private (1:1) chat. Group messages are
|
||||
received but replies are sent back as private messages to the sender.
|
||||
"""
|
||||
|
||||
name = "dingtalk"
|
||||
|
||||
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DingTalkConfig = config
|
||||
self._client: Any = None
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
|
||||
# Access Token management for sending messages
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: float = 0
|
||||
|
||||
# Hold references to background tasks to prevent GC
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the DingTalk bot with Stream Mode."""
|
||||
try:
|
||||
if not DINGTALK_AVAILABLE:
|
||||
logger.error(
|
||||
"DingTalk Stream SDK not installed. Run: pip install dingtalk-stream"
|
||||
)
|
||||
return
|
||||
|
||||
if not self.config.client_id or not self.config.client_secret:
|
||||
logger.error("DingTalk client_id and client_secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient()
|
||||
|
||||
logger.info(
|
||||
"Initializing DingTalk Stream Client with Client ID: {}...",
|
||||
self.config.client_id,
|
||||
)
|
||||
credential = Credential(self.config.client_id, self.config.client_secret)
|
||||
self._client = DingTalkStreamClient(credential)
|
||||
|
||||
# Register standard handler
|
||||
handler = NanobotDingTalkHandler(self)
|
||||
self._client.register_callback_handler(ChatbotMessage.TOPIC, handler)
|
||||
|
||||
logger.info("DingTalk bot started with Stream Mode")
|
||||
|
||||
# Reconnect loop: restart stream if SDK exits or crashes
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start()
|
||||
except Exception as e:
|
||||
logger.warning("DingTalk stream error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to start DingTalk channel: {}", e)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DingTalk bot."""
|
||||
self._running = False
|
||||
# Close the shared HTTP client
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
# Cancel outstanding background tasks
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def _get_access_token(self) -> str | None:
|
||||
"""Get or refresh Access Token."""
|
||||
if self._access_token and time.time() < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
data = {
|
||||
"appKey": self.config.client_id,
|
||||
"appSecret": self.config.client_secret,
|
||||
}
|
||||
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot refresh token")
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=data)
|
||||
resp.raise_for_status()
|
||||
res_data = resp.json()
|
||||
self._access_token = res_data.get("accessToken")
|
||||
# Expire 60s early to be safe
|
||||
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
||||
return self._access_token
|
||||
except Exception as e:
|
||||
logger.error("Failed to get DingTalk access token: {}", e)
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through DingTalk."""
|
||||
token = await self._get_access_token()
|
||||
if not token:
|
||||
return
|
||||
|
||||
# oToMessages/batchSend: sends to individual users (private chat)
|
||||
# https://open.dingtalk.com/document/orgapp/robot-batch-send-messages
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
|
||||
headers = {"x-acs-dingtalk-access-token": token}
|
||||
|
||||
data = {
|
||||
"robotCode": self.config.client_id,
|
||||
"userIds": [msg.chat_id], # chat_id is the user's staffId
|
||||
"msgKey": "sampleMarkdown",
|
||||
"msgParam": json.dumps({
|
||||
"text": msg.content,
|
||||
"title": "Nanobot Reply",
|
||||
}, ensure_ascii=False),
|
||||
}
|
||||
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||
return
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=data, headers=headers)
|
||||
if resp.status_code != 200:
|
||||
logger.error("DingTalk send failed: {}", resp.text)
|
||||
else:
|
||||
logger.debug("DingTalk message sent to {}", msg.chat_id)
|
||||
except Exception as e:
|
||||
logger.error("Error sending DingTalk message: {}", e)
|
||||
|
||||
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||
|
||||
Delegates to BaseChannel._handle_message() which enforces allow_from
|
||||
permission checks before publishing to the bus.
|
||||
"""
|
||||
try:
|
||||
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender_id, # For private chat, chat_id == sender_id
|
||||
content=str(content),
|
||||
metadata={
|
||||
"sender_name": sender_name,
|
||||
"platform": "dingtalk",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error publishing DingTalk message: {}", e)
|
||||
301
app-instance/backend/nanobot/channels/discord.py
Normal file
301
app-instance/backend/nanobot/channels/discord.py
Normal file
@ -0,0 +1,301 @@
|
||||
"""Discord channel implementation using Discord Gateway websocket."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import DiscordConfig
|
||||
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
|
||||
|
||||
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
|
||||
"""Split content into chunks within max_len, preferring line breaks."""
|
||||
if not content:
|
||||
return []
|
||||
if len(content) <= max_len:
|
||||
return [content]
|
||||
chunks: list[str] = []
|
||||
while content:
|
||||
if len(content) <= max_len:
|
||||
chunks.append(content)
|
||||
break
|
||||
cut = content[:max_len]
|
||||
pos = cut.rfind('\n')
|
||||
if pos <= 0:
|
||||
pos = cut.rfind(' ')
|
||||
if pos <= 0:
|
||||
pos = max_len
|
||||
chunks.append(content[:pos])
|
||||
content = content[pos:].lstrip()
|
||||
return chunks
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
"""Discord channel using Gateway websocket."""
|
||||
|
||||
name = "discord"
|
||||
|
||||
def __init__(self, config: DiscordConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DiscordConfig = config
|
||||
self._ws: websockets.WebSocketClientProtocol | None = None
|
||||
self._seq: int | None = None
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
if not self.config.token:
|
||||
logger.error("Discord bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.info("Connecting to Discord gateway...")
|
||||
async with websockets.connect(self.config.gateway_url) as ws:
|
||||
self._ws = ws
|
||||
await self._gateway_loop()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Discord gateway error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Discord channel."""
|
||||
self._running = False
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
return
|
||||
|
||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
|
||||
try:
|
||||
chunks = _split_message(msg.content or "")
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Only set reply reference on the first chunk
|
||||
if i == 0 and msg.reply_to:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
if not await self._send_payload(url, headers, payload):
|
||||
break # Abort remaining chunks on failure
|
||||
finally:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
|
||||
async def _send_payload(
|
||||
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
async for raw in self._ws:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||
continue
|
||||
|
||||
op = data.get("op")
|
||||
event_type = data.get("t")
|
||||
seq = data.get("s")
|
||||
payload = data.get("d")
|
||||
|
||||
if seq is not None:
|
||||
self._seq = seq
|
||||
|
||||
if op == 10:
|
||||
# HELLO: start heartbeat and identify
|
||||
interval_ms = payload.get("heartbeat_interval", 45000)
|
||||
await self._start_heartbeat(interval_ms / 1000)
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
# RECONNECT: exit loop to reconnect
|
||||
logger.info("Discord gateway requested reconnect")
|
||||
break
|
||||
elif op == 9:
|
||||
# INVALID_SESSION: reconnect
|
||||
logger.warning("Discord gateway invalid session")
|
||||
break
|
||||
|
||||
async def _identify(self) -> None:
|
||||
"""Send IDENTIFY payload."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
identify = {
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.config.token,
|
||||
"intents": self.config.intents,
|
||||
"properties": {
|
||||
"os": "nanobot",
|
||||
"browser": "nanobot",
|
||||
"device": "nanobot",
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send(json.dumps(identify))
|
||||
|
||||
async def _start_heartbeat(self, interval_s: float) -> None:
|
||||
"""Start or restart the heartbeat loop."""
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
|
||||
async def heartbeat_loop() -> None:
|
||||
while self._running and self._ws:
|
||||
payload = {"op": 1, "d": self._seq}
|
||||
try:
|
||||
await self._ws.send(json.dumps(payload))
|
||||
except Exception as e:
|
||||
logger.warning("Discord heartbeat failed: {}", e)
|
||||
break
|
||||
await asyncio.sleep(interval_s)
|
||||
|
||||
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
|
||||
|
||||
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
author = payload.get("author") or {}
|
||||
if author.get("bot"):
|
||||
return
|
||||
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
media_paths: list[str] = []
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
|
||||
for attachment in payload.get("attachments") or []:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename") or "attachment"
|
||||
size = attachment.get("size") or 0
|
||||
if not url or not self._http:
|
||||
continue
|
||||
if size and size > MAX_ATTACHMENT_BYTES:
|
||||
content_parts.append(f"[attachment: {filename} - too large]")
|
||||
continue
|
||||
try:
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
|
||||
resp = await self._http.get(url)
|
||||
resp.raise_for_status()
|
||||
file_path.write_bytes(resp.content)
|
||||
media_paths.append(str(file_path))
|
||||
content_parts.append(f"[attachment: {file_path}]")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download Discord attachment: {}", e)
|
||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||
|
||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||
|
||||
await self._start_typing(channel_id)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content="\n".join(p for p in content_parts if p) or "[empty message]",
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": payload.get("guild_id"),
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def typing_loop() -> None:
|
||||
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
while self._running:
|
||||
try:
|
||||
await self._http.post(url, headers=headers)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
|
||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||
|
||||
async def _stop_typing(self, channel_id: str) -> None:
|
||||
"""Stop typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(channel_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
404
app-instance/backend/nanobot/channels/email.py
Normal file
404
app-instance/backend/nanobot/channels/email.py
Normal file
@ -0,0 +1,404 @@
|
||||
"""Email channel implementation using IMAP polling + SMTP replies."""
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import imaplib
|
||||
import re
|
||||
import smtplib
|
||||
import ssl
|
||||
from datetime import date
|
||||
from email import policy
|
||||
from email.header import decode_header, make_header
|
||||
from email.message import EmailMessage
|
||||
from email.parser import BytesParser
|
||||
from email.utils import parseaddr
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import EmailConfig
|
||||
|
||||
|
||||
class EmailChannel(BaseChannel):
|
||||
"""
|
||||
Email channel.
|
||||
|
||||
Inbound:
|
||||
- Poll IMAP mailbox for unread messages.
|
||||
- Convert each message into an inbound event.
|
||||
|
||||
Outbound:
|
||||
- Send responses via SMTP back to the sender address.
|
||||
"""
|
||||
|
||||
name = "email"
|
||||
_IMAP_MONTHS = (
|
||||
"Jan",
|
||||
"Feb",
|
||||
"Mar",
|
||||
"Apr",
|
||||
"May",
|
||||
"Jun",
|
||||
"Jul",
|
||||
"Aug",
|
||||
"Sep",
|
||||
"Oct",
|
||||
"Nov",
|
||||
"Dec",
|
||||
)
|
||||
|
||||
def __init__(self, config: EmailConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: EmailConfig = config
|
||||
self._last_subject_by_chat: dict[str, str] = {}
|
||||
self._last_message_id_by_chat: dict[str, str] = {}
|
||||
self._processed_uids: set[str] = set() # Capped to prevent unbounded growth
|
||||
self._MAX_PROCESSED_UIDS = 100000
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start polling IMAP for inbound emails."""
|
||||
if not self.config.consent_granted:
|
||||
logger.warning(
|
||||
"Email channel disabled: consent_granted is false. "
|
||||
"Set channels.email.consentGranted=true after explicit user permission."
|
||||
)
|
||||
return
|
||||
|
||||
if not self._validate_config():
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("Starting Email channel (IMAP polling mode)...")
|
||||
|
||||
poll_seconds = max(5, int(self.config.poll_interval_seconds))
|
||||
while self._running:
|
||||
try:
|
||||
inbound_items = await asyncio.to_thread(self._fetch_new_messages)
|
||||
for item in inbound_items:
|
||||
sender = item["sender"]
|
||||
subject = item.get("subject", "")
|
||||
message_id = item.get("message_id", "")
|
||||
|
||||
if subject:
|
||||
self._last_subject_by_chat[sender] = subject
|
||||
if message_id:
|
||||
self._last_message_id_by_chat[sender] = message_id
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender,
|
||||
chat_id=sender,
|
||||
content=item["content"],
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Email polling error: {}", e)
|
||||
|
||||
await asyncio.sleep(poll_seconds)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop polling loop."""
|
||||
self._running = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send email via SMTP."""
|
||||
if not self.config.consent_granted:
|
||||
logger.warning("Skip email send: consent_granted is false")
|
||||
return
|
||||
|
||||
force_send = bool((msg.metadata or {}).get("force_send"))
|
||||
if not self.config.auto_reply_enabled and not force_send:
|
||||
logger.info("Skip automatic email reply: auto_reply_enabled is false")
|
||||
return
|
||||
|
||||
if not self.config.smtp_host:
|
||||
logger.warning("Email channel SMTP host not configured")
|
||||
return
|
||||
|
||||
to_addr = msg.chat_id.strip()
|
||||
if not to_addr:
|
||||
logger.warning("Email channel missing recipient address")
|
||||
return
|
||||
|
||||
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
||||
subject = self._reply_subject(base_subject)
|
||||
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
||||
override = msg.metadata["subject"].strip()
|
||||
if override:
|
||||
subject = override
|
||||
|
||||
email_msg = EmailMessage()
|
||||
email_msg["From"] = self.config.from_address or self.config.smtp_username or self.config.imap_username
|
||||
email_msg["To"] = to_addr
|
||||
email_msg["Subject"] = subject
|
||||
email_msg.set_content(msg.content or "")
|
||||
|
||||
in_reply_to = self._last_message_id_by_chat.get(to_addr)
|
||||
if in_reply_to:
|
||||
email_msg["In-Reply-To"] = in_reply_to
|
||||
email_msg["References"] = in_reply_to
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(self._smtp_send, email_msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending email to {}: {}", to_addr, e)
|
||||
raise
|
||||
|
||||
def _validate_config(self) -> bool:
|
||||
missing = []
|
||||
if not self.config.imap_host:
|
||||
missing.append("imap_host")
|
||||
if not self.config.imap_username:
|
||||
missing.append("imap_username")
|
||||
if not self.config.imap_password:
|
||||
missing.append("imap_password")
|
||||
if not self.config.smtp_host:
|
||||
missing.append("smtp_host")
|
||||
if not self.config.smtp_username:
|
||||
missing.append("smtp_username")
|
||||
if not self.config.smtp_password:
|
||||
missing.append("smtp_password")
|
||||
|
||||
if missing:
|
||||
logger.error("Email channel not configured, missing: {}", ', '.join(missing))
|
||||
return False
|
||||
return True
|
||||
|
||||
def _smtp_send(self, msg: EmailMessage) -> None:
|
||||
timeout = 30
|
||||
if self.config.smtp_use_ssl:
|
||||
with smtplib.SMTP_SSL(
|
||||
self.config.smtp_host,
|
||||
self.config.smtp_port,
|
||||
timeout=timeout,
|
||||
) as smtp:
|
||||
smtp.login(self.config.smtp_username, self.config.smtp_password)
|
||||
smtp.send_message(msg)
|
||||
return
|
||||
|
||||
with smtplib.SMTP(self.config.smtp_host, self.config.smtp_port, timeout=timeout) as smtp:
|
||||
if self.config.smtp_use_tls:
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self.config.smtp_username, self.config.smtp_password)
|
||||
smtp.send_message(msg)
|
||||
|
||||
def _fetch_new_messages(self) -> list[dict[str, Any]]:
|
||||
"""Poll IMAP and return parsed unread messages."""
|
||||
return self._fetch_messages(
|
||||
search_criteria=("UNSEEN",),
|
||||
mark_seen=self.config.mark_seen,
|
||||
dedupe=True,
|
||||
limit=0,
|
||||
)
|
||||
|
||||
def fetch_messages_between_dates(
|
||||
self,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch messages in [start_date, end_date) by IMAP date search.
|
||||
|
||||
This is used for historical summarization tasks (e.g. "yesterday").
|
||||
"""
|
||||
if end_date <= start_date:
|
||||
return []
|
||||
|
||||
return self._fetch_messages(
|
||||
search_criteria=(
|
||||
"SINCE",
|
||||
self._format_imap_date(start_date),
|
||||
"BEFORE",
|
||||
self._format_imap_date(end_date),
|
||||
),
|
||||
mark_seen=False,
|
||||
dedupe=False,
|
||||
limit=max(1, int(limit)),
|
||||
)
|
||||
|
||||
def _fetch_messages(
|
||||
self,
|
||||
search_criteria: tuple[str, ...],
|
||||
mark_seen: bool,
|
||||
dedupe: bool,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch messages by arbitrary IMAP search criteria."""
|
||||
messages: list[dict[str, Any]] = []
|
||||
mailbox = self.config.imap_mailbox or "INBOX"
|
||||
|
||||
if self.config.imap_use_ssl:
|
||||
client = imaplib.IMAP4_SSL(self.config.imap_host, self.config.imap_port)
|
||||
else:
|
||||
client = imaplib.IMAP4(self.config.imap_host, self.config.imap_port)
|
||||
|
||||
try:
|
||||
client.login(self.config.imap_username, self.config.imap_password)
|
||||
status, _ = client.select(mailbox)
|
||||
if status != "OK":
|
||||
return messages
|
||||
|
||||
status, data = client.search(None, *search_criteria)
|
||||
if status != "OK" or not data:
|
||||
return messages
|
||||
|
||||
ids = data[0].split()
|
||||
if limit > 0 and len(ids) > limit:
|
||||
ids = ids[-limit:]
|
||||
for imap_id in ids:
|
||||
status, fetched = client.fetch(imap_id, "(BODY.PEEK[] UID)")
|
||||
if status != "OK" or not fetched:
|
||||
continue
|
||||
|
||||
raw_bytes = self._extract_message_bytes(fetched)
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
uid = self._extract_uid(fetched)
|
||||
if dedupe and uid and uid in self._processed_uids:
|
||||
continue
|
||||
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(raw_bytes)
|
||||
sender = parseaddr(parsed.get("From", ""))[1].strip().lower()
|
||||
if not sender:
|
||||
continue
|
||||
|
||||
subject = self._decode_header_value(parsed.get("Subject", ""))
|
||||
date_value = parsed.get("Date", "")
|
||||
message_id = parsed.get("Message-ID", "").strip()
|
||||
body = self._extract_text_body(parsed)
|
||||
|
||||
if not body:
|
||||
body = "(empty email body)"
|
||||
|
||||
body = body[: self.config.max_body_chars]
|
||||
content = (
|
||||
f"Email received.\n"
|
||||
f"From: {sender}\n"
|
||||
f"Subject: {subject}\n"
|
||||
f"Date: {date_value}\n\n"
|
||||
f"{body}"
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"subject": subject,
|
||||
"date": date_value,
|
||||
"sender_email": sender,
|
||||
"uid": uid,
|
||||
}
|
||||
messages.append(
|
||||
{
|
||||
"sender": sender,
|
||||
"subject": subject,
|
||||
"message_id": message_id,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
if dedupe and uid:
|
||||
self._processed_uids.add(uid)
|
||||
# mark_seen is the primary dedup; this set is a safety net
|
||||
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
||||
# Evict a random half to cap memory; mark_seen is the primary dedup
|
||||
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
|
||||
|
||||
if mark_seen:
|
||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||
finally:
|
||||
try:
|
||||
client.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def _format_imap_date(cls, value: date) -> str:
|
||||
"""Format date for IMAP search (always English month abbreviations)."""
|
||||
month = cls._IMAP_MONTHS[value.month - 1]
|
||||
return f"{value.day:02d}-{month}-{value.year}"
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_bytes(fetched: list[Any]) -> bytes | None:
|
||||
for item in fetched:
|
||||
if isinstance(item, tuple) and len(item) >= 2 and isinstance(item[1], (bytes, bytearray)):
|
||||
return bytes(item[1])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_uid(fetched: list[Any]) -> str:
|
||||
for item in fetched:
|
||||
if isinstance(item, tuple) and item and isinstance(item[0], (bytes, bytearray)):
|
||||
head = bytes(item[0]).decode("utf-8", errors="ignore")
|
||||
m = re.search(r"UID\s+(\d+)", head)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _decode_header_value(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
return str(make_header(decode_header(value)))
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _extract_text_body(cls, msg: Any) -> str:
|
||||
"""Best-effort extraction of readable body text."""
|
||||
if msg.is_multipart():
|
||||
plain_parts: list[str] = []
|
||||
html_parts: list[str] = []
|
||||
for part in msg.walk():
|
||||
if part.get_content_disposition() == "attachment":
|
||||
continue
|
||||
content_type = part.get_content_type()
|
||||
try:
|
||||
payload = part.get_content()
|
||||
except Exception:
|
||||
payload_bytes = part.get_payload(decode=True) or b""
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
payload = payload_bytes.decode(charset, errors="replace")
|
||||
if not isinstance(payload, str):
|
||||
continue
|
||||
if content_type == "text/plain":
|
||||
plain_parts.append(payload)
|
||||
elif content_type == "text/html":
|
||||
html_parts.append(payload)
|
||||
if plain_parts:
|
||||
return "\n\n".join(plain_parts).strip()
|
||||
if html_parts:
|
||||
return cls._html_to_text("\n\n".join(html_parts)).strip()
|
||||
return ""
|
||||
|
||||
try:
|
||||
payload = msg.get_content()
|
||||
except Exception:
|
||||
payload_bytes = msg.get_payload(decode=True) or b""
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
payload = payload_bytes.decode(charset, errors="replace")
|
||||
if not isinstance(payload, str):
|
||||
return ""
|
||||
if msg.get_content_type() == "text/html":
|
||||
return cls._html_to_text(payload).strip()
|
||||
return payload.strip()
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(raw_html: str) -> str:
|
||||
text = re.sub(r"<\s*br\s*/?>", "\n", raw_html, flags=re.IGNORECASE)
|
||||
text = re.sub(r"<\s*/\s*p\s*>", "\n", text, flags=re.IGNORECASE)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
return html.unescape(text)
|
||||
|
||||
def _reply_subject(self, base_subject: str) -> str:
|
||||
subject = (base_subject or "").strip() or "nanobot reply"
|
||||
prefix = self.config.subject_prefix or "Re: "
|
||||
if subject.lower().startswith("re:"):
|
||||
return subject
|
||||
return f"{prefix}{subject}"
|
||||
733
app-instance/backend/nanobot/channels/feishu.py
Normal file
733
app-instance/backend/nanobot/channels/feishu.py
Normal file
@ -0,0 +1,733 @@
|
||||
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import FeishuConfig
|
||||
|
||||
try:
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateFileRequest,
|
||||
CreateFileRequestBody,
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
Emoji,
|
||||
GetFileRequest,
|
||||
GetMessageResourceRequest,
|
||||
P2ImMessageReceiveV1,
|
||||
)
|
||||
FEISHU_AVAILABLE = True
|
||||
except ImportError:
|
||||
FEISHU_AVAILABLE = False
|
||||
lark = None
|
||||
Emoji = None
|
||||
|
||||
# Message type display mapping
|
||||
MSG_TYPE_MAP = {
|
||||
"image": "[image]",
|
||||
"audio": "[audio]",
|
||||
"file": "[file]",
|
||||
"sticker": "[sticker]",
|
||||
}
|
||||
|
||||
|
||||
def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
|
||||
"""Extract text representation from share cards and interactive messages."""
|
||||
parts = []
|
||||
|
||||
if msg_type == "share_chat":
|
||||
parts.append(f"[shared chat: {content_json.get('chat_id', '')}]")
|
||||
elif msg_type == "share_user":
|
||||
parts.append(f"[shared user: {content_json.get('user_id', '')}]")
|
||||
elif msg_type == "interactive":
|
||||
parts.extend(_extract_interactive_content(content_json))
|
||||
elif msg_type == "share_calendar_event":
|
||||
parts.append(f"[shared calendar event: {content_json.get('event_key', '')}]")
|
||||
elif msg_type == "system":
|
||||
parts.append("[system message]")
|
||||
elif msg_type == "merge_forward":
|
||||
parts.append("[merged forward messages]")
|
||||
|
||||
return "\n".join(parts) if parts else f"[{msg_type}]"
|
||||
|
||||
|
||||
def _extract_interactive_content(content: dict) -> list[str]:
|
||||
"""Recursively extract text and links from interactive card content."""
|
||||
parts = []
|
||||
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return [content] if content.strip() else []
|
||||
|
||||
if not isinstance(content, dict):
|
||||
return parts
|
||||
|
||||
if "title" in content:
|
||||
title = content["title"]
|
||||
if isinstance(title, dict):
|
||||
title_content = title.get("content", "") or title.get("text", "")
|
||||
if title_content:
|
||||
parts.append(f"title: {title_content}")
|
||||
elif isinstance(title, str):
|
||||
parts.append(f"title: {title}")
|
||||
|
||||
for element in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
|
||||
parts.extend(_extract_element_content(element))
|
||||
|
||||
card = content.get("card", {})
|
||||
if card:
|
||||
parts.extend(_extract_interactive_content(card))
|
||||
|
||||
header = content.get("header", {})
|
||||
if header:
|
||||
header_title = header.get("title", {})
|
||||
if isinstance(header_title, dict):
|
||||
header_text = header_title.get("content", "") or header_title.get("text", "")
|
||||
if header_text:
|
||||
parts.append(f"title: {header_text}")
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_element_content(element: dict) -> list[str]:
|
||||
"""Extract content from a single card element."""
|
||||
parts = []
|
||||
|
||||
if not isinstance(element, dict):
|
||||
return parts
|
||||
|
||||
tag = element.get("tag", "")
|
||||
|
||||
if tag in ("markdown", "lark_md"):
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
elif tag == "div":
|
||||
text = element.get("text", {})
|
||||
if isinstance(text, dict):
|
||||
text_content = text.get("content", "") or text.get("text", "")
|
||||
if text_content:
|
||||
parts.append(text_content)
|
||||
elif isinstance(text, str):
|
||||
parts.append(text)
|
||||
for field in element.get("fields", []):
|
||||
if isinstance(field, dict):
|
||||
field_text = field.get("text", {})
|
||||
if isinstance(field_text, dict):
|
||||
c = field_text.get("content", "")
|
||||
if c:
|
||||
parts.append(c)
|
||||
|
||||
elif tag == "a":
|
||||
href = element.get("href", "")
|
||||
text = element.get("text", "")
|
||||
if href:
|
||||
parts.append(f"link: {href}")
|
||||
if text:
|
||||
parts.append(text)
|
||||
|
||||
elif tag == "button":
|
||||
text = element.get("text", {})
|
||||
if isinstance(text, dict):
|
||||
c = text.get("content", "")
|
||||
if c:
|
||||
parts.append(c)
|
||||
url = element.get("url", "") or element.get("multi_url", {}).get("url", "")
|
||||
if url:
|
||||
parts.append(f"link: {url}")
|
||||
|
||||
elif tag == "img":
|
||||
alt = element.get("alt", {})
|
||||
parts.append(alt.get("content", "[image]") if isinstance(alt, dict) else "[image]")
|
||||
|
||||
elif tag == "note":
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
elif tag == "column_set":
|
||||
for col in element.get("columns", []):
|
||||
for ce in col.get("elements", []):
|
||||
parts.extend(_extract_element_content(ce))
|
||||
|
||||
elif tag == "plain_text":
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
else:
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_post_text(content_json: dict) -> str:
|
||||
"""Extract plain text from Feishu post (rich text) message content.
|
||||
|
||||
Supports two formats:
|
||||
1. Direct format: {"title": "...", "content": [...]}
|
||||
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
||||
"""
|
||||
def extract_from_lang(lang_content: dict) -> str | None:
|
||||
if not isinstance(lang_content, dict):
|
||||
return None
|
||||
title = lang_content.get("title", "")
|
||||
content_blocks = lang_content.get("content", [])
|
||||
if not isinstance(content_blocks, list):
|
||||
return None
|
||||
text_parts = []
|
||||
if title:
|
||||
text_parts.append(title)
|
||||
for block in content_blocks:
|
||||
if not isinstance(block, list):
|
||||
continue
|
||||
for element in block:
|
||||
if isinstance(element, dict):
|
||||
tag = element.get("tag")
|
||||
if tag == "text":
|
||||
text_parts.append(element.get("text", ""))
|
||||
elif tag == "a":
|
||||
text_parts.append(element.get("text", ""))
|
||||
elif tag == "at":
|
||||
text_parts.append(f"@{element.get('user_name', 'user')}")
|
||||
return " ".join(text_parts).strip() if text_parts else None
|
||||
|
||||
# Try direct format first
|
||||
if "content" in content_json:
|
||||
result = extract_from_lang(content_json)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Try localized format
|
||||
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
||||
lang_content = content_json.get(lang_key)
|
||||
result = extract_from_lang(lang_content)
|
||||
if result:
|
||||
return result
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
Feishu/Lark channel using WebSocket long connection.
|
||||
|
||||
Uses WebSocket to receive events - no public IP or webhook required.
|
||||
|
||||
Requires:
|
||||
- App ID and App Secret from Feishu Open Platform
|
||||
- Bot capability enabled
|
||||
- Event subscription enabled (im.message.receive_v1)
|
||||
"""
|
||||
|
||||
name = "feishu"
|
||||
|
||||
def __init__(self, config: FeishuConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: FeishuConfig = config
|
||||
self._client: Any = None
|
||||
self._ws_client: Any = None
|
||||
self._ws_thread: threading.Thread | None = None
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Feishu bot with WebSocket long connection."""
|
||||
if not FEISHU_AVAILABLE:
|
||||
logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
|
||||
return
|
||||
|
||||
if not self.config.app_id or not self.config.app_secret:
|
||||
logger.error("Feishu app_id and app_secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Create Lark client for sending messages
|
||||
self._client = lark.Client.builder() \
|
||||
.app_id(self.config.app_id) \
|
||||
.app_secret(self.config.app_secret) \
|
||||
.log_level(lark.LogLevel.INFO) \
|
||||
.build()
|
||||
|
||||
# Create event handler (only register message receive, ignore other events)
|
||||
event_handler = lark.EventDispatcherHandler.builder(
|
||||
self.config.encrypt_key or "",
|
||||
self.config.verification_token or "",
|
||||
).register_p2_im_message_receive_v1(
|
||||
self._on_message_sync
|
||||
).build()
|
||||
|
||||
# Create WebSocket client for long connection
|
||||
self._ws_client = lark.ws.Client(
|
||||
self.config.app_id,
|
||||
self.config.app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.INFO
|
||||
)
|
||||
|
||||
# Start WebSocket client in a separate thread with reconnect loop
|
||||
def run_ws():
|
||||
while self._running:
|
||||
try:
|
||||
self._ws_client.start()
|
||||
except Exception as e:
|
||||
logger.warning("Feishu WebSocket error: {}", e)
|
||||
if self._running:
|
||||
import time; time.sleep(5)
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||
self._ws_thread.start()
|
||||
|
||||
logger.info("Feishu bot started with WebSocket long connection")
|
||||
logger.info("No public IP required - using WebSocket to receive events")
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Feishu bot."""
|
||||
self._running = False
|
||||
if self._ws_client:
|
||||
try:
|
||||
self._ws_client.stop()
|
||||
except Exception as e:
|
||||
logger.warning("Error stopping WebSocket client: {}", e)
|
||||
logger.info("Feishu bot stopped")
|
||||
|
||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||
try:
|
||||
request = CreateMessageReactionRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.request_body(
|
||||
CreateMessageReactionRequestBody.builder()
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
|
||||
.build()
|
||||
).build()
|
||||
|
||||
response = self._client.im.v1.message_reaction.create(request)
|
||||
|
||||
if not response.success():
|
||||
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
||||
else:
|
||||
logger.debug("Added {} reaction to message {}", emoji_type, message_id)
|
||||
except Exception as e:
|
||||
logger.warning("Error adding reaction: {}", e)
|
||||
|
||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
||||
"""
|
||||
Add a reaction emoji to a message (non-blocking).
|
||||
|
||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||
"""
|
||||
if not self._client or not Emoji:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
|
||||
|
||||
# Regex to match markdown tables (header + separator + data rows)
|
||||
_TABLE_RE = re.compile(
|
||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
||||
|
||||
_CODE_BLOCK_RE = re.compile(r"(```[\s\S]*?```)", re.MULTILINE)
|
||||
|
||||
@staticmethod
|
||||
def _parse_md_table(table_text: str) -> dict | None:
|
||||
"""Parse a markdown table into a Feishu table element."""
|
||||
lines = [l.strip() for l in table_text.strip().split("\n") if l.strip()]
|
||||
if len(lines) < 3:
|
||||
return None
|
||||
split = lambda l: [c.strip() for c in l.strip("|").split("|")]
|
||||
headers = split(lines[0])
|
||||
rows = [split(l) for l in lines[2:]]
|
||||
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||
for i, h in enumerate(headers)]
|
||||
return {
|
||||
"tag": "table",
|
||||
"page_size": len(rows) + 1,
|
||||
"columns": columns,
|
||||
"rows": [{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows],
|
||||
}
|
||||
|
||||
def _build_card_elements(self, content: str) -> list[dict]:
|
||||
"""Split content into div/markdown + table elements for Feishu card."""
|
||||
elements, last_end = [], 0
|
||||
for m in self._TABLE_RE.finditer(content):
|
||||
before = content[last_end:m.start()]
|
||||
if before.strip():
|
||||
elements.extend(self._split_headings(before))
|
||||
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
|
||||
last_end = m.end()
|
||||
remaining = content[last_end:]
|
||||
if remaining.strip():
|
||||
elements.extend(self._split_headings(remaining))
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
def _split_headings(self, content: str) -> list[dict]:
|
||||
"""Split content by headings, converting headings to div elements."""
|
||||
protected = content
|
||||
code_blocks = []
|
||||
for m in self._CODE_BLOCK_RE.finditer(content):
|
||||
code_blocks.append(m.group(1))
|
||||
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1)
|
||||
|
||||
elements = []
|
||||
last_end = 0
|
||||
for m in self._HEADING_RE.finditer(protected):
|
||||
before = protected[last_end:m.start()].strip()
|
||||
if before:
|
||||
elements.append({"tag": "markdown", "content": before})
|
||||
text = m.group(2).strip()
|
||||
elements.append({
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": f"**{text}**",
|
||||
},
|
||||
})
|
||||
last_end = m.end()
|
||||
remaining = protected[last_end:].strip()
|
||||
if remaining:
|
||||
elements.append({"tag": "markdown", "content": remaining})
|
||||
|
||||
for i, cb in enumerate(code_blocks):
|
||||
for el in elements:
|
||||
if el.get("tag") == "markdown":
|
||||
el["content"] = el["content"].replace(f"\x00CODE{i}\x00", cb)
|
||||
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
|
||||
_AUDIO_EXTS = {".opus"}
|
||||
_FILE_TYPE_MAP = {
|
||||
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
||||
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
||||
}
|
||||
|
||||
def _upload_image_sync(self, file_path: str) -> str | None:
|
||||
"""Upload an image to Feishu and return the image_key."""
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateImageRequest.builder() \
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type("message")
|
||||
.image(f)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.image.create(request)
|
||||
if response.success():
|
||||
image_key = response.data.image_key
|
||||
logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
|
||||
return image_key
|
||||
else:
|
||||
logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading image {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _upload_file_sync(self, file_path: str) -> str | None:
|
||||
"""Upload a file to Feishu and return the file_key."""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
|
||||
file_name = os.path.basename(file_path)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateFileRequest.builder() \
|
||||
.request_body(
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(file_name)
|
||||
.file(f)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.file.create(request)
|
||||
if response.success():
|
||||
file_key = response.data.file_key
|
||||
logger.debug("Uploaded file {}: {}", file_name, file_key)
|
||||
return file_key
|
||||
else:
|
||||
logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading file {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
|
||||
"""Download an image from Feishu message by message_id and image_key."""
|
||||
try:
|
||||
request = GetMessageResourceRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.file_key(image_key) \
|
||||
.type("image") \
|
||||
.build()
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
# GetMessageResourceRequest returns BytesIO, need to read bytes
|
||||
if hasattr(file_data, 'read'):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download image: code={}, msg={}", response.code, response.msg)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("Error downloading image {}: {}", image_key, e)
|
||||
return None, None
|
||||
|
||||
def _download_file_sync(
|
||||
self, message_id: str, file_key: str, resource_type: str = "file"
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
|
||||
try:
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message_id)
|
||||
.file_key(file_key)
|
||||
.type(resource_type)
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
if hasattr(file_data, "read"):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
|
||||
return None, None
|
||||
except Exception:
|
||||
logger.exception("Error downloading {} {}", resource_type, file_key)
|
||||
return None, None
|
||||
|
||||
async def _download_and_save_media(
|
||||
self,
|
||||
msg_type: str,
|
||||
content_json: dict,
|
||||
message_id: str | None = None
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Download media from Feishu and save to local disk.
|
||||
|
||||
Returns:
|
||||
(file_path, content_text) - file_path is None if download failed
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data, filename = None, None
|
||||
|
||||
if msg_type == "image":
|
||||
image_key = content_json.get("image_key")
|
||||
if image_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_image_sync, message_id, image_key
|
||||
)
|
||||
if not filename:
|
||||
filename = f"{image_key[:16]}.jpg"
|
||||
|
||||
elif msg_type in ("audio", "file", "media"):
|
||||
file_key = content_json.get("file_key")
|
||||
if file_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
if not filename:
|
||||
ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "")
|
||||
filename = f"{file_key[:16]}{ext}"
|
||||
|
||||
if data and filename:
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
logger.debug("Downloaded {} to {}", msg_type, file_path)
|
||||
return str(file_path), f"[{msg_type}: {filename}]"
|
||||
|
||||
return None, f"[{msg_type}: download failed]"
|
||||
|
||||
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
|
||||
"""Send a single message (text/image/file/interactive) synchronously."""
|
||||
try:
|
||||
request = CreateMessageRequest.builder() \
|
||||
.receive_id_type(receive_id_type) \
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.msg_type(msg_type)
|
||||
.content(content)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.message.create(request)
|
||||
if not response.success():
|
||||
logger.error(
|
||||
"Failed to send Feishu {} message: code={}, msg={}, log_id={}",
|
||||
msg_type, response.code, response.msg, response.get_log_id()
|
||||
)
|
||||
return False
|
||||
logger.debug("Feishu {} message sent to {}", msg_type, receive_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
||||
return False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Feishu, including media (images/files) if present."""
|
||||
if not self._client:
|
||||
logger.warning("Feishu client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for file_path in msg.media:
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("Media file not found: {}", file_path)
|
||||
continue
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext in self._IMAGE_EXTS:
|
||||
key = await loop.run_in_executor(None, self._upload_image_sync, file_path)
|
||||
if key:
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "image", json.dumps({"image_key": key}, ensure_ascii=False),
|
||||
)
|
||||
else:
|
||||
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
||||
if key:
|
||||
media_type = "audio" if ext in self._AUDIO_EXTS else "file"
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
||||
)
|
||||
|
||||
if msg.content and msg.content.strip():
|
||||
card = {"config": {"wide_screen_mode": True}, "elements": self._build_card_elements(msg.content)}
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu message: {}", e)
|
||||
|
||||
def _on_message_sync(self, data: "P2ImMessageReceiveV1") -> None:
|
||||
"""
|
||||
Sync handler for incoming messages (called from WebSocket thread).
|
||||
Schedules async handling in the main event loop.
|
||||
"""
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||
|
||||
async def _on_message(self, data: "P2ImMessageReceiveV1") -> None:
|
||||
"""Handle incoming message from Feishu."""
|
||||
try:
|
||||
event = data.event
|
||||
message = event.message
|
||||
sender = event.sender
|
||||
|
||||
# Deduplication check
|
||||
message_id = message.message_id
|
||||
if message_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[message_id] = None
|
||||
|
||||
# Trim cache
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Skip bot messages
|
||||
if sender.sender_type == "bot":
|
||||
return
|
||||
|
||||
sender_id = sender.sender_id.open_id if sender.sender_id else "unknown"
|
||||
chat_id = message.chat_id
|
||||
chat_type = message.chat_type
|
||||
msg_type = message.message_type
|
||||
|
||||
# Add reaction
|
||||
await self._add_reaction(message_id, "THUMBSUP")
|
||||
|
||||
# Parse content
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
try:
|
||||
content_json = json.loads(message.content) if message.content else {}
|
||||
except json.JSONDecodeError:
|
||||
content_json = {}
|
||||
|
||||
if msg_type == "text":
|
||||
text = content_json.get("text", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type == "post":
|
||||
text = _extract_post_text(content_json)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type in ("image", "audio", "file", "media"):
|
||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
||||
# Handle share cards and interactive messages
|
||||
text = _extract_share_card_content(content_json, msg_type)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
# Forward to message bus
|
||||
reply_to = chat_id if chat_type == "group" else sender_id
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=reply_to,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"chat_type": chat_type,
|
||||
"msg_type": msg_type,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing Feishu message: {}", e)
|
||||
326
app-instance/backend/nanobot/channels/manager.py
Normal file
326
app-instance/backend/nanobot/channels/manager.py
Normal file
@ -0,0 +1,326 @@
|
||||
"""渠道管理器:统一管理多聊天渠道的生命周期与消息路由。
|
||||
|
||||
本模块处在“Agent 核心逻辑”和“外部 IM 平台”之间,承担两类关键职责:
|
||||
1. 渠道生命周期管理:
|
||||
- 按配置初始化可用渠道(Telegram/Slack/Discord/WhatsApp/...);
|
||||
- 统一启动与停止,避免各渠道在 CLI 层分散管理。
|
||||
2. 出站消息分发:
|
||||
- 从 MessageBus 的 outbound 队列读取消息;
|
||||
- 根据 `msg.channel` 路由到目标渠道对象并执行 `send(...)`;
|
||||
- 对进度消息(_progress/_tool_hint)按全局开关过滤。
|
||||
|
||||
设计原则:
|
||||
- 渠道失败隔离:单个渠道启动/发送失败不应拖垮其它渠道;
|
||||
- 配置驱动:是否启用由 `config.channels.*.enabled` 决定;
|
||||
- 统一入口:上层只需与 MessageBus 交互,不关心各渠道细节。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
渠道协调器。
|
||||
|
||||
你可以把它看成一个“渠道运行时容器”:
|
||||
- `self.channels` 保存已启用渠道实例;
|
||||
- `_dispatch_outbound()` 作为中央分发协程持续消费 outbound 消息;
|
||||
- `start_all()/stop_all()` 负责渠道与分发协程的统一启停。
|
||||
|
||||
与 AgentLoop 的关系:
|
||||
- AgentLoop 只负责“生成 OutboundMessage”;
|
||||
- ChannelManager 负责“把 OutboundMessage 真的发出去”。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus):
|
||||
# 全局配置(含渠道开关、进度消息开关等)
|
||||
self.config = config
|
||||
# 与 AgentLoop 共享同一 MessageBus,负责消费 outbound。
|
||||
self.bus = bus
|
||||
# name -> channel instance(只存启用且成功初始化的渠道)
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
# 出站分发后台任务句柄(由 start_all 创建,stop_all 取消)
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
|
||||
# 构造时即按配置初始化渠道实例(不启动网络连接,仅实例化)。
|
||||
self._init_channels()
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""按配置初始化渠道实例。
|
||||
|
||||
注意:
|
||||
- 这里只做“实例化”,不会进入各渠道的 start() 主循环;
|
||||
- ImportError 会被捕获并记录 warning,允许缺依赖时降级运行;
|
||||
- 未启用渠道不会创建实例,也不会出现在 enabled_channels 列表里。
|
||||
"""
|
||||
|
||||
# Telegram 渠道:
|
||||
# - 需要 telegram 配置开启;
|
||||
# - 额外透传 groq_api_key(用于语音/转写等能力时按渠道内部策略使用)。
|
||||
if self.config.channels.telegram.enabled:
|
||||
try:
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
self.channels["telegram"] = TelegramChannel(
|
||||
self.config.channels.telegram,
|
||||
self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
)
|
||||
logger.info("Telegram channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Telegram channel not available: {}", e)
|
||||
|
||||
# WhatsApp 渠道(通过 bridge 连接)
|
||||
if self.config.channels.whatsapp.enabled:
|
||||
try:
|
||||
from nanobot.channels.whatsapp import WhatsAppChannel
|
||||
self.channels["whatsapp"] = WhatsAppChannel(
|
||||
self.config.channels.whatsapp, self.bus
|
||||
)
|
||||
logger.info("WhatsApp channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("WhatsApp channel not available: {}", e)
|
||||
|
||||
# Discord 渠道
|
||||
if self.config.channels.discord.enabled:
|
||||
try:
|
||||
from nanobot.channels.discord import DiscordChannel
|
||||
self.channels["discord"] = DiscordChannel(
|
||||
self.config.channels.discord, self.bus
|
||||
)
|
||||
logger.info("Discord channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Discord channel not available: {}", e)
|
||||
|
||||
# 飞书 / Lark 渠道
|
||||
if self.config.channels.feishu.enabled:
|
||||
try:
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
self.channels["feishu"] = FeishuChannel(
|
||||
self.config.channels.feishu, self.bus
|
||||
)
|
||||
logger.info("Feishu channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Feishu channel not available: {}", e)
|
||||
|
||||
# Mochat 渠道
|
||||
if self.config.channels.mochat.enabled:
|
||||
try:
|
||||
from nanobot.channels.mochat import MochatChannel
|
||||
|
||||
self.channels["mochat"] = MochatChannel(
|
||||
self.config.channels.mochat, self.bus
|
||||
)
|
||||
logger.info("Mochat channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Mochat channel not available: {}", e)
|
||||
|
||||
# 钉钉渠道
|
||||
if self.config.channels.dingtalk.enabled:
|
||||
try:
|
||||
from nanobot.channels.dingtalk import DingTalkChannel
|
||||
self.channels["dingtalk"] = DingTalkChannel(
|
||||
self.config.channels.dingtalk, self.bus
|
||||
)
|
||||
logger.info("DingTalk channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("DingTalk channel not available: {}", e)
|
||||
|
||||
# Email 渠道(IMAP 收件 + SMTP 发件)
|
||||
if self.config.channels.email.enabled:
|
||||
try:
|
||||
from nanobot.channels.email import EmailChannel
|
||||
self.channels["email"] = EmailChannel(
|
||||
self.config.channels.email, self.bus
|
||||
)
|
||||
logger.info("Email channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Email channel not available: {}", e)
|
||||
|
||||
# Slack 渠道
|
||||
if self.config.channels.slack.enabled:
|
||||
try:
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
self.channels["slack"] = SlackChannel(
|
||||
self.config.channels.slack, self.bus
|
||||
)
|
||||
logger.info("Slack channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Slack channel not available: {}", e)
|
||||
|
||||
# QQ 渠道
|
||||
if self.config.channels.qq.enabled:
|
||||
try:
|
||||
from nanobot.channels.qq import QQChannel
|
||||
self.channels["qq"] = QQChannel(
|
||||
self.config.channels.qq,
|
||||
self.bus,
|
||||
)
|
||||
logger.info("QQ channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("QQ channel not available: {}", e)
|
||||
|
||||
# Matrix 渠道
|
||||
if self.config.channels.matrix.enabled:
|
||||
try:
|
||||
from nanobot.channels.matrix import MatrixChannel
|
||||
self.channels["matrix"] = MatrixChannel(
|
||||
self.config.channels.matrix,
|
||||
self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
)
|
||||
logger.info("Matrix channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Matrix channel not available: {}", e)
|
||||
|
||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||
"""启动单个渠道并隔离异常。
|
||||
|
||||
设计意图:
|
||||
- 不让一个渠道的启动失败影响其它渠道启动;
|
||||
- 错误统一记录日志,方便后续定位具体渠道问题。
|
||||
"""
|
||||
try:
|
||||
await channel.start()
|
||||
except Exception as e:
|
||||
logger.error("Failed to start channel {}: {}", name, e)
|
||||
|
||||
async def start_all(self) -> None:
|
||||
"""启动所有渠道与出站分发协程。
|
||||
|
||||
启动顺序:
|
||||
1. 启动 outbound 分发任务(先就绪,避免启动早期消息丢失);
|
||||
2. 并发启动所有渠道 start() 协程;
|
||||
3. `gather` 挂住,直到渠道协程返回(正常应长期运行)。
|
||||
"""
|
||||
if not self.channels:
|
||||
logger.warning("No channels enabled")
|
||||
return
|
||||
|
||||
# 启动出站分发协程:负责消费 bus.outbound 并调用 channel.send()。
|
||||
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
||||
|
||||
# 启动渠道主循环。
|
||||
tasks = []
|
||||
for name, channel in self.channels.items():
|
||||
logger.info("Starting {} channel...", name)
|
||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||
|
||||
# 等待所有渠道任务(理论上它们应常驻直到 stop_all 被调用)。
|
||||
# return_exceptions=True 可避免一个任务异常导致 gather 整体中断。
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""停止所有渠道并关闭出站分发任务。
|
||||
|
||||
停止顺序:
|
||||
1. 先取消分发协程,避免继续从队列取消息;
|
||||
2. 再逐个 stop 渠道,释放各自连接/资源;
|
||||
3. 各渠道停止异常仅记录,不影响其它渠道收尾。
|
||||
"""
|
||||
logger.info("Stopping all channels...")
|
||||
|
||||
# 停止分发协程。
|
||||
if self._dispatch_task:
|
||||
self._dispatch_task.cancel()
|
||||
try:
|
||||
await self._dispatch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有渠道实例。
|
||||
for name, channel in self.channels.items():
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Stopped {} channel", name)
|
||||
except Exception as e:
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""消费 outbound 队列并路由发送到对应渠道。
|
||||
|
||||
分发规则:
|
||||
- `msg.channel` 决定目标渠道实例;
|
||||
- 若渠道不存在,记录 warning(通常表示渠道未启用或名称不匹配);
|
||||
- 进度消息可被全局开关过滤(send_progress / send_tool_hints)。
|
||||
|
||||
循环模型:
|
||||
- 使用 `wait_for(..., timeout=1.0)` 做短超时轮询,
|
||||
便于 stop_all 取消后快速退出;
|
||||
- Timeout 属于正常空闲态,不视为错误。
|
||||
"""
|
||||
logger.info("Outbound dispatcher started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 从总线获取一条待发送消息;短超时保证可取消性。
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_outbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
# 进度消息过滤:
|
||||
# - _progress=True 且 _tool_hint=True 受 send_tool_hints 控制
|
||||
# - _progress=True 且非工具提示受 send_progress 控制
|
||||
# 这样可以在渠道侧按需静默“中间态”,只保留最终回复。
|
||||
if msg.metadata.get("_progress"):
|
||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||
continue
|
||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||
continue
|
||||
|
||||
# 按 channel 名路由发送。
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
try:
|
||||
# 实际发送由各渠道实现(统一接口:BaseChannel.send)。
|
||||
await channel.send(msg)
|
||||
except Exception as e:
|
||||
# 单条发送失败不终止分发循环,避免“全局停摆”。
|
||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||
else:
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 队列暂时无消息:继续下一轮轮询。
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# stop_all 取消任务时走这里退出循环。
|
||||
break
|
||||
|
||||
def get_channel(self, name: str) -> BaseChannel | None:
|
||||
"""按名称获取渠道实例(未启用/不存在返回 None)。"""
|
||||
return self.channels.get(name)
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""返回所有已启用渠道的运行状态快照。
|
||||
|
||||
返回结构示例:
|
||||
{
|
||||
"telegram": {"enabled": True, "running": True},
|
||||
"slack": {"enabled": True, "running": False},
|
||||
}
|
||||
"""
|
||||
return {
|
||||
name: {
|
||||
# 出现在 self.channels 里即表示“配置层已启用且实例化成功”。
|
||||
"enabled": True,
|
||||
# running 由渠道实例自身维护,反映连接/主循环当前状态。
|
||||
"running": channel.is_running
|
||||
}
|
||||
for name, channel in self.channels.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def enabled_channels(self) -> list[str]:
|
||||
"""返回当前已启用并成功初始化的渠道名称列表。"""
|
||||
return list(self.channels.keys())
|
||||
733
app-instance/backend/nanobot/channels/matrix.py
Normal file
733
app-instance/backend/nanobot/channels/matrix.py
Normal file
@ -0,0 +1,733 @@
|
||||
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import nh3
|
||||
from mistune import create_markdown
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
AsyncClientConfig,
|
||||
ContentRepositoryConfigError,
|
||||
DownloadError,
|
||||
InviteEvent,
|
||||
JoinError,
|
||||
MatrixRoom,
|
||||
MemoryDownloadResponse,
|
||||
RoomEncryptedMedia,
|
||||
RoomMessage,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncResponse,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
|
||||
) from e
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_data_dir, get_media_dir
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
|
||||
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
||||
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
|
||||
TYPING_KEEPALIVE_INTERVAL_MS = 20_000
|
||||
MATRIX_HTML_FORMAT = "org.matrix.custom.html"
|
||||
_ATTACH_MARKER = "[attachment: {}]"
|
||||
_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
|
||||
_ATTACH_FAILED = "[attachment: {} - download failed]"
|
||||
_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
|
||||
_DEFAULT_ATTACH_NAME = "attachment"
|
||||
_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
|
||||
|
||||
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||
|
||||
MATRIX_MARKDOWN = create_markdown(
|
||||
escape=True,
|
||||
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
|
||||
)
|
||||
|
||||
MATRIX_ALLOWED_HTML_TAGS = {
|
||||
"p", "a", "strong", "em", "del", "code", "pre", "blockquote",
|
||||
"ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"hr", "br", "table", "thead", "tbody", "tr", "th", "td",
|
||||
"caption", "sup", "sub", "img",
|
||||
}
|
||||
MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
|
||||
"a": {"href"}, "code": {"class"}, "ol": {"start"},
|
||||
"img": {"src", "alt", "title", "width", "height"},
|
||||
}
|
||||
MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
|
||||
|
||||
|
||||
def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
|
||||
"""Filter attribute values to a safe Matrix-compatible subset."""
|
||||
if tag == "a" and attr == "href":
|
||||
return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
|
||||
if tag == "img" and attr == "src":
|
||||
return value if value.lower().startswith("mxc://") else None
|
||||
if tag == "code" and attr == "class":
|
||||
classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
|
||||
return " ".join(classes) if classes else None
|
||||
return value
|
||||
|
||||
|
||||
MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||
tags=MATRIX_ALLOWED_HTML_TAGS,
|
||||
attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
|
||||
attribute_filter=_filter_matrix_html_attribute,
|
||||
url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
|
||||
strip_comments=True,
|
||||
link_rel="noopener noreferrer",
|
||||
)
|
||||
|
||||
|
||||
def _render_markdown_html(text: str) -> str | None:
|
||||
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||
try:
|
||||
formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
|
||||
except Exception:
|
||||
return None
|
||||
if not formatted:
|
||||
return None
|
||||
# Skip formatted_body for plain <p>text</p> to keep payload minimal.
|
||||
if formatted.startswith("<p>") and formatted.endswith("</p>"):
|
||||
inner = formatted[3:-4]
|
||||
if "<" not in inner and ">" not in inner:
|
||||
return None
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||
if html := _render_markdown_html(text):
|
||||
content["format"] = MATRIX_HTML_FORMAT
|
||||
content["formatted_body"] = html
|
||||
return content
|
||||
|
||||
|
||||
class _NioLoguruHandler(logging.Handler):
|
||||
"""Route matrix-nio stdlib logs into Loguru."""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame, depth = frame.f_back, depth + 1
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def _configure_nio_logging_bridge() -> None:
|
||||
"""Bridge matrix-nio logs to Loguru (idempotent)."""
|
||||
nio_logger = logging.getLogger("nio")
|
||||
if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
|
||||
nio_logger.handlers = [_NioLoguruHandler()]
|
||||
nio_logger.propagate = False
|
||||
|
||||
|
||||
class MatrixChannel(BaseChannel):
|
||||
"""Matrix (Element) channel using long-polling sync."""
|
||||
|
||||
name = "matrix"
|
||||
display_name = "Matrix"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus, groq_api_key: str = ""):
|
||||
super().__init__(config, bus)
|
||||
self.groq_api_key = groq_api_key
|
||||
self.client: AsyncClient | None = None
|
||||
self._sync_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._restrict_to_workspace = False
|
||||
self._workspace: Path | None = None
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
self._sync_ready_logged = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Matrix client and begin sync loop."""
|
||||
self._running = True
|
||||
_configure_nio_logging_bridge()
|
||||
|
||||
store_path = get_data_dir() / "matrix-store"
|
||||
store_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.client = AsyncClient(
|
||||
homeserver=self.config.homeserver, user=self.config.user_id,
|
||||
store_path=store_path,
|
||||
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
|
||||
)
|
||||
self.client.user_id = self.config.user_id
|
||||
self.client.access_token = self.config.access_token
|
||||
self.client.device_id = self.config.device_id
|
||||
|
||||
self._register_event_callbacks()
|
||||
self._register_response_callbacks()
|
||||
|
||||
if not self.config.e2ee_enabled:
|
||||
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
|
||||
|
||||
if self.config.device_id:
|
||||
try:
|
||||
self.client.load_store()
|
||||
except Exception:
|
||||
logger.exception("Matrix store load failed; restart may replay recent messages.")
|
||||
else:
|
||||
logger.warning("Matrix device_id empty; restart may replay recent messages.")
|
||||
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Matrix channel with graceful sync shutdown."""
|
||||
self._running = False
|
||||
for room_id in list(self._typing_tasks):
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
if self.client:
|
||||
self.client.stop_sync_forever()
|
||||
if self._sync_task:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(self._sync_task),
|
||||
timeout=self.config.sync_stop_grace_seconds)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
self._sync_task.cancel()
|
||||
try:
|
||||
await self._sync_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
||||
"""Check path is inside workspace (when restriction enabled)."""
|
||||
if not self._restrict_to_workspace or not self._workspace:
|
||||
return True
|
||||
try:
|
||||
path.resolve(strict=False).relative_to(self._workspace)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
|
||||
"""Deduplicate and resolve outbound attachment paths."""
|
||||
seen: set[str] = set()
|
||||
candidates: list[Path] = []
|
||||
for raw in media:
|
||||
if not isinstance(raw, str) or not raw.strip():
|
||||
continue
|
||||
path = Path(raw.strip()).expanduser()
|
||||
try:
|
||||
key = str(path.resolve(strict=False))
|
||||
except OSError:
|
||||
key = str(path)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
candidates.append(path)
|
||||
return candidates
|
||||
|
||||
@staticmethod
|
||||
def _build_outbound_attachment_content(
|
||||
*, filename: str, mime: str, size_bytes: int,
|
||||
mxc_url: str, encryption_info: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Matrix content payload for an uploaded file/image/audio/video."""
|
||||
prefix = mime.split("/")[0]
|
||||
msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
|
||||
content: dict[str, Any] = {
|
||||
"msgtype": msgtype, "body": filename, "filename": filename,
|
||||
"info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
|
||||
}
|
||||
if encryption_info:
|
||||
content["file"] = {**encryption_info, "url": mxc_url}
|
||||
else:
|
||||
content["url"] = mxc_url
|
||||
return content
|
||||
|
||||
def _is_encrypted_room(self, room_id: str) -> bool:
|
||||
if not self.client:
|
||||
return False
|
||||
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||
return bool(getattr(room, "encrypted", False))
|
||||
|
||||
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||
"""Send m.room.message with E2EE options."""
|
||||
if not self.client:
|
||||
return
|
||||
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||
if self.config.e2ee_enabled:
|
||||
kwargs["ignore_unverified_devices"] = True
|
||||
await self.client.room_send(**kwargs)
|
||||
|
||||
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||
"""Query homeserver upload limit once per channel lifecycle."""
|
||||
if self._server_upload_limit_checked:
|
||||
return self._server_upload_limit_bytes
|
||||
self._server_upload_limit_checked = True
|
||||
if not self.client:
|
||||
return None
|
||||
try:
|
||||
response = await self.client.content_repository_config()
|
||||
except Exception:
|
||||
return None
|
||||
upload_size = getattr(response, "upload_size", None)
|
||||
if isinstance(upload_size, int) and upload_size > 0:
|
||||
self._server_upload_limit_bytes = upload_size
|
||||
return upload_size
|
||||
return None
|
||||
|
||||
async def _effective_media_limit_bytes(self) -> int:
|
||||
"""min(local config, server advertised) — 0 blocks all uploads."""
|
||||
local_limit = max(int(self.config.max_media_bytes), 0)
|
||||
server_limit = await self._resolve_server_upload_limit_bytes()
|
||||
if server_limit is None:
|
||||
return local_limit
|
||||
return min(local_limit, server_limit) if local_limit else 0
|
||||
|
||||
async def _upload_and_send_attachment(
|
||||
self, room_id: str, path: Path, limit_bytes: int,
|
||||
relates_to: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
|
||||
if not self.client:
|
||||
return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
|
||||
|
||||
resolved = path.expanduser().resolve(strict=False)
|
||||
filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
|
||||
fail = _ATTACH_UPLOAD_FAILED.format(filename)
|
||||
|
||||
if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
|
||||
return fail
|
||||
try:
|
||||
size_bytes = resolved.stat().st_size
|
||||
except OSError:
|
||||
return fail
|
||||
if limit_bytes <= 0 or size_bytes > limit_bytes:
|
||||
return _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
|
||||
try:
|
||||
with resolved.open("rb") as f:
|
||||
upload_result = await self.client.upload(
|
||||
f, content_type=mime, filename=filename,
|
||||
encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
|
||||
filesize=size_bytes,
|
||||
)
|
||||
except Exception:
|
||||
return fail
|
||||
|
||||
upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
|
||||
encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
|
||||
if isinstance(upload_response, UploadError):
|
||||
return fail
|
||||
mxc_url = getattr(upload_response, "content_uri", None)
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return fail
|
||||
|
||||
content = self._build_outbound_attachment_content(
|
||||
filename=filename, mime=mime, size_bytes=size_bytes,
|
||||
mxc_url=mxc_url, encryption_info=encryption_info,
|
||||
)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
try:
|
||||
await self._send_room_content(room_id, content)
|
||||
except Exception:
|
||||
return fail
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound content; clear typing for non-progress messages."""
|
||||
if not self.client:
|
||||
return
|
||||
text = msg.content or ""
|
||||
candidates = self._collect_outbound_media_candidates(msg.media)
|
||||
relates_to = self._build_thread_relates_to(msg.metadata)
|
||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||
try:
|
||||
failures: list[str] = []
|
||||
if candidates:
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
for path in candidates:
|
||||
if fail := await self._upload_and_send_attachment(
|
||||
room_id=msg.chat_id,
|
||||
path=path,
|
||||
limit_bytes=limit_bytes,
|
||||
relates_to=relates_to,
|
||||
):
|
||||
failures.append(fail)
|
||||
if failures:
|
||||
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
||||
if text or not candidates:
|
||||
content = _build_matrix_text_content(text)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
await self._send_room_content(msg.chat_id, content)
|
||||
finally:
|
||||
if not is_progress:
|
||||
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||
|
||||
def _register_event_callbacks(self) -> None:
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||
|
||||
def _register_response_callbacks(self) -> None:
|
||||
self.client.add_response_callback(self._on_sync_success, SyncResponse)
|
||||
self.client.add_response_callback(self._on_sync_error, SyncError)
|
||||
self.client.add_response_callback(self._on_join_error, JoinError)
|
||||
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
||||
|
||||
def _log_response_error(self, label: str, response: Any) -> None:
|
||||
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
|
||||
code = getattr(response, "status_code", None)
|
||||
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
||||
is_fatal = is_auth or getattr(response, "soft_logout", False)
|
||||
(logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
|
||||
|
||||
async def _on_sync_success(self, response: SyncResponse) -> None:
|
||||
if self._sync_ready_logged:
|
||||
return
|
||||
rooms = getattr(response, "rooms", None)
|
||||
joined = len(getattr(rooms, "join", {}) or {})
|
||||
invited = len(getattr(rooms, "invite", {}) or {})
|
||||
logger.info(
|
||||
"Matrix sync ready: user={} device={} joined_rooms={} invited_rooms={}",
|
||||
self.config.user_id,
|
||||
self.config.device_id or "-",
|
||||
joined,
|
||||
invited,
|
||||
)
|
||||
self._sync_ready_logged = True
|
||||
|
||||
async def _on_sync_error(self, response: SyncError) -> None:
|
||||
self._log_response_error("sync", response)
|
||||
|
||||
async def _on_join_error(self, response: JoinError) -> None:
|
||||
self._log_response_error("join", response)
|
||||
|
||||
async def _on_send_error(self, response: RoomSendError) -> None:
|
||||
self._log_response_error("send", response)
|
||||
|
||||
async def _set_typing(self, room_id: str, typing: bool) -> None:
|
||||
"""Best-effort typing indicator update."""
|
||||
if not self.client:
|
||||
return
|
||||
try:
|
||||
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
||||
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
||||
if isinstance(response, RoomTypingError):
|
||||
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _start_typing_keepalive(self, room_id: str) -> None:
|
||||
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
await self._set_typing(room_id, True)
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
async def loop() -> None:
|
||||
try:
|
||||
while self._running:
|
||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
|
||||
await self._set_typing(room_id, True)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[room_id] = asyncio.create_task(loop())
|
||||
|
||||
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
|
||||
if task := self._typing_tasks.pop(room_id, None):
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if clear_typing:
|
||||
await self._set_typing(room_id, False)
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
await self.client.sync_forever(timeout=30000, full_state=True)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
|
||||
if self.is_allowed(event.sender):
|
||||
await self.client.join(room.room_id)
|
||||
|
||||
def _is_direct_room(self, room: MatrixRoom) -> bool:
|
||||
count = getattr(room, "member_count", None)
|
||||
return isinstance(count, int) and count <= 2
|
||||
|
||||
def _is_bot_mentioned(self, event: RoomMessage) -> bool:
|
||||
"""Check m.mentions payload for bot mention."""
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return False
|
||||
mentions = (source.get("content") or {}).get("m.mentions")
|
||||
if not isinstance(mentions, dict):
|
||||
return False
|
||||
user_ids = mentions.get("user_ids")
|
||||
if isinstance(user_ids, list) and self.config.user_id in user_ids:
|
||||
return True
|
||||
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||
|
||||
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||
"""Apply sender and room policy checks."""
|
||||
if not self.is_allowed(event.sender):
|
||||
return False
|
||||
if self._is_direct_room(room):
|
||||
return True
|
||||
policy = self.config.group_policy
|
||||
if policy == "open":
|
||||
return True
|
||||
if policy == "allowlist":
|
||||
return room.room_id in (self.config.group_allow_from or [])
|
||||
if policy == "mention":
|
||||
return self._is_bot_mentioned(event)
|
||||
return False
|
||||
|
||||
def _media_dir(self) -> Path:
|
||||
return get_media_dir("matrix")
|
||||
|
||||
async def transcribe_audio(self, file_path: str) -> str:
|
||||
"""Best-effort audio transcription for inbound Matrix voice/audio messages."""
|
||||
try:
|
||||
return await GroqTranscriptionProvider(api_key=self.groq_api_key).transcribe(file_path)
|
||||
except Exception:
|
||||
logger.exception("Matrix audio transcription failed")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return {}
|
||||
content = source.get("content")
|
||||
return content if isinstance(content, dict) else {}
|
||||
|
||||
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
|
||||
relates_to = self._event_source_content(event).get("m.relates_to")
|
||||
if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
|
||||
return None
|
||||
root_id = relates_to.get("event_id")
|
||||
return root_id if isinstance(root_id, str) and root_id else None
|
||||
|
||||
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
|
||||
if not (root_id := self._event_thread_root_id(event)):
|
||||
return None
|
||||
meta: dict[str, str] = {"thread_root_event_id": root_id}
|
||||
if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
|
||||
meta["thread_reply_to_event_id"] = reply_to
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not metadata:
|
||||
return None
|
||||
root_id = metadata.get("thread_root_event_id")
|
||||
if not isinstance(root_id, str) or not root_id:
|
||||
return None
|
||||
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
|
||||
if not isinstance(reply_to, str) or not reply_to:
|
||||
return None
|
||||
return {"rel_type": "m.thread", "event_id": root_id,
|
||||
"m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
|
||||
|
||||
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
|
||||
msgtype = self._event_source_content(event).get("msgtype")
|
||||
return _MSGTYPE_MAP.get(msgtype, "file")
|
||||
|
||||
@staticmethod
|
||||
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
|
||||
return (isinstance(getattr(event, "key", None), dict)
|
||||
and isinstance(getattr(event, "hashes", None), dict)
|
||||
and isinstance(getattr(event, "iv", None), str))
|
||||
|
||||
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
size = info.get("size") if isinstance(info, dict) else None
|
||||
return size if isinstance(size, int) and size >= 0 else None
|
||||
|
||||
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
|
||||
return m
|
||||
m = getattr(event, "mimetype", None)
|
||||
return m if isinstance(m, str) and m else None
|
||||
|
||||
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
|
||||
body = getattr(event, "body", None)
|
||||
if isinstance(body, str) and body.strip():
|
||||
if candidate := safe_filename(Path(body).name):
|
||||
return candidate
|
||||
return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
|
||||
|
||||
def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
|
||||
filename: str, mime: str | None) -> Path:
|
||||
safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
|
||||
suffix = Path(safe_name).suffix
|
||||
if not suffix and mime:
|
||||
if guessed := mimetypes.guess_extension(mime, strict=False):
|
||||
safe_name, suffix = f"{safe_name}{guessed}", guessed
|
||||
stem = (Path(safe_name).stem or attachment_type)[:72]
|
||||
suffix = suffix[:16]
|
||||
event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
|
||||
event_prefix = (event_id[:24] or "evt").strip("_")
|
||||
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
|
||||
|
||||
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
|
||||
if not self.client:
|
||||
return None
|
||||
response = await self.client.download(mxc=mxc_url)
|
||||
if isinstance(response, DownloadError):
|
||||
logger.warning("Matrix download failed for {}: {}", mxc_url, response)
|
||||
return None
|
||||
body = getattr(response, "body", None)
|
||||
if isinstance(body, (bytes, bytearray)):
|
||||
return bytes(body)
|
||||
if isinstance(response, MemoryDownloadResponse):
|
||||
return bytes(response.body)
|
||||
if isinstance(body, (str, Path)):
|
||||
path = Path(body)
|
||||
if path.is_file():
|
||||
try:
|
||||
return path.read_bytes()
|
||||
except OSError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
|
||||
key = key_obj.get("k") if isinstance(key_obj, dict) else None
|
||||
sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
|
||||
if not all(isinstance(v, str) for v in (key, sha256, iv)):
|
||||
return None
|
||||
try:
|
||||
return decrypt_attachment(ciphertext, key, sha256, iv)
|
||||
except (EncryptionError, ValueError, TypeError):
|
||||
logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
|
||||
return None
|
||||
|
||||
async def _fetch_media_attachment(
|
||||
self, room: MatrixRoom, event: MatrixMediaEvent,
|
||||
) -> tuple[dict[str, Any] | None, str]:
|
||||
"""Download, decrypt if needed, and persist a Matrix attachment."""
|
||||
atype = self._event_attachment_type(event)
|
||||
mime = self._event_mime(event)
|
||||
filename = self._event_filename(event, atype)
|
||||
mxc_url = getattr(event, "url", None)
|
||||
fail = _ATTACH_FAILED.format(filename)
|
||||
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return None, fail
|
||||
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
declared = self._event_declared_size_bytes(event)
|
||||
if declared is not None and declared > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
downloaded = await self._download_media_bytes(mxc_url)
|
||||
if downloaded is None:
|
||||
return None, fail
|
||||
|
||||
encrypted = self._is_encrypted_media_event(event)
|
||||
data = downloaded
|
||||
if encrypted:
|
||||
if (data := self._decrypt_media_bytes(event, downloaded)) is None:
|
||||
return None, fail
|
||||
|
||||
if len(data) > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
path = self._build_attachment_path(event, atype, filename, mime)
|
||||
try:
|
||||
path.write_bytes(data)
|
||||
except OSError:
|
||||
return None, fail
|
||||
|
||||
attachment = {
|
||||
"type": atype, "mime": mime, "filename": filename,
|
||||
"event_id": str(getattr(event, "event_id", "") or ""),
|
||||
"encrypted": encrypted, "size_bytes": len(data),
|
||||
"path": str(path), "mxc_url": mxc_url,
|
||||
}
|
||||
return attachment, _ATTACH_MARKER.format(path)
|
||||
|
||||
def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
|
||||
"""Build common metadata for text and media handlers."""
|
||||
meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
|
||||
if isinstance(eid := getattr(event, "event_id", None), str) and eid:
|
||||
meta["event_id"] = eid
|
||||
if thread := self._thread_metadata(event):
|
||||
meta.update(thread)
|
||||
return meta
|
||||
|
||||
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content=event.body, metadata=self._base_metadata(room, event),
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
|
||||
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
attachment, marker = await self._fetch_media_attachment(room, event)
|
||||
parts: list[str] = []
|
||||
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||
parts.append(body.strip())
|
||||
|
||||
if attachment and attachment.get("type") == "audio":
|
||||
transcription = await self.transcribe_audio(attachment["path"])
|
||||
if transcription:
|
||||
parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
parts.append(marker)
|
||||
elif marker:
|
||||
parts.append(marker)
|
||||
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
meta = self._base_metadata(room, event)
|
||||
meta["attachments"] = []
|
||||
if attachment:
|
||||
meta["attachments"] = [attachment]
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content="\n".join(parts),
|
||||
media=[attachment["path"]] if attachment else [],
|
||||
metadata=meta,
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
895
app-instance/backend/nanobot/channels/mochat.py
Normal file
895
app-instance/backend/nanobot/channels/mochat.py
Normal file
@ -0,0 +1,895 @@
|
||||
"""Mochat channel implementation using Socket.IO with HTTP polling fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import MochatConfig
|
||||
from nanobot.utils.helpers import get_data_path
|
||||
|
||||
try:
|
||||
import socketio
|
||||
SOCKETIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
socketio = None
|
||||
SOCKETIO_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import msgpack # noqa: F401
|
||||
MSGPACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
MSGPACK_AVAILABLE = False
|
||||
|
||||
MAX_SEEN_MESSAGE_IDS = 2000
|
||||
CURSOR_SAVE_DEBOUNCE_S = 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MochatBufferedEntry:
|
||||
"""Buffered inbound entry for delayed dispatch."""
|
||||
raw_body: str
|
||||
author: str
|
||||
sender_name: str = ""
|
||||
sender_username: str = ""
|
||||
timestamp: int | None = None
|
||||
message_id: str = ""
|
||||
group_id: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelayState:
|
||||
"""Per-target delayed message state."""
|
||||
entries: list[MochatBufferedEntry] = field(default_factory=list)
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
timer: asyncio.Task | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MochatTarget:
|
||||
"""Outbound target resolution result."""
|
||||
id: str
|
||||
is_panel: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_dict(value: Any) -> dict:
|
||||
"""Return *value* if it's a dict, else empty dict."""
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _str_field(src: dict, *keys: str) -> str:
|
||||
"""Return the first non-empty str value found for *keys*, stripped."""
|
||||
for k in keys:
|
||||
v = src.get(k)
|
||||
if isinstance(v, str) and v.strip():
|
||||
return v.strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _make_synthetic_event(
|
||||
message_id: str, author: str, content: Any,
|
||||
meta: Any, group_id: str, converse_id: str,
|
||||
timestamp: Any = None, *, author_info: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a synthetic ``message.add`` event dict."""
|
||||
payload: dict[str, Any] = {
|
||||
"messageId": message_id, "author": author,
|
||||
"content": content, "meta": _safe_dict(meta),
|
||||
"groupId": group_id, "converseId": converse_id,
|
||||
}
|
||||
if author_info is not None:
|
||||
payload["authorInfo"] = _safe_dict(author_info)
|
||||
return {
|
||||
"type": "message.add",
|
||||
"timestamp": timestamp or datetime.utcnow().isoformat(),
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def normalize_mochat_content(content: Any) -> str:
|
||||
"""Normalize content payload to text."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
|
||||
def resolve_mochat_target(raw: str) -> MochatTarget:
|
||||
"""Resolve id and target kind from user-provided target string."""
|
||||
trimmed = (raw or "").strip()
|
||||
if not trimmed:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
|
||||
lowered = trimmed.lower()
|
||||
cleaned, forced_panel = trimmed, False
|
||||
for prefix in ("mochat:", "group:", "channel:", "panel:"):
|
||||
if lowered.startswith(prefix):
|
||||
cleaned = trimmed[len(prefix):].strip()
|
||||
forced_panel = prefix in {"group:", "channel:", "panel:"}
|
||||
break
|
||||
|
||||
if not cleaned:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
return MochatTarget(id=cleaned, is_panel=forced_panel or not cleaned.startswith("session_"))
|
||||
|
||||
|
||||
def extract_mention_ids(value: Any) -> list[str]:
|
||||
"""Extract mention ids from heterogeneous mention payload."""
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
ids: list[str] = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
if item.strip():
|
||||
ids.append(item.strip())
|
||||
elif isinstance(item, dict):
|
||||
for key in ("id", "userId", "_id"):
|
||||
candidate = item.get(key)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
ids.append(candidate.strip())
|
||||
break
|
||||
return ids
|
||||
|
||||
|
||||
def resolve_was_mentioned(payload: dict[str, Any], agent_user_id: str) -> bool:
|
||||
"""Resolve mention state from payload metadata and text fallback."""
|
||||
meta = payload.get("meta")
|
||||
if isinstance(meta, dict):
|
||||
if meta.get("mentioned") is True or meta.get("wasMentioned") is True:
|
||||
return True
|
||||
for f in ("mentions", "mentionIds", "mentionedUserIds", "mentionedUsers"):
|
||||
if agent_user_id and agent_user_id in extract_mention_ids(meta.get(f)):
|
||||
return True
|
||||
if not agent_user_id:
|
||||
return False
|
||||
content = payload.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
return False
|
||||
return f"<@{agent_user_id}>" in content or f"@{agent_user_id}" in content
|
||||
|
||||
|
||||
def resolve_require_mention(config: MochatConfig, session_id: str, group_id: str) -> bool:
|
||||
"""Resolve mention requirement for group/panel conversations."""
|
||||
groups = config.groups or {}
|
||||
for key in (group_id, session_id, "*"):
|
||||
if key and key in groups:
|
||||
return bool(groups[key].require_mention)
|
||||
return bool(config.mention.require_in_groups)
|
||||
|
||||
|
||||
def build_buffered_body(entries: list[MochatBufferedEntry], is_group: bool) -> str:
|
||||
"""Build text body from one or more buffered entries."""
|
||||
if not entries:
|
||||
return ""
|
||||
if len(entries) == 1:
|
||||
return entries[0].raw_body
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
if not entry.raw_body:
|
||||
continue
|
||||
if is_group:
|
||||
label = entry.sender_name.strip() or entry.sender_username.strip() or entry.author
|
||||
if label:
|
||||
lines.append(f"{label}: {entry.raw_body}")
|
||||
continue
|
||||
lines.append(entry.raw_body)
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def parse_timestamp(value: Any) -> int | None:
|
||||
"""Parse event timestamp to epoch milliseconds."""
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return None
|
||||
try:
|
||||
return int(datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() * 1000)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MochatChannel(BaseChannel):
|
||||
"""Mochat channel using socket.io with fallback polling workers."""
|
||||
|
||||
name = "mochat"
|
||||
|
||||
def __init__(self, config: MochatConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: MochatConfig = config
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._socket: Any = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
self._state_dir = get_data_path() / "mochat"
|
||||
self._cursor_path = self._state_dir / "session_cursors.json"
|
||||
self._session_cursor: dict[str, int] = {}
|
||||
self._cursor_save_task: asyncio.Task | None = None
|
||||
|
||||
self._session_set: set[str] = set()
|
||||
self._panel_set: set[str] = set()
|
||||
self._auto_discover_sessions = self._auto_discover_panels = False
|
||||
|
||||
self._cold_sessions: set[str] = set()
|
||||
self._session_by_converse: dict[str, str] = {}
|
||||
|
||||
self._seen_set: dict[str, set[str]] = {}
|
||||
self._seen_queue: dict[str, deque[str]] = {}
|
||||
self._delay_states: dict[str, DelayState] = {}
|
||||
|
||||
self._fallback_mode = False
|
||||
self._session_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._panel_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
self._target_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# ---- lifecycle ---------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Mochat channel workers and websocket connection."""
|
||||
if not self.config.claw_token:
|
||||
logger.error("Mochat claw_token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
await self._load_session_cursors()
|
||||
self._seed_targets_from_config()
|
||||
await self._refresh_targets(subscribe_new=False)
|
||||
|
||||
if not await self._start_socket_client():
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
self._refresh_task = asyncio.create_task(self._refresh_loop())
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop all workers and clean up resources."""
|
||||
self._running = False
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
|
||||
await self._stop_fallback_workers()
|
||||
await self._cancel_delay_timers()
|
||||
|
||||
if self._socket:
|
||||
try:
|
||||
await self._socket.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
|
||||
if self._cursor_save_task:
|
||||
self._cursor_save_task.cancel()
|
||||
self._cursor_save_task = None
|
||||
await self._save_session_cursors()
|
||||
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound message to session or panel."""
|
||||
if not self.config.claw_token:
|
||||
logger.warning("Mochat claw_token missing, skip send")
|
||||
return
|
||||
|
||||
parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
|
||||
if msg.media:
|
||||
parts.extend(m for m in msg.media if isinstance(m, str) and m.strip())
|
||||
content = "\n".join(parts).strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
target = resolve_mochat_target(msg.chat_id)
|
||||
if not target.id:
|
||||
logger.warning("Mochat outbound target is empty")
|
||||
return
|
||||
|
||||
is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
|
||||
try:
|
||||
if is_panel:
|
||||
await self._api_send("/api/claw/groups/panels/send", "panelId", target.id,
|
||||
content, msg.reply_to, self._read_group_id(msg.metadata))
|
||||
else:
|
||||
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||
content, msg.reply_to)
|
||||
except Exception as e:
|
||||
logger.error("Failed to send Mochat message: {}", e)
|
||||
|
||||
# ---- config / init helpers ---------------------------------------------
|
||||
|
||||
def _seed_targets_from_config(self) -> None:
|
||||
sessions, self._auto_discover_sessions = self._normalize_id_list(self.config.sessions)
|
||||
panels, self._auto_discover_panels = self._normalize_id_list(self.config.panels)
|
||||
self._session_set.update(sessions)
|
||||
self._panel_set.update(panels)
|
||||
for sid in sessions:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_id_list(values: list[str]) -> tuple[list[str], bool]:
|
||||
cleaned = [str(v).strip() for v in values if str(v).strip()]
|
||||
return sorted({v for v in cleaned if v != "*"}), "*" in cleaned
|
||||
|
||||
# ---- websocket ---------------------------------------------------------
|
||||
|
||||
async def _start_socket_client(self) -> bool:
|
||||
if not SOCKETIO_AVAILABLE:
|
||||
logger.warning("python-socketio not installed, Mochat using polling fallback")
|
||||
return False
|
||||
|
||||
serializer = "default"
|
||||
if not self.config.socket_disable_msgpack:
|
||||
if MSGPACK_AVAILABLE:
|
||||
serializer = "msgpack"
|
||||
else:
|
||||
logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
|
||||
|
||||
client = socketio.AsyncClient(
|
||||
reconnection=True,
|
||||
reconnection_attempts=self.config.max_retry_attempts or None,
|
||||
reconnection_delay=max(0.1, self.config.socket_reconnect_delay_ms / 1000.0),
|
||||
reconnection_delay_max=max(0.1, self.config.socket_max_reconnect_delay_ms / 1000.0),
|
||||
logger=False, engineio_logger=False, serializer=serializer,
|
||||
)
|
||||
|
||||
@client.event
|
||||
async def connect() -> None:
|
||||
self._ws_connected, self._ws_ready = True, False
|
||||
logger.info("Mochat websocket connected")
|
||||
subscribed = await self._subscribe_all()
|
||||
self._ws_ready = subscribed
|
||||
await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
|
||||
|
||||
@client.event
|
||||
async def disconnect() -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._ws_connected = self._ws_ready = False
|
||||
logger.warning("Mochat websocket disconnected")
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
@client.event
|
||||
async def connect_error(data: Any) -> None:
|
||||
logger.error("Mochat websocket connect error: {}", data)
|
||||
|
||||
@client.on("claw.session.events")
|
||||
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
|
||||
@client.on("claw.panel.events")
|
||||
async def on_panel_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "panel")
|
||||
|
||||
for ev in ("notify:chat.inbox.append", "notify:chat.message.add",
|
||||
"notify:chat.message.update", "notify:chat.message.recall",
|
||||
"notify:chat.message.delete"):
|
||||
client.on(ev, self._build_notify_handler(ev))
|
||||
|
||||
socket_url = (self.config.socket_url or self.config.base_url).strip().rstrip("/")
|
||||
socket_path = (self.config.socket_path or "/socket.io").strip().lstrip("/")
|
||||
|
||||
try:
|
||||
self._socket = client
|
||||
await client.connect(
|
||||
socket_url, transports=["websocket"], socketio_path=socket_path,
|
||||
auth={"token": self.config.claw_token},
|
||||
wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect Mochat websocket: {}", e)
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
return False
|
||||
|
||||
def _build_notify_handler(self, event_name: str):
|
||||
async def handler(payload: Any) -> None:
|
||||
if event_name == "notify:chat.inbox.append":
|
||||
await self._handle_notify_inbox_append(payload)
|
||||
elif event_name.startswith("notify:chat.message."):
|
||||
await self._handle_notify_chat_message(payload)
|
||||
return handler
|
||||
|
||||
# ---- subscribe ---------------------------------------------------------
|
||||
|
||||
async def _subscribe_all(self) -> bool:
|
||||
ok = await self._subscribe_sessions(sorted(self._session_set))
|
||||
ok = await self._subscribe_panels(sorted(self._panel_set)) and ok
|
||||
if self._auto_discover_sessions or self._auto_discover_panels:
|
||||
await self._refresh_targets(subscribe_new=True)
|
||||
return ok
|
||||
|
||||
async def _subscribe_sessions(self, session_ids: list[str]) -> bool:
|
||||
if not session_ids:
|
||||
return True
|
||||
for sid in session_ids:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
ack = await self._socket_call("com.claw.im.subscribeSessions", {
|
||||
"sessionIds": session_ids, "cursors": self._session_cursor,
|
||||
"limit": self.config.watch_limit,
|
||||
})
|
||||
if not ack.get("result"):
|
||||
logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
|
||||
data = ack.get("data")
|
||||
items: list[dict[str, Any]] = []
|
||||
if isinstance(data, list):
|
||||
items = [i for i in data if isinstance(i, dict)]
|
||||
elif isinstance(data, dict):
|
||||
sessions = data.get("sessions")
|
||||
if isinstance(sessions, list):
|
||||
items = [i for i in sessions if isinstance(i, dict)]
|
||||
elif "sessionId" in data:
|
||||
items = [data]
|
||||
for p in items:
|
||||
await self._handle_watch_payload(p, "session")
|
||||
return True
|
||||
|
||||
async def _subscribe_panels(self, panel_ids: list[str]) -> bool:
|
||||
if not self._auto_discover_panels and not panel_ids:
|
||||
return True
|
||||
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||
if not ack.get("result"):
|
||||
logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _socket_call(self, event_name: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._socket:
|
||||
return {"result": False, "message": "socket not connected"}
|
||||
try:
|
||||
raw = await self._socket.call(event_name, payload, timeout=10)
|
||||
except Exception as e:
|
||||
return {"result": False, "message": str(e)}
|
||||
return raw if isinstance(raw, dict) else {"result": True, "data": raw}
|
||||
|
||||
# ---- refresh / discovery -----------------------------------------------
|
||||
|
||||
async def _refresh_loop(self) -> None:
|
||||
interval_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running:
|
||||
await asyncio.sleep(interval_s)
|
||||
try:
|
||||
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||
except Exception as e:
|
||||
logger.warning("Mochat refresh failed: {}", e)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_targets(self, subscribe_new: bool) -> None:
|
||||
if self._auto_discover_sessions:
|
||||
await self._refresh_sessions_directory(subscribe_new)
|
||||
if self._auto_discover_panels:
|
||||
await self._refresh_panels(subscribe_new)
|
||||
|
||||
async def _refresh_sessions_directory(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/sessions/list", {})
|
||||
except Exception as e:
|
||||
logger.warning("Mochat listSessions failed: {}", e)
|
||||
return
|
||||
|
||||
sessions = response.get("sessions")
|
||||
if not isinstance(sessions, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for s in sessions:
|
||||
if not isinstance(s, dict):
|
||||
continue
|
||||
sid = _str_field(s, "sessionId")
|
||||
if not sid:
|
||||
continue
|
||||
if sid not in self._session_set:
|
||||
self._session_set.add(sid)
|
||||
new_ids.append(sid)
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
cid = _str_field(s, "converseId")
|
||||
if cid:
|
||||
self._session_by_converse[cid] = sid
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_sessions(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_panels(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/groups/get", {})
|
||||
except Exception as e:
|
||||
logger.warning("Mochat getWorkspaceGroup failed: {}", e)
|
||||
return
|
||||
|
||||
raw_panels = response.get("panels")
|
||||
if not isinstance(raw_panels, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for p in raw_panels:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
pt = p.get("type")
|
||||
if isinstance(pt, int) and pt != 0:
|
||||
continue
|
||||
pid = _str_field(p, "id", "_id")
|
||||
if pid and pid not in self._panel_set:
|
||||
self._panel_set.add(pid)
|
||||
new_ids.append(pid)
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_panels(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
# ---- fallback workers --------------------------------------------------
|
||||
|
||||
async def _ensure_fallback_workers(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._fallback_mode = True
|
||||
for sid in sorted(self._session_set):
|
||||
t = self._session_fallback_tasks.get(sid)
|
||||
if not t or t.done():
|
||||
self._session_fallback_tasks[sid] = asyncio.create_task(self._session_watch_worker(sid))
|
||||
for pid in sorted(self._panel_set):
|
||||
t = self._panel_fallback_tasks.get(pid)
|
||||
if not t or t.done():
|
||||
self._panel_fallback_tasks[pid] = asyncio.create_task(self._panel_poll_worker(pid))
|
||||
|
||||
async def _stop_fallback_workers(self) -> None:
|
||||
self._fallback_mode = False
|
||||
tasks = [*self._session_fallback_tasks.values(), *self._panel_fallback_tasks.values()]
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._session_fallback_tasks.clear()
|
||||
self._panel_fallback_tasks.clear()
|
||||
|
||||
async def _session_watch_worker(self, session_id: str) -> None:
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
payload = await self._post_json("/api/claw/sessions/watch", {
|
||||
"sessionId": session_id, "cursor": self._session_cursor.get(session_id, 0),
|
||||
"timeoutMs": self.config.watch_timeout_ms, "limit": self.config.watch_limit,
|
||||
})
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
|
||||
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||
|
||||
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||
sleep_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
resp = await self._post_json("/api/claw/groups/panels/messages", {
|
||||
"panelId": panel_id, "limit": min(100, max(1, self.config.watch_limit)),
|
||||
})
|
||||
msgs = resp.get("messages")
|
||||
if isinstance(msgs, list):
|
||||
for m in reversed(msgs):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(m.get("messageId") or ""),
|
||||
author=str(m.get("author") or ""),
|
||||
content=m.get("content"),
|
||||
meta=m.get("meta"), group_id=str(resp.get("groupId") or ""),
|
||||
converse_id=panel_id, timestamp=m.get("createdAt"),
|
||||
author_info=m.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
|
||||
await asyncio.sleep(sleep_s)
|
||||
|
||||
# ---- inbound event processing ------------------------------------------
|
||||
|
||||
async def _handle_watch_payload(self, payload: dict[str, Any], target_kind: str) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
target_id = _str_field(payload, "sessionId")
|
||||
if not target_id:
|
||||
return
|
||||
|
||||
lock = self._target_locks.setdefault(f"{target_kind}:{target_id}", asyncio.Lock())
|
||||
async with lock:
|
||||
prev = self._session_cursor.get(target_id, 0) if target_kind == "session" else 0
|
||||
pc = payload.get("cursor")
|
||||
if target_kind == "session" and isinstance(pc, int) and pc >= 0:
|
||||
self._mark_session_cursor(target_id, pc)
|
||||
|
||||
raw_events = payload.get("events")
|
||||
if not isinstance(raw_events, list):
|
||||
return
|
||||
if target_kind == "session" and target_id in self._cold_sessions:
|
||||
self._cold_sessions.discard(target_id)
|
||||
return
|
||||
|
||||
for event in raw_events:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
seq = event.get("seq")
|
||||
if target_kind == "session" and isinstance(seq, int) and seq > self._session_cursor.get(target_id, prev):
|
||||
self._mark_session_cursor(target_id, seq)
|
||||
if event.get("type") == "message.add":
|
||||
await self._process_inbound_event(target_id, event, target_kind)
|
||||
|
||||
async def _process_inbound_event(self, target_id: str, event: dict[str, Any], target_kind: str) -> None:
|
||||
payload = event.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
|
||||
author = _str_field(payload, "author")
|
||||
if not author or (self.config.agent_user_id and author == self.config.agent_user_id):
|
||||
return
|
||||
if not self.is_allowed(author):
|
||||
return
|
||||
|
||||
message_id = _str_field(payload, "messageId")
|
||||
seen_key = f"{target_kind}:{target_id}"
|
||||
if message_id and self._remember_message_id(seen_key, message_id):
|
||||
return
|
||||
|
||||
raw_body = normalize_mochat_content(payload.get("content")) or "[empty message]"
|
||||
ai = _safe_dict(payload.get("authorInfo"))
|
||||
sender_name = _str_field(ai, "nickname", "email")
|
||||
sender_username = _str_field(ai, "agentId")
|
||||
|
||||
group_id = _str_field(payload, "groupId")
|
||||
is_group = bool(group_id)
|
||||
was_mentioned = resolve_was_mentioned(payload, self.config.agent_user_id)
|
||||
require_mention = target_kind == "panel" and is_group and resolve_require_mention(self.config, target_id, group_id)
|
||||
use_delay = target_kind == "panel" and self.config.reply_delay_mode == "non-mention"
|
||||
|
||||
if require_mention and not was_mentioned and not use_delay:
|
||||
return
|
||||
|
||||
entry = MochatBufferedEntry(
|
||||
raw_body=raw_body, author=author, sender_name=sender_name,
|
||||
sender_username=sender_username, timestamp=parse_timestamp(event.get("timestamp")),
|
||||
message_id=message_id, group_id=group_id,
|
||||
)
|
||||
|
||||
if use_delay:
|
||||
delay_key = seen_key
|
||||
if was_mentioned:
|
||||
await self._flush_delayed_entries(delay_key, target_id, target_kind, "mention", entry)
|
||||
else:
|
||||
await self._enqueue_delayed_entry(delay_key, target_id, target_kind, entry)
|
||||
return
|
||||
|
||||
await self._dispatch_entries(target_id, target_kind, [entry], was_mentioned)
|
||||
|
||||
# ---- dedup / buffering -------------------------------------------------
|
||||
|
||||
def _remember_message_id(self, key: str, message_id: str) -> bool:
|
||||
seen_set = self._seen_set.setdefault(key, set())
|
||||
seen_queue = self._seen_queue.setdefault(key, deque())
|
||||
if message_id in seen_set:
|
||||
return True
|
||||
seen_set.add(message_id)
|
||||
seen_queue.append(message_id)
|
||||
while len(seen_queue) > MAX_SEEN_MESSAGE_IDS:
|
||||
seen_set.discard(seen_queue.popleft())
|
||||
return False
|
||||
|
||||
async def _enqueue_delayed_entry(self, key: str, target_id: str, target_kind: str, entry: MochatBufferedEntry) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
state.entries.append(entry)
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
state.timer = asyncio.create_task(self._delay_flush_after(key, target_id, target_kind))
|
||||
|
||||
async def _delay_flush_after(self, key: str, target_id: str, target_kind: str) -> None:
|
||||
await asyncio.sleep(max(0, self.config.reply_delay_ms) / 1000.0)
|
||||
await self._flush_delayed_entries(key, target_id, target_kind, "timer", None)
|
||||
|
||||
async def _flush_delayed_entries(self, key: str, target_id: str, target_kind: str, reason: str, entry: MochatBufferedEntry | None) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
if entry:
|
||||
state.entries.append(entry)
|
||||
current = asyncio.current_task()
|
||||
if state.timer and state.timer is not current:
|
||||
state.timer.cancel()
|
||||
state.timer = None
|
||||
entries = state.entries[:]
|
||||
state.entries.clear()
|
||||
if entries:
|
||||
await self._dispatch_entries(target_id, target_kind, entries, reason == "mention")
|
||||
|
||||
async def _dispatch_entries(self, target_id: str, target_kind: str, entries: list[MochatBufferedEntry], was_mentioned: bool) -> None:
|
||||
if not entries:
|
||||
return
|
||||
last = entries[-1]
|
||||
is_group = bool(last.group_id)
|
||||
body = build_buffered_body(entries, is_group) or "[empty message]"
|
||||
await self._handle_message(
|
||||
sender_id=last.author, chat_id=target_id, content=body,
|
||||
metadata={
|
||||
"message_id": last.message_id, "timestamp": last.timestamp,
|
||||
"is_group": is_group, "group_id": last.group_id,
|
||||
"sender_name": last.sender_name, "sender_username": last.sender_username,
|
||||
"target_kind": target_kind, "was_mentioned": was_mentioned,
|
||||
"buffered_count": len(entries),
|
||||
},
|
||||
)
|
||||
|
||||
async def _cancel_delay_timers(self) -> None:
|
||||
for state in self._delay_states.values():
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
self._delay_states.clear()
|
||||
|
||||
# ---- notify handlers ---------------------------------------------------
|
||||
|
||||
async def _handle_notify_chat_message(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
group_id = _str_field(payload, "groupId")
|
||||
panel_id = _str_field(payload, "converseId", "panelId")
|
||||
if not group_id or not panel_id:
|
||||
return
|
||||
if self._panel_set and panel_id not in self._panel_set:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(payload.get("_id") or payload.get("messageId") or ""),
|
||||
author=str(payload.get("author") or ""),
|
||||
content=payload.get("content"), meta=payload.get("meta"),
|
||||
group_id=group_id, converse_id=panel_id,
|
||||
timestamp=payload.get("createdAt"), author_info=payload.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
|
||||
async def _handle_notify_inbox_append(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict) or payload.get("type") != "message":
|
||||
return
|
||||
detail = payload.get("payload")
|
||||
if not isinstance(detail, dict):
|
||||
return
|
||||
if _str_field(detail, "groupId"):
|
||||
return
|
||||
converse_id = _str_field(detail, "converseId")
|
||||
if not converse_id:
|
||||
return
|
||||
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
await self._refresh_sessions_directory(self._ws_ready)
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(detail.get("messageId") or payload.get("_id") or ""),
|
||||
author=str(detail.get("messageAuthor") or ""),
|
||||
content=str(detail.get("messagePlainContent") or detail.get("messageSnippet") or ""),
|
||||
meta={"source": "notify:chat.inbox.append", "converseId": converse_id},
|
||||
group_id="", converse_id=converse_id, timestamp=payload.get("createdAt"),
|
||||
)
|
||||
await self._process_inbound_event(session_id, evt, "session")
|
||||
|
||||
# ---- cursor persistence ------------------------------------------------
|
||||
|
||||
def _mark_session_cursor(self, session_id: str, cursor: int) -> None:
|
||||
if cursor < 0 or cursor < self._session_cursor.get(session_id, 0):
|
||||
return
|
||||
self._session_cursor[session_id] = cursor
|
||||
if not self._cursor_save_task or self._cursor_save_task.done():
|
||||
self._cursor_save_task = asyncio.create_task(self._save_cursor_debounced())
|
||||
|
||||
async def _save_cursor_debounced(self) -> None:
|
||||
await asyncio.sleep(CURSOR_SAVE_DEBOUNCE_S)
|
||||
await self._save_session_cursors()
|
||||
|
||||
async def _load_session_cursors(self) -> None:
|
||||
if not self._cursor_path.exists():
|
||||
return
|
||||
try:
|
||||
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read Mochat cursor file: {}", e)
|
||||
return
|
||||
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||
if isinstance(cursors, dict):
|
||||
for sid, cur in cursors.items():
|
||||
if isinstance(sid, str) and isinstance(cur, int) and cur >= 0:
|
||||
self._session_cursor[sid] = cur
|
||||
|
||||
async def _save_session_cursors(self) -> None:
|
||||
try:
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._cursor_path.write_text(json.dumps({
|
||||
"schemaVersion": 1, "updatedAt": datetime.utcnow().isoformat(),
|
||||
"cursors": self._session_cursor,
|
||||
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save Mochat cursor file: {}", e)
|
||||
|
||||
# ---- HTTP helpers ------------------------------------------------------
|
||||
|
||||
async def _post_json(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._http:
|
||||
raise RuntimeError("Mochat HTTP client not initialized")
|
||||
url = f"{self.config.base_url.strip().rstrip('/')}{path}"
|
||||
response = await self._http.post(url, headers={
|
||||
"Content-Type": "application/json", "X-Claw-Token": self.config.claw_token,
|
||||
}, json=payload)
|
||||
if not response.is_success:
|
||||
raise RuntimeError(f"Mochat HTTP {response.status_code}: {response.text[:200]}")
|
||||
try:
|
||||
parsed = response.json()
|
||||
except Exception:
|
||||
parsed = response.text
|
||||
if isinstance(parsed, dict) and isinstance(parsed.get("code"), int):
|
||||
if parsed["code"] != 200:
|
||||
msg = str(parsed.get("message") or parsed.get("name") or "request failed")
|
||||
raise RuntimeError(f"Mochat API error: {msg} (code={parsed['code']})")
|
||||
data = parsed.get("data")
|
||||
return data if isinstance(data, dict) else {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
async def _api_send(self, path: str, id_key: str, id_val: str,
|
||||
content: str, reply_to: str | None, group_id: str | None = None) -> dict[str, Any]:
|
||||
"""Unified send helper for session and panel messages."""
|
||||
body: dict[str, Any] = {id_key: id_val, "content": content}
|
||||
if reply_to:
|
||||
body["replyTo"] = reply_to
|
||||
if group_id:
|
||||
body["groupId"] = group_id
|
||||
return await self._post_json(path, body)
|
||||
|
||||
@staticmethod
|
||||
def _read_group_id(metadata: dict[str, Any]) -> str | None:
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
value = metadata.get("group_id") or metadata.get("groupId")
|
||||
return value.strip() if isinstance(value, str) and value.strip() else None
|
||||
132
app-instance/backend/nanobot/channels/qq.py
Normal file
132
app-instance/backend/nanobot/channels/qq.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""QQ channel implementation using botpy SDK."""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import QQConfig
|
||||
|
||||
try:
|
||||
import botpy
|
||||
from botpy.message import C2CMessage
|
||||
|
||||
QQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
botpy = None
|
||||
C2CMessage = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botpy.message import C2CMessage
|
||||
|
||||
|
||||
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
||||
"""Create a botpy Client subclass bound to the given channel."""
|
||||
intents = botpy.Intents(public_messages=True, direct_message=True)
|
||||
|
||||
class _Bot(botpy.Client):
|
||||
def __init__(self):
|
||||
super().__init__(intents=intents)
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info("QQ bot ready: {}", self.robot.name)
|
||||
|
||||
async def on_c2c_message_create(self, message: "C2CMessage"):
|
||||
await channel._on_message(message)
|
||||
|
||||
async def on_direct_message_create(self, message):
|
||||
await channel._on_message(message)
|
||||
|
||||
return _Bot
|
||||
|
||||
|
||||
class QQChannel(BaseChannel):
|
||||
"""QQ channel using botpy SDK with WebSocket connection."""
|
||||
|
||||
name = "qq"
|
||||
|
||||
def __init__(self, config: QQConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: QQConfig = config
|
||||
self._client: "botpy.Client | None" = None
|
||||
self._processed_ids: deque = deque(maxlen=1000)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the QQ bot."""
|
||||
if not QQ_AVAILABLE:
|
||||
logger.error("QQ SDK not installed. Run: pip install qq-botpy")
|
||||
return
|
||||
|
||||
if not self.config.app_id or not self.config.secret:
|
||||
logger.error("QQ app_id and secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
BotClass = _make_bot_class(self)
|
||||
self._client = BotClass()
|
||||
|
||||
logger.info("QQ bot started (C2C private message)")
|
||||
await self._run_bot()
|
||||
|
||||
async def _run_bot(self) -> None:
|
||||
"""Run the bot connection with auto-reconnect."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||
except Exception as e:
|
||||
logger.warning("QQ bot error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting QQ bot in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the QQ bot."""
|
||||
self._running = False
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("QQ bot stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through QQ."""
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
try:
|
||||
await self._client.api.post_c2c_message(
|
||||
openid=msg.chat_id,
|
||||
msg_type=0,
|
||||
content=msg.content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error sending QQ message: {}", e)
|
||||
|
||||
async def _on_message(self, data: "C2CMessage") -> None:
|
||||
"""Handle incoming message from QQ."""
|
||||
try:
|
||||
# Dedup by message ID
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
|
||||
author = data.author
|
||||
user_id = str(getattr(author, 'id', None) or getattr(author, 'user_openid', 'unknown'))
|
||||
content = (data.content or "").strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=user_id,
|
||||
content=content,
|
||||
metadata={"message_id": data.id},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ message")
|
||||
257
app-instance/backend/nanobot/channels/slack.py
Normal file
257
app-instance/backend/nanobot/channels/slack.py
Normal file
@ -0,0 +1,257 @@
|
||||
"""Slack channel implementation using Socket Mode."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from slackify_markdown import slackify_markdown
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import SlackConfig
|
||||
|
||||
|
||||
class SlackChannel(BaseChannel):
|
||||
"""Slack channel using Socket Mode."""
|
||||
|
||||
name = "slack"
|
||||
|
||||
def __init__(self, config: SlackConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: SlackConfig = config
|
||||
self._web_client: AsyncWebClient | None = None
|
||||
self._socket_client: SocketModeClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Slack Socket Mode client."""
|
||||
if not self.config.bot_token or not self.config.app_token:
|
||||
logger.error("Slack bot/app token not configured")
|
||||
return
|
||||
if self.config.mode != "socket":
|
||||
logger.error("Unsupported Slack mode: {}", self.config.mode)
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
self._web_client = AsyncWebClient(token=self.config.bot_token)
|
||||
self._socket_client = SocketModeClient(
|
||||
app_token=self.config.app_token,
|
||||
web_client=self._web_client,
|
||||
)
|
||||
|
||||
self._socket_client.socket_mode_request_listeners.append(self._on_socket_request)
|
||||
|
||||
# Resolve bot user ID for mention handling
|
||||
try:
|
||||
auth = await self._web_client.auth_test()
|
||||
self._bot_user_id = auth.get("user_id")
|
||||
logger.info("Slack bot connected as {}", self._bot_user_id)
|
||||
except Exception as e:
|
||||
logger.warning("Slack auth_test failed: {}", e)
|
||||
|
||||
logger.info("Starting Slack Socket Mode client...")
|
||||
await self._socket_client.connect()
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Slack client."""
|
||||
self._running = False
|
||||
if self._socket_client:
|
||||
try:
|
||||
await self._socket_client.close()
|
||||
except Exception as e:
|
||||
logger.warning("Slack socket close failed: {}", e)
|
||||
self._socket_client = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Slack."""
|
||||
if not self._web_client:
|
||||
logger.warning("Slack client not running")
|
||||
return
|
||||
try:
|
||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||
thread_ts = slack_meta.get("thread_ts")
|
||||
channel_type = slack_meta.get("channel_type")
|
||||
# Only reply in thread for channel/group messages; DMs don't use threads
|
||||
use_thread = thread_ts and channel_type != "im"
|
||||
thread_ts_param = thread_ts if use_thread else None
|
||||
|
||||
if msg.content:
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
text=self._to_mrkdwn(msg.content),
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
|
||||
for media_path in msg.media or []:
|
||||
try:
|
||||
await self._web_client.files_upload_v2(
|
||||
channel=msg.chat_id,
|
||||
file=media_path,
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to upload file {}: {}", media_path, e)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Slack message: {}", e)
|
||||
|
||||
async def _on_socket_request(
|
||||
self,
|
||||
client: SocketModeClient,
|
||||
req: SocketModeRequest,
|
||||
) -> None:
|
||||
"""Handle incoming Socket Mode requests."""
|
||||
if req.type != "events_api":
|
||||
return
|
||||
|
||||
# Acknowledge right away
|
||||
await client.send_socket_mode_response(
|
||||
SocketModeResponse(envelope_id=req.envelope_id)
|
||||
)
|
||||
|
||||
payload = req.payload or {}
|
||||
event = payload.get("event") or {}
|
||||
event_type = event.get("type")
|
||||
|
||||
# Handle app mentions or plain messages
|
||||
if event_type not in ("message", "app_mention"):
|
||||
return
|
||||
|
||||
sender_id = event.get("user")
|
||||
chat_id = event.get("channel")
|
||||
|
||||
# Ignore bot/system messages (any subtype = not a normal user message)
|
||||
if event.get("subtype"):
|
||||
return
|
||||
if self._bot_user_id and sender_id == self._bot_user_id:
|
||||
return
|
||||
|
||||
# Avoid double-processing: Slack sends both `message` and `app_mention`
|
||||
# for mentions in channels. Prefer `app_mention`.
|
||||
text = event.get("text") or ""
|
||||
if event_type == "message" and self._bot_user_id and f"<@{self._bot_user_id}>" in text:
|
||||
return
|
||||
|
||||
# Debug: log basic event shape
|
||||
logger.debug(
|
||||
"Slack event: type={} subtype={} user={} channel={} channel_type={} text={}",
|
||||
event_type,
|
||||
event.get("subtype"),
|
||||
sender_id,
|
||||
chat_id,
|
||||
event.get("channel_type"),
|
||||
text[:80],
|
||||
)
|
||||
if not sender_id or not chat_id:
|
||||
return
|
||||
|
||||
channel_type = event.get("channel_type") or ""
|
||||
|
||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||
return
|
||||
|
||||
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
|
||||
return
|
||||
|
||||
text = self._strip_bot_mention(text)
|
||||
|
||||
thread_ts = event.get("thread_ts")
|
||||
if self.config.reply_in_thread and not thread_ts:
|
||||
thread_ts = event.get("ts")
|
||||
# Add :eyes: reaction to the triggering message (best-effort)
|
||||
try:
|
||||
if self._web_client and event.get("ts"):
|
||||
await self._web_client.reactions_add(
|
||||
channel=chat_id,
|
||||
name=self.config.react_emoji,
|
||||
timestamp=event.get("ts"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Slack reactions_add failed: {}", e)
|
||||
|
||||
# Thread-scoped session key for channel/group messages
|
||||
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": event,
|
||||
"thread_ts": thread_ts,
|
||||
"channel_type": channel_type,
|
||||
},
|
||||
},
|
||||
session_key=session_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling Slack message from {}", sender_id)
|
||||
|
||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||
if channel_type == "im":
|
||||
if not self.config.dm.enabled:
|
||||
return False
|
||||
if self.config.dm.policy == "allowlist":
|
||||
return sender_id in self.config.dm.allow_from
|
||||
return True
|
||||
|
||||
# Group / channel messages
|
||||
if self.config.group_policy == "allowlist":
|
||||
return chat_id in self.config.group_allow_from
|
||||
return True
|
||||
|
||||
def _should_respond_in_channel(self, event_type: str, text: str, chat_id: str) -> bool:
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
if self.config.group_policy == "mention":
|
||||
if event_type == "app_mention":
|
||||
return True
|
||||
return self._bot_user_id is not None and f"<@{self._bot_user_id}>" in text
|
||||
if self.config.group_policy == "allowlist":
|
||||
return chat_id in self.config.group_allow_from
|
||||
return False
|
||||
|
||||
def _strip_bot_mention(self, text: str) -> str:
|
||||
if not text or not self._bot_user_id:
|
||||
return text
|
||||
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
||||
|
||||
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
||||
|
||||
@classmethod
|
||||
def _to_mrkdwn(cls, text: str) -> str:
|
||||
"""Convert Markdown to Slack mrkdwn, including tables."""
|
||||
if not text:
|
||||
return ""
|
||||
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
||||
return slackify_markdown(text)
|
||||
|
||||
@staticmethod
|
||||
def _convert_table(match: re.Match) -> str:
|
||||
"""Convert a Markdown table to a Slack-readable list."""
|
||||
lines = [ln.strip() for ln in match.group(0).strip().splitlines() if ln.strip()]
|
||||
if len(lines) < 2:
|
||||
return match.group(0)
|
||||
headers = [h.strip() for h in lines[0].strip("|").split("|")]
|
||||
start = 2 if re.fullmatch(r"[|\s:\-]+", lines[1]) else 1
|
||||
rows: list[str] = []
|
||||
for line in lines[start:]:
|
||||
cells = [c.strip() for c in line.strip("|").split("|")]
|
||||
cells = (cells + [""] * len(headers))[: len(headers)]
|
||||
parts = [f"**{headers[i]}**: {cells[i]}" for i in range(len(headers)) if cells[i]]
|
||||
if parts:
|
||||
rows.append(" · ".join(parts))
|
||||
return "\n".join(rows)
|
||||
|
||||
457
app-instance/backend/nanobot/channels/telegram.py
Normal file
457
app-instance/backend/nanobot/channels/telegram.py
Normal file
@ -0,0 +1,457 @@
|
||||
"""Telegram channel implementation using python-telegram-bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from loguru import logger
|
||||
from telegram import BotCommand, Update, ReplyParameters
|
||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
|
||||
|
||||
def _markdown_to_telegram_html(text: str) -> str:
|
||||
"""
|
||||
Convert markdown to Telegram-safe HTML.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 1. Extract and protect code blocks (preserve content from other processing)
|
||||
code_blocks: list[str] = []
|
||||
def save_code_block(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(1))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||
|
||||
# 2. Extract and protect inline code
|
||||
inline_codes: list[str] = []
|
||||
def save_inline_code(m: re.Match) -> str:
|
||||
inline_codes.append(m.group(1))
|
||||
return f"\x00IC{len(inline_codes) - 1}\x00"
|
||||
|
||||
text = re.sub(r'`([^`]+)`', save_inline_code, text)
|
||||
|
||||
# 3. Headers # Title -> just the title text
|
||||
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
# 4. Blockquotes > text -> just the text (before HTML escaping)
|
||||
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
# 5. Escape HTML special characters
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
# 6. Links [text](url) - must be before bold/italic to handle nested cases
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
|
||||
|
||||
# 7. Bold **text** or __text__
|
||||
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
|
||||
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
|
||||
|
||||
# 8. Italic _text_ (avoid matching inside words like some_var_name)
|
||||
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
|
||||
|
||||
# 9. Strikethrough ~~text~~
|
||||
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
|
||||
|
||||
# 10. Bullet lists - item -> • item
|
||||
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
|
||||
|
||||
# 11. Restore inline code with HTML tags
|
||||
for i, code in enumerate(inline_codes):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
|
||||
|
||||
# 12. Restore code blocks with HTML tags
|
||||
for i, code in enumerate(code_blocks):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def _split_message(content: str, max_len: int = 4000) -> list[str]:
|
||||
"""Split content into chunks within max_len, preferring line breaks."""
|
||||
if len(content) <= max_len:
|
||||
return [content]
|
||||
chunks: list[str] = []
|
||||
while content:
|
||||
if len(content) <= max_len:
|
||||
chunks.append(content)
|
||||
break
|
||||
cut = content[:max_len]
|
||||
pos = cut.rfind('\n')
|
||||
if pos == -1:
|
||||
pos = cut.rfind(' ')
|
||||
if pos == -1:
|
||||
pos = max_len
|
||||
chunks.append(content[:pos])
|
||||
content = content[pos:].lstrip()
|
||||
return chunks
|
||||
|
||||
|
||||
class TelegramChannel(BaseChannel):
|
||||
"""
|
||||
Telegram channel using long polling.
|
||||
|
||||
Simple and reliable - no webhook/public IP needed.
|
||||
"""
|
||||
|
||||
name = "telegram"
|
||||
|
||||
# Commands registered with Telegram's command menu
|
||||
BOT_COMMANDS = [
|
||||
BotCommand("start", "Start the bot"),
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TelegramConfig,
|
||||
bus: MessageBus,
|
||||
groq_api_key: str = "",
|
||||
):
|
||||
super().__init__(config, bus)
|
||||
self.config: TelegramConfig = config
|
||||
self.groq_api_key = groq_api_key
|
||||
self._app: Application | None = None
|
||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Telegram bot with long polling."""
|
||||
if not self.config.token:
|
||||
logger.error("Telegram bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||
req = HTTPXRequest(connection_pool_size=16, pool_timeout=5.0, connect_timeout=30.0, read_timeout=30.0)
|
||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||
if self.config.proxy:
|
||||
builder = builder.proxy(self.config.proxy).get_updates_proxy(self.config.proxy)
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
MessageHandler(
|
||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||
& ~filters.COMMAND,
|
||||
self._on_message
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Starting Telegram bot (polling mode)...")
|
||||
|
||||
# Initialize and start polling
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
|
||||
# Get bot info and register command menu
|
||||
bot_info = await self._app.bot.get_me()
|
||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||
|
||||
try:
|
||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||
logger.debug("Telegram bot commands registered")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to register bot commands: {}", e)
|
||||
|
||||
# Start polling (this runs until stopped)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=["message"],
|
||||
drop_pending_updates=True # Ignore old messages on startup
|
||||
)
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Telegram bot."""
|
||||
self._running = False
|
||||
|
||||
# Cancel all typing indicators
|
||||
for chat_id in list(self._typing_tasks):
|
||||
self._stop_typing(chat_id)
|
||||
|
||||
if self._app:
|
||||
logger.info("Stopping Telegram bot...")
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
self._app = None
|
||||
|
||||
@staticmethod
|
||||
def _get_media_type(path: str) -> str:
|
||||
"""Guess media type from file extension."""
|
||||
ext = path.rsplit(".", 1)[-1].lower() if "." in path else ""
|
||||
if ext in ("jpg", "jpeg", "png", "gif", "webp"):
|
||||
return "photo"
|
||||
if ext == "ogg":
|
||||
return "voice"
|
||||
if ext in ("mp3", "m4a", "wav", "aac"):
|
||||
return "audio"
|
||||
return "document"
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Telegram."""
|
||||
if not self._app:
|
||||
logger.warning("Telegram bot not running")
|
||||
return
|
||||
|
||||
self._stop_typing(msg.chat_id)
|
||||
|
||||
try:
|
||||
chat_id = int(msg.chat_id)
|
||||
except ValueError:
|
||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||
return
|
||||
|
||||
reply_params = None
|
||||
if self.config.reply_to_message:
|
||||
reply_to_message_id = msg.metadata.get("message_id")
|
||||
if reply_to_message_id:
|
||||
reply_params = ReplyParameters(
|
||||
message_id=reply_to_message_id,
|
||||
allow_sending_without_reply=True
|
||||
)
|
||||
|
||||
# Send media files
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
media_type = self._get_media_type(media_path)
|
||||
sender = {
|
||||
"photo": self._app.bot.send_photo,
|
||||
"voice": self._app.bot.send_voice,
|
||||
"audio": self._app.bot.send_audio,
|
||||
}.get(media_type, self._app.bot.send_document)
|
||||
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
||||
with open(media_path, 'rb') as f:
|
||||
await sender(
|
||||
chat_id=chat_id,
|
||||
**{param: f},
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e:
|
||||
filename = media_path.rsplit("/", 1)[-1]
|
||||
logger.error("Failed to send media {}: {}", media_path, e)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"[Failed to send: {filename}]",
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
|
||||
# Send text content
|
||||
if msg.content and msg.content != "[empty message]":
|
||||
for chunk in _split_message(msg.content):
|
||||
try:
|
||||
html = _markdown_to_telegram_html(chunk)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=html,
|
||||
parse_mode="HTML",
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=chunk,
|
||||
reply_parameters=reply_params
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error("Error sending Telegram message: {}", e2)
|
||||
|
||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /start command."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
user = update.effective_user
|
||||
await update.message.reply_text(
|
||||
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
||||
"Send me a message and I'll respond!\n"
|
||||
"Type /help to see available commands."
|
||||
)
|
||||
|
||||
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /help command, bypassing ACL so all users can access it."""
|
||||
if not update.message:
|
||||
return
|
||||
await update.message.reply_text(
|
||||
"🐈 nanobot commands:\n"
|
||||
"/new — Start a new conversation\n"
|
||||
"/help — Show available commands"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sender_id(user) -> str:
|
||||
"""Build sender_id with username for allowlist matching."""
|
||||
sid = str(user.id)
|
||||
return f"{sid}|{user.username}" if user.username else sid
|
||||
|
||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
await self._handle_message(
|
||||
sender_id=self._sender_id(update.effective_user),
|
||||
chat_id=str(update.message.chat_id),
|
||||
content=update.message.text,
|
||||
)
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
chat_id = message.chat_id
|
||||
sender_id = self._sender_id(user)
|
||||
|
||||
# Store chat_id for replies
|
||||
self._chat_ids[sender_id] = chat_id
|
||||
|
||||
# Build content from text and/or media
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
# Text content
|
||||
if message.text:
|
||||
content_parts.append(message.text)
|
||||
if message.caption:
|
||||
content_parts.append(message.caption)
|
||||
|
||||
# Handle media files
|
||||
media_file = None
|
||||
media_type = None
|
||||
|
||||
if message.photo:
|
||||
media_file = message.photo[-1] # Largest photo
|
||||
media_type = "image"
|
||||
elif message.voice:
|
||||
media_file = message.voice
|
||||
media_type = "voice"
|
||||
elif message.audio:
|
||||
media_file = message.audio
|
||||
media_type = "audio"
|
||||
elif message.document:
|
||||
media_file = message.document
|
||||
media_type = "file"
|
||||
|
||||
# Download media if present
|
||||
if media_file and self._app:
|
||||
try:
|
||||
file = await self._app.bot.get_file(media_file.file_id)
|
||||
ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
|
||||
|
||||
# Save to workspace/media/
|
||||
from pathlib import Path
|
||||
media_dir = Path.home() / ".nanobot" / "media"
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||
await file.download_to_drive(str(file_path))
|
||||
|
||||
media_paths.append(str(file_path))
|
||||
|
||||
# Handle voice transcription
|
||||
if media_type == "voice" or media_type == "audio":
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
||||
transcription = await transcriber.transcribe(file_path)
|
||||
if transcription:
|
||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||
content_parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to download media: {}", e)
|
||||
content_parts.append(f"[{media_type}: download failed]")
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||
|
||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||
|
||||
str_chat_id = str(chat_id)
|
||||
|
||||
# Start typing indicator before processing
|
||||
self._start_typing(str_chat_id)
|
||||
|
||||
# Forward to the message bus
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str_chat_id,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message.message_id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private"
|
||||
}
|
||||
)
|
||||
|
||||
def _start_typing(self, chat_id: str) -> None:
|
||||
"""Start sending 'typing...' indicator for a chat."""
|
||||
# Cancel any existing typing task for this chat
|
||||
self._stop_typing(chat_id)
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
|
||||
|
||||
def _stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the typing indicator for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
async def _typing_loop(self, chat_id: str) -> None:
|
||||
"""Repeatedly send 'typing' action until cancelled."""
|
||||
try:
|
||||
while self._app:
|
||||
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||
await asyncio.sleep(4)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||
|
||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Log polling / handler errors instead of silently swallowing them."""
|
||||
logger.error("Telegram error: {}", context.error)
|
||||
|
||||
def _get_extension(self, media_type: str, mime_type: str | None) -> str:
|
||||
"""Get file extension based on media type."""
|
||||
if mime_type:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
|
||||
"audio/ogg": ".ogg", "audio/mpeg": ".mp3", "audio/mp4": ".m4a",
|
||||
}
|
||||
if mime_type in ext_map:
|
||||
return ext_map[mime_type]
|
||||
|
||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||
return type_map.get(media_type, "")
|
||||
148
app-instance/backend/nanobot/channels/whatsapp.py
Normal file
148
app-instance/backend/nanobot/channels/whatsapp.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""WhatsApp channel implementation using Node.js bridge."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import WhatsAppConfig
|
||||
|
||||
|
||||
class WhatsAppChannel(BaseChannel):
|
||||
"""
|
||||
WhatsApp channel that connects to a Node.js bridge.
|
||||
|
||||
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
|
||||
Communication between Python and Node.js is via WebSocket.
|
||||
"""
|
||||
|
||||
name = "whatsapp"
|
||||
|
||||
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WhatsAppConfig = config
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||
import websockets
|
||||
|
||||
bridge_url = self.config.bridge_url
|
||||
|
||||
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with websockets.connect(bridge_url) as ws:
|
||||
self._ws = ws
|
||||
# Send auth token if configured
|
||||
if self.config.bridge_token:
|
||||
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||
self._connected = True
|
||||
logger.info("Connected to WhatsApp bridge")
|
||||
|
||||
# Listen for messages
|
||||
async for message in ws:
|
||||
try:
|
||||
await self._handle_bridge_message(message)
|
||||
except Exception as e:
|
||||
logger.error("Error handling bridge message: {}", e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
self._ws = None
|
||||
logger.warning("WhatsApp bridge connection error: {}", e)
|
||||
|
||||
if self._running:
|
||||
logger.info("Reconnecting in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the WhatsApp channel."""
|
||||
self._running = False
|
||||
self._connected = False
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WhatsApp."""
|
||||
if not self._ws or not self._connected:
|
||||
logger.warning("WhatsApp bridge not connected")
|
||||
return
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"type": "send",
|
||||
"to": msg.chat_id,
|
||||
"text": msg.content
|
||||
}
|
||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error("Error sending WhatsApp message: {}", e)
|
||||
|
||||
async def _handle_bridge_message(self, raw: str) -> None:
|
||||
"""Handle a message from the bridge."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||
return
|
||||
|
||||
msg_type = data.get("type")
|
||||
|
||||
if msg_type == "message":
|
||||
# Incoming message from WhatsApp
|
||||
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
|
||||
pn = data.get("pn", "")
|
||||
# New LID sytle typically:
|
||||
sender = data.get("sender", "")
|
||||
content = data.get("content", "")
|
||||
|
||||
# Extract just the phone number or lid as chat_id
|
||||
user_id = pn if pn else sender
|
||||
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
|
||||
logger.info("Sender {}", sender)
|
||||
|
||||
# Handle voice transcription if it's a voice message
|
||||
if content == "[Voice Message]":
|
||||
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender, # Use full LID for replies
|
||||
content=content,
|
||||
metadata={
|
||||
"message_id": data.get("id"),
|
||||
"timestamp": data.get("timestamp"),
|
||||
"is_group": data.get("isGroup", False)
|
||||
}
|
||||
)
|
||||
|
||||
elif msg_type == "status":
|
||||
# Connection status update
|
||||
status = data.get("status")
|
||||
logger.info("WhatsApp status: {}", status)
|
||||
|
||||
if status == "connected":
|
||||
self._connected = True
|
||||
elif status == "disconnected":
|
||||
self._connected = False
|
||||
|
||||
elif msg_type == "qr":
|
||||
# QR code for authentication
|
||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||
|
||||
elif msg_type == "error":
|
||||
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
||||
Reference in New Issue
Block a user