fix: prompt
This commit is contained in:
136
custom_agent.py
136
custom_agent.py
@ -6,6 +6,7 @@ from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from memory import MemoryRecallClient
|
||||
from tts import BlackboxTTS
|
||||
|
||||
from asr import BlackboxSTT
|
||||
from livekit.agents import (
|
||||
@ -27,9 +28,9 @@ from livekit.agents import (
|
||||
room_io,
|
||||
stt,
|
||||
)
|
||||
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
|
||||
from livekit.plugins import openai, silero
|
||||
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
||||
from tts import BlackboxTTS
|
||||
|
||||
logger = logging.getLogger("custom-agent")
|
||||
|
||||
@ -47,9 +48,19 @@ ROOM_LOCATOR_INSTRUCTIONS = """
|
||||
如果用户的问题与房间物品定位无关,则正常回答用户问题。
|
||||
""".strip()
|
||||
|
||||
GENERAL_INSTRUCTIONS = """
|
||||
你是一个智能语音助手。
|
||||
正常回答用户问题。
|
||||
回答自然、简洁、准确。
|
||||
""".strip()
|
||||
|
||||
ROOM_LOCATOR_MODE = "room_locator"
|
||||
GENERAL_MODE = "general"
|
||||
|
||||
|
||||
class CustomAgent(Agent):
|
||||
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
|
||||
super().__init__(instructions=ROOM_LOCATOR_INSTRUCTIONS)
|
||||
super().__init__(instructions=GENERAL_INSTRUCTIONS)
|
||||
self._memory_client = memory_client
|
||||
|
||||
async def on_enter(self) -> None:
|
||||
@ -63,9 +74,24 @@ class CustomAgent(Agent):
|
||||
model_settings: ModelSettings,
|
||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||
llm_node_started_at = time.perf_counter()
|
||||
memory_context = await self._recall_room_memory(chat_ctx)
|
||||
if memory_context:
|
||||
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
||||
|
||||
user_query = _latest_user_text(chat_ctx)
|
||||
mode = _select_mode(user_query)
|
||||
logger.info("Selected agent mode: %s", mode)
|
||||
|
||||
chat_ctx = chat_ctx.copy()
|
||||
update_chat_instructions(
|
||||
chat_ctx,
|
||||
instructions=ROOM_LOCATOR_INSTRUCTIONS
|
||||
if mode == ROOM_LOCATOR_MODE
|
||||
else GENERAL_INSTRUCTIONS,
|
||||
add_if_missing=True,
|
||||
)
|
||||
|
||||
if mode == ROOM_LOCATOR_MODE:
|
||||
memory_context = await self._recall_room_memory(chat_ctx)
|
||||
if memory_context:
|
||||
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
||||
|
||||
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
||||
if not hasattr(llm_result, "__aiter__"):
|
||||
@ -124,6 +150,104 @@ class CustomAgent(Agent):
|
||||
return ""
|
||||
|
||||
|
||||
def _select_mode(user_query: str) -> str:
|
||||
normalized = _normalize_text(user_query)
|
||||
if not normalized:
|
||||
return GENERAL_MODE
|
||||
|
||||
if _is_room_locator_query(normalized):
|
||||
return ROOM_LOCATOR_MODE
|
||||
|
||||
return GENERAL_MODE
|
||||
|
||||
|
||||
def _is_room_locator_query(normalized_text: str) -> bool:
|
||||
room_context_hints = (
|
||||
"房间",
|
||||
"屋里",
|
||||
"屋子",
|
||||
"室内",
|
||||
"客厅",
|
||||
"卧室",
|
||||
"书房",
|
||||
"厨房",
|
||||
"餐厅",
|
||||
"沙发",
|
||||
"桌",
|
||||
"椅",
|
||||
"床",
|
||||
"门",
|
||||
"窗",
|
||||
"柜",
|
||||
"电视",
|
||||
"空调",
|
||||
"书架",
|
||||
"灯",
|
||||
"冰箱",
|
||||
"茶几",
|
||||
"电脑",
|
||||
"包",
|
||||
"瓶",
|
||||
"相机",
|
||||
"植物",
|
||||
)
|
||||
spatial_hints = (
|
||||
"在哪里",
|
||||
"在哪",
|
||||
"位置",
|
||||
"方位",
|
||||
"旁边",
|
||||
"左边",
|
||||
"右边",
|
||||
"前面",
|
||||
"后面",
|
||||
"上面",
|
||||
"下面",
|
||||
"附近",
|
||||
"对面",
|
||||
"靠近",
|
||||
"挨着",
|
||||
"隔着",
|
||||
)
|
||||
software_hints = (
|
||||
"python",
|
||||
"代码",
|
||||
"函数",
|
||||
"class",
|
||||
"bug",
|
||||
"日志",
|
||||
"logging",
|
||||
"api",
|
||||
"server",
|
||||
"agent",
|
||||
"prompt",
|
||||
"模型",
|
||||
"数据库",
|
||||
"git",
|
||||
"uv",
|
||||
"ruff",
|
||||
"mypy",
|
||||
)
|
||||
|
||||
if any(hint in normalized_text for hint in software_hints):
|
||||
return False
|
||||
|
||||
has_spatial_hint = any(hint in normalized_text for hint in spatial_hints)
|
||||
has_room_context_hint = any(hint in normalized_text for hint in room_context_hints)
|
||||
|
||||
if has_spatial_hint and has_room_context_hint:
|
||||
return True
|
||||
|
||||
if has_spatial_hint and len(normalized_text) <= 12:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
return "".join(text.split()).lower()
|
||||
|
||||
|
||||
def _latest_user_text(chat_ctx: ChatContext) -> str:
|
||||
for item in reversed(chat_ctx.items):
|
||||
if isinstance(item, ChatMessage) and item.role == "user":
|
||||
@ -175,7 +299,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||
|
||||
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
||||
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
|
||||
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
|
||||
)
|
||||
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
||||
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
||||
|
||||
Reference in New Issue
Block a user