fix: prompt

This commit is contained in:
0Xiao0
2026-05-22 14:46:10 +08:00
parent fba51a5257
commit f272053a95

View File

@ -6,6 +6,7 @@ from pathlib import Path
from dotenv import load_dotenv
from memory import MemoryRecallClient
from tts import BlackboxTTS
from asr import BlackboxSTT
from livekit.agents import (
@ -27,9 +28,9 @@ from livekit.agents import (
room_io,
stt,
)
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
from livekit.plugins import openai, silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel
from tts import BlackboxTTS
logger = logging.getLogger("custom-agent")
@ -47,9 +48,19 @@ ROOM_LOCATOR_INSTRUCTIONS = """
如果用户的问题与房间物品定位无关,则正常回答用户问题。
""".strip()
GENERAL_INSTRUCTIONS = """
你是一个智能语音助手。
正常回答用户问题。
回答自然、简洁、准确。
""".strip()
ROOM_LOCATOR_MODE = "room_locator"
GENERAL_MODE = "general"
class CustomAgent(Agent):
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
super().__init__(instructions=ROOM_LOCATOR_INSTRUCTIONS)
super().__init__(instructions=GENERAL_INSTRUCTIONS)
self._memory_client = memory_client
async def on_enter(self) -> None:
@ -63,9 +74,24 @@ class CustomAgent(Agent):
model_settings: ModelSettings,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
llm_node_started_at = time.perf_counter()
memory_context = await self._recall_room_memory(chat_ctx)
if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
user_query = _latest_user_text(chat_ctx)
mode = _select_mode(user_query)
logger.info("Selected agent mode: %s", mode)
chat_ctx = chat_ctx.copy()
update_chat_instructions(
chat_ctx,
instructions=ROOM_LOCATOR_INSTRUCTIONS
if mode == ROOM_LOCATOR_MODE
else GENERAL_INSTRUCTIONS,
add_if_missing=True,
)
if mode == ROOM_LOCATOR_MODE:
memory_context = await self._recall_room_memory(chat_ctx)
if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
if not hasattr(llm_result, "__aiter__"):
@ -124,6 +150,104 @@ class CustomAgent(Agent):
return ""
def _select_mode(user_query: str) -> str:
normalized = _normalize_text(user_query)
if not normalized:
return GENERAL_MODE
if _is_room_locator_query(normalized):
return ROOM_LOCATOR_MODE
return GENERAL_MODE
def _is_room_locator_query(normalized_text: str) -> bool:
room_context_hints = (
"房间",
"屋里",
"屋子",
"室内",
"客厅",
"卧室",
"书房",
"厨房",
"餐厅",
"沙发",
"",
"",
"",
"",
"",
"",
"电视",
"空调",
"书架",
"",
"冰箱",
"茶几",
"电脑",
"",
"",
"相机",
"植物",
)
spatial_hints = (
"在哪里",
"在哪",
"位置",
"方位",
"旁边",
"左边",
"右边",
"前面",
"后面",
"上面",
"下面",
"附近",
"对面",
"靠近",
"挨着",
"隔着",
)
software_hints = (
"python",
"代码",
"函数",
"class",
"bug",
"日志",
"logging",
"api",
"server",
"agent",
"prompt",
"模型",
"数据库",
"git",
"uv",
"ruff",
"mypy",
)
if any(hint in normalized_text for hint in software_hints):
return False
has_spatial_hint = any(hint in normalized_text for hint in spatial_hints)
has_room_context_hint = any(hint in normalized_text for hint in room_context_hints)
if has_spatial_hint and has_room_context_hint:
return True
if has_spatial_hint and len(normalized_text) <= 12:
return True
return False
def _normalize_text(text: str) -> str:
return "".join(text.split()).lower()
def _latest_user_text(chat_ctx: ChatContext) -> str:
for item in reversed(chat_ctx.items):
if isinstance(item, ChatMessage) and item.role == "user":
@ -175,7 +299,7 @@ async def entrypoint(ctx: JobContext) -> None:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
)
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)