fix: prompt
This commit is contained in:
130
custom_agent.py
130
custom_agent.py
@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from memory import MemoryRecallClient
|
from memory import MemoryRecallClient
|
||||||
|
from tts import BlackboxTTS
|
||||||
|
|
||||||
from asr import BlackboxSTT
|
from asr import BlackboxSTT
|
||||||
from livekit.agents import (
|
from livekit.agents import (
|
||||||
@ -27,9 +28,9 @@ from livekit.agents import (
|
|||||||
room_io,
|
room_io,
|
||||||
stt,
|
stt,
|
||||||
)
|
)
|
||||||
|
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
|
||||||
from livekit.plugins import openai, silero
|
from livekit.plugins import openai, silero
|
||||||
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
||||||
from tts import BlackboxTTS
|
|
||||||
|
|
||||||
logger = logging.getLogger("custom-agent")
|
logger = logging.getLogger("custom-agent")
|
||||||
|
|
||||||
@ -47,9 +48,19 @@ ROOM_LOCATOR_INSTRUCTIONS = """
|
|||||||
如果用户的问题与房间物品定位无关,则正常回答用户问题。
|
如果用户的问题与房间物品定位无关,则正常回答用户问题。
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
GENERAL_INSTRUCTIONS = """
|
||||||
|
你是一个智能语音助手。
|
||||||
|
正常回答用户问题。
|
||||||
|
回答自然、简洁、准确。
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
ROOM_LOCATOR_MODE = "room_locator"
|
||||||
|
GENERAL_MODE = "general"
|
||||||
|
|
||||||
|
|
||||||
class CustomAgent(Agent):
|
class CustomAgent(Agent):
|
||||||
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
|
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
|
||||||
super().__init__(instructions=ROOM_LOCATOR_INSTRUCTIONS)
|
super().__init__(instructions=GENERAL_INSTRUCTIONS)
|
||||||
self._memory_client = memory_client
|
self._memory_client = memory_client
|
||||||
|
|
||||||
async def on_enter(self) -> None:
|
async def on_enter(self) -> None:
|
||||||
@ -63,6 +74,21 @@ class CustomAgent(Agent):
|
|||||||
model_settings: ModelSettings,
|
model_settings: ModelSettings,
|
||||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||||
llm_node_started_at = time.perf_counter()
|
llm_node_started_at = time.perf_counter()
|
||||||
|
|
||||||
|
user_query = _latest_user_text(chat_ctx)
|
||||||
|
mode = _select_mode(user_query)
|
||||||
|
logger.info("Selected agent mode: %s", mode)
|
||||||
|
|
||||||
|
chat_ctx = chat_ctx.copy()
|
||||||
|
update_chat_instructions(
|
||||||
|
chat_ctx,
|
||||||
|
instructions=ROOM_LOCATOR_INSTRUCTIONS
|
||||||
|
if mode == ROOM_LOCATOR_MODE
|
||||||
|
else GENERAL_INSTRUCTIONS,
|
||||||
|
add_if_missing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mode == ROOM_LOCATOR_MODE:
|
||||||
memory_context = await self._recall_room_memory(chat_ctx)
|
memory_context = await self._recall_room_memory(chat_ctx)
|
||||||
if memory_context:
|
if memory_context:
|
||||||
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
||||||
@ -124,6 +150,104 @@ class CustomAgent(Agent):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _select_mode(user_query: str) -> str:
|
||||||
|
normalized = _normalize_text(user_query)
|
||||||
|
if not normalized:
|
||||||
|
return GENERAL_MODE
|
||||||
|
|
||||||
|
if _is_room_locator_query(normalized):
|
||||||
|
return ROOM_LOCATOR_MODE
|
||||||
|
|
||||||
|
return GENERAL_MODE
|
||||||
|
|
||||||
|
|
||||||
|
def _is_room_locator_query(normalized_text: str) -> bool:
|
||||||
|
room_context_hints = (
|
||||||
|
"房间",
|
||||||
|
"屋里",
|
||||||
|
"屋子",
|
||||||
|
"室内",
|
||||||
|
"客厅",
|
||||||
|
"卧室",
|
||||||
|
"书房",
|
||||||
|
"厨房",
|
||||||
|
"餐厅",
|
||||||
|
"沙发",
|
||||||
|
"桌",
|
||||||
|
"椅",
|
||||||
|
"床",
|
||||||
|
"门",
|
||||||
|
"窗",
|
||||||
|
"柜",
|
||||||
|
"电视",
|
||||||
|
"空调",
|
||||||
|
"书架",
|
||||||
|
"灯",
|
||||||
|
"冰箱",
|
||||||
|
"茶几",
|
||||||
|
"电脑",
|
||||||
|
"包",
|
||||||
|
"瓶",
|
||||||
|
"相机",
|
||||||
|
"植物",
|
||||||
|
)
|
||||||
|
spatial_hints = (
|
||||||
|
"在哪里",
|
||||||
|
"在哪",
|
||||||
|
"位置",
|
||||||
|
"方位",
|
||||||
|
"旁边",
|
||||||
|
"左边",
|
||||||
|
"右边",
|
||||||
|
"前面",
|
||||||
|
"后面",
|
||||||
|
"上面",
|
||||||
|
"下面",
|
||||||
|
"附近",
|
||||||
|
"对面",
|
||||||
|
"靠近",
|
||||||
|
"挨着",
|
||||||
|
"隔着",
|
||||||
|
)
|
||||||
|
software_hints = (
|
||||||
|
"python",
|
||||||
|
"代码",
|
||||||
|
"函数",
|
||||||
|
"class",
|
||||||
|
"bug",
|
||||||
|
"日志",
|
||||||
|
"logging",
|
||||||
|
"api",
|
||||||
|
"server",
|
||||||
|
"agent",
|
||||||
|
"prompt",
|
||||||
|
"模型",
|
||||||
|
"数据库",
|
||||||
|
"git",
|
||||||
|
"uv",
|
||||||
|
"ruff",
|
||||||
|
"mypy",
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(hint in normalized_text for hint in software_hints):
|
||||||
|
return False
|
||||||
|
|
||||||
|
has_spatial_hint = any(hint in normalized_text for hint in spatial_hints)
|
||||||
|
has_room_context_hint = any(hint in normalized_text for hint in room_context_hints)
|
||||||
|
|
||||||
|
if has_spatial_hint and has_room_context_hint:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if has_spatial_hint and len(normalized_text) <= 12:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_text(text: str) -> str:
|
||||||
|
return "".join(text.split()).lower()
|
||||||
|
|
||||||
|
|
||||||
def _latest_user_text(chat_ctx: ChatContext) -> str:
|
def _latest_user_text(chat_ctx: ChatContext) -> str:
|
||||||
for item in reversed(chat_ctx.items):
|
for item in reversed(chat_ctx.items):
|
||||||
if isinstance(item, ChatMessage) and item.role == "user":
|
if isinstance(item, ChatMessage) and item.role == "user":
|
||||||
@ -175,7 +299,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||||
|
|
||||||
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
||||||
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
|
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
|
||||||
)
|
)
|
||||||
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
||||||
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
||||||
|
|||||||
Reference in New Issue
Block a user