feat: add emotion prompt

This commit is contained in:
0Xiao0
2026-06-01 09:46:04 +08:00
parent e097323176
commit 7efd9eba98

View File

@ -2,6 +2,7 @@ import base64
import json import json
import logging import logging
import os import os
import re
import time import time
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from dataclasses import dataclass from dataclasses import dataclass
@ -57,6 +58,12 @@ GENERAL_INSTRUCTIONS = """
回答自然、简洁、准确。 回答自然、简洁、准确。
""".strip() """.strip()
EMOTION_INSTRUCTIONS = """
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
情绪标签之后直接输出给用户的正常回复,不要解释标签。
""".strip()
ROOM_LOCATOR_MODE = "room_locator" ROOM_LOCATOR_MODE = "room_locator"
GENERAL_MODE = "general" GENERAL_MODE = "general"
VOICE_INPUT_MODE = "voice" VOICE_INPUT_MODE = "voice"
@ -64,6 +71,25 @@ VISION_VOICE_INPUT_MODE = "vision_voice"
AUTO_INPUT_MODE = "auto" AUTO_INPUT_MODE = "auto"
VISION_FRAME_TOPIC = "vision.frame" VISION_FRAME_TOPIC = "vision.frame"
DEFAULT_EMOTION = "neutral"
EMOTION_LABELS = {
"neutral",
"happy",
"sad",
"angry",
"surprised",
"fearful",
"calm",
"concerned",
}
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
TTS_EMOTION_LINE_RE = re.compile(
r"^\s*(?:emotion|情绪)\s*[:=]\s*[\w\u4e00-\u9fff-]{1,40}\s*[,,。.!\s-]*",
re.IGNORECASE,
)
MAX_EMOTION_PREFIX_CHARS = 80
@dataclass @dataclass
class VisionFrame: class VisionFrame:
@ -111,13 +137,16 @@ class CustomAgent(Agent):
vision_llm: llm.LLM | None = None, vision_llm: llm.LLM | None = None,
model_image_save_dir: Path | None = None, model_image_save_dir: Path | None = None,
) -> None: ) -> None:
super().__init__(instructions=GENERAL_INSTRUCTIONS) super().__init__(instructions=_with_emotion_instructions(GENERAL_INSTRUCTIONS))
self._memory_client = memory_client self._memory_client = memory_client
self._vision_store = vision_store self._vision_store = vision_store
self._input_mode = input_mode self._input_mode = input_mode
self._text_llm = text_llm self._text_llm = text_llm
self._vision_llm = vision_llm self._vision_llm = vision_llm
self._model_image_save_dir = model_image_save_dir self._model_image_save_dir = model_image_save_dir
self.current_emotion = DEFAULT_EMOTION
self._emotion_prefix_buffer = ""
self._emotion_prefix_done = True
async def on_enter(self) -> None: async def on_enter(self) -> None:
# self.session.generate_reply(instructions="greet the user and introduce yourself") # self.session.generate_reply(instructions="greet the user and introduce yourself")
@ -144,9 +173,9 @@ class CustomAgent(Agent):
chat_ctx = chat_ctx.copy() chat_ctx = chat_ctx.copy()
update_chat_instructions( update_chat_instructions(
chat_ctx, chat_ctx,
instructions=ROOM_LOCATOR_INSTRUCTIONS instructions=_with_emotion_instructions(
if mode == ROOM_LOCATOR_MODE ROOM_LOCATOR_INSTRUCTIONS if mode == ROOM_LOCATOR_MODE else GENERAL_INSTRUCTIONS
else GENERAL_INSTRUCTIONS, ),
add_if_missing=True, add_if_missing=True,
) )
@ -173,6 +202,8 @@ class CustomAgent(Agent):
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
first_chunk_at: float | None = None first_chunk_at: float | None = None
chunk_count = 0 chunk_count = 0
self._emotion_prefix_buffer = ""
self._emotion_prefix_done = False
try: try:
async for chunk in llm_result: async for chunk in llm_result:
chunk_count += 1 chunk_count += 1
@ -182,7 +213,8 @@ class CustomAgent(Agent):
"LLM first chunk after %.3fs", "LLM first chunk after %.3fs",
first_chunk_at - llm_node_started_at, first_chunk_at - llm_node_started_at,
) )
yield chunk async for output_chunk in self._observe_emotion_prefix(chunk):
yield output_chunk
finally: finally:
finished_at = time.perf_counter() finished_at = time.perf_counter()
logger.info( logger.info(
@ -194,6 +226,9 @@ class CustomAgent(Agent):
return _instrumented_stream() return _instrumented_stream()
def tts_node(self, text: AsyncIterable[str], model_settings: ModelSettings):
return Agent.default.tts_node(self, _strip_emotion_for_tts(text), model_settings)
def _consume_vision_frame(self) -> VisionFrame | None: def _consume_vision_frame(self) -> VisionFrame | None:
if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None: if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None:
return None return None
@ -255,6 +290,51 @@ class CustomAgent(Agent):
yield chunk yield chunk
return _stream() return _stream()
async def _observe_emotion_prefix(
self, chunk: llm.ChatChunk | str | FlushSentinel
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
if isinstance(chunk, str):
self._consume_emotion_prefix(chunk)
yield chunk
return
if isinstance(chunk, llm.ChatChunk) and chunk.delta and chunk.delta.content:
self._consume_emotion_prefix(chunk.delta.content)
yield chunk
return
yield chunk
def _consume_emotion_prefix(self, content: str) -> None:
if self._emotion_prefix_done:
return
self._emotion_prefix_buffer += content
match = EMOTION_PREFIX_RE.match(self._emotion_prefix_buffer)
if match:
emotion = match.group(1).lower()
if emotion not in EMOTION_LABELS:
logger.warning("LLM returned unsupported emotion=%s, using neutral", emotion)
emotion = DEFAULT_EMOTION
self.current_emotion = emotion
self._emotion_prefix_done = True
self._emotion_prefix_buffer = ""
logger.info("LLM emotion selected: %s", emotion)
return
candidate = self._emotion_prefix_buffer.lstrip().lower()
might_still_be_prefix = (
not candidate
or "<emotion=".startswith(candidate)
or (candidate.startswith("<emotion=") and ">" not in candidate)
)
if might_still_be_prefix and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
return
self._emotion_prefix_done = True
self._emotion_prefix_buffer = ""
logger.warning("LLM response did not start with an emotion prefix")
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
if self._memory_client is None: if self._memory_client is None:
@ -294,6 +374,69 @@ def _select_mode(user_query: str) -> str:
return GENERAL_MODE return GENERAL_MODE
def _with_emotion_instructions(instructions: str) -> str:
return f"{instructions}\n\n{EMOTION_INSTRUCTIONS}"
async def _strip_emotion_for_tts(text: AsyncIterable[str]) -> AsyncIterable[str]:
prefix_buffer = ""
scanning_prefix = True
async for chunk in text:
if not chunk:
continue
if scanning_prefix:
prefix_buffer += chunk
cleaned, done = _strip_leading_tts_emotion(prefix_buffer)
if not done:
continue
scanning_prefix = False
prefix_buffer = ""
if cleaned:
yield _strip_inline_tts_emotion(cleaned)
continue
cleaned = _strip_inline_tts_emotion(chunk)
if cleaned:
yield cleaned
if scanning_prefix and prefix_buffer:
cleaned, _ = _strip_leading_tts_emotion(prefix_buffer, force=True)
cleaned = _strip_inline_tts_emotion(cleaned)
if cleaned:
yield cleaned
def _strip_leading_tts_emotion(text: str, *, force: bool = False) -> tuple[str, bool]:
match = TTS_EMOTION_MARKUP_RE.match(text)
if match:
return text[match.end() :], True
match = TTS_EMOTION_LINE_RE.match(text)
if match:
return text[match.end() :], True
candidate = text.lstrip().lower()
might_still_be_emotion = (
not candidate
or "<emotion=".startswith(candidate)
or (candidate.startswith("<emotion") and ">" not in candidate)
or "emotion".startswith(candidate)
or (candidate.startswith("emotion") and len(candidate) <= MAX_EMOTION_PREFIX_CHARS)
or "情绪".startswith(candidate)
)
if not force and might_still_be_emotion and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
return "", False
return text, True
def _strip_inline_tts_emotion(text: str) -> str:
return TTS_EMOTION_MARKUP_RE.sub("", text)
def _is_room_locator_query(normalized_text: str) -> bool: def _is_room_locator_query(normalized_text: str) -> bool:
room_context_hints = ( room_context_hints = (
"房间", "房间",
@ -500,6 +643,7 @@ async def entrypoint(ctx: JobContext) -> None:
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE")) INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
if not LLM_API_KEY: if not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default")
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv( TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox" "VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"