diff --git a/custom_agent.py b/custom_agent.py index 3786074..16a5728 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -6,6 +6,7 @@ from pathlib import Path from dotenv import load_dotenv from memory import MemoryRecallClient +from tts import BlackboxTTS from asr import BlackboxSTT from livekit.agents import ( @@ -27,9 +28,9 @@ from livekit.agents import ( room_io, stt, ) +from livekit.agents.voice.generation import update_instructions as update_chat_instructions from livekit.plugins import openai, silero from livekit.plugins.turn_detector.multilingual import MultilingualModel -from tts import BlackboxTTS logger = logging.getLogger("custom-agent") @@ -47,9 +48,19 @@ ROOM_LOCATOR_INSTRUCTIONS = """ 如果用户的问题与房间物品定位无关,则正常回答用户问题。 """.strip() +GENERAL_INSTRUCTIONS = """ +你是一个智能语音助手。 +正常回答用户问题。 +回答自然、简洁、准确。 +""".strip() + +ROOM_LOCATOR_MODE = "room_locator" +GENERAL_MODE = "general" + + class CustomAgent(Agent): def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None: - super().__init__(instructions=ROOM_LOCATOR_INSTRUCTIONS) + super().__init__(instructions=GENERAL_INSTRUCTIONS) self._memory_client = memory_client async def on_enter(self) -> None: @@ -63,9 +74,24 @@ class CustomAgent(Agent): model_settings: ModelSettings, ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: llm_node_started_at = time.perf_counter() - memory_context = await self._recall_room_memory(chat_ctx) - if memory_context: - chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) + + user_query = _latest_user_text(chat_ctx) + mode = _select_mode(user_query) + logger.info("Selected agent mode: %s", mode) + + chat_ctx = chat_ctx.copy() + update_chat_instructions( + chat_ctx, + instructions=ROOM_LOCATOR_INSTRUCTIONS + if mode == ROOM_LOCATOR_MODE + else GENERAL_INSTRUCTIONS, + add_if_missing=True, + ) + + if mode == ROOM_LOCATOR_MODE: + memory_context = await self._recall_room_memory(chat_ctx) + if memory_context: + chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings) if not hasattr(llm_result, "__aiter__"): @@ -124,6 +150,104 @@ class CustomAgent(Agent): return "" +def _select_mode(user_query: str) -> str: + normalized = _normalize_text(user_query) + if not normalized: + return GENERAL_MODE + + if _is_room_locator_query(normalized): + return ROOM_LOCATOR_MODE + + return GENERAL_MODE + + +def _is_room_locator_query(normalized_text: str) -> bool: + room_context_hints = ( + "房间", + "屋里", + "屋子", + "室内", + "客厅", + "卧室", + "书房", + "厨房", + "餐厅", + "沙发", + "桌", + "椅", + "床", + "门", + "窗", + "柜", + "电视", + "空调", + "书架", + "灯", + "冰箱", + "茶几", + "电脑", + "包", + "瓶", + "相机", + "植物", + ) + spatial_hints = ( + "在哪里", + "在哪", + "位置", + "方位", + "旁边", + "左边", + "右边", + "前面", + "后面", + "上面", + "下面", + "附近", + "对面", + "靠近", + "挨着", + "隔着", + ) + software_hints = ( + "python", + "代码", + "函数", + "class", + "bug", + "日志", + "logging", + "api", + "server", + "agent", + "prompt", + "模型", + "数据库", + "git", + "uv", + "ruff", + "mypy", + ) + + if any(hint in normalized_text for hint in software_hints): + return False + + has_spatial_hint = any(hint in normalized_text for hint in spatial_hints) + has_room_context_hint = any(hint in normalized_text for hint in room_context_hints) + + if has_spatial_hint and has_room_context_hint: + return True + + if has_spatial_hint and len(normalized_text) <= 12: + return True + + return False + + +def _normalize_text(text: str) -> str: + return "".join(text.split()).lower() + + def _latest_user_text(chat_ctx: ChatContext) -> str: for item in reversed(chat_ctx.items): if isinstance(item, ChatMessage) and item.role == "user": @@ -175,7 +299,7 @@ async def entrypoint(ctx: JobContext) -> None: raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv( - "VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox" + "VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox" ) TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts") TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)