diff --git a/custom_agent.py b/custom_agent.py index 8f274fd..bb98a1e 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -2,6 +2,7 @@ import base64 import json import logging import os +import re import time from collections.abc import AsyncIterable from dataclasses import dataclass @@ -57,6 +58,12 @@ GENERAL_INSTRUCTIONS = """ 回答自然、简洁、准确。 """.strip() +EMOTION_INSTRUCTIONS = """ +每次回复必须先输出一个情绪标签,格式严格为: +emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。 +情绪标签之后直接输出给用户的正常回复,不要解释标签。 +""".strip() + ROOM_LOCATOR_MODE = "room_locator" GENERAL_MODE = "general" VOICE_INPUT_MODE = "voice" @@ -64,6 +71,25 @@ VISION_VOICE_INPUT_MODE = "vision_voice" AUTO_INPUT_MODE = "auto" VISION_FRAME_TOPIC = "vision.frame" +DEFAULT_EMOTION = "neutral" +EMOTION_LABELS = { + "neutral", + "happy", + "sad", + "angry", + "surprised", + "fearful", + "calm", + "concerned", +} +EMOTION_PREFIX_RE = re.compile(r"^\s*\s*", re.IGNORECASE) +TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE) +TTS_EMOTION_LINE_RE = re.compile( + r"^\s*(?:emotion|情绪)\s*[::=]\s*[\w\u4e00-\u9fff-]{1,40}\s*[,,。.!!\s-]*", + re.IGNORECASE, +) +MAX_EMOTION_PREFIX_CHARS = 80 + @dataclass class VisionFrame: @@ -111,13 +137,16 @@ class CustomAgent(Agent): vision_llm: llm.LLM | None = None, model_image_save_dir: Path | None = None, ) -> None: - super().__init__(instructions=GENERAL_INSTRUCTIONS) + super().__init__(instructions=_with_emotion_instructions(GENERAL_INSTRUCTIONS)) self._memory_client = memory_client self._vision_store = vision_store self._input_mode = input_mode self._text_llm = text_llm self._vision_llm = vision_llm self._model_image_save_dir = model_image_save_dir + self.current_emotion = DEFAULT_EMOTION + self._emotion_prefix_buffer = "" + self._emotion_prefix_done = True async def on_enter(self) -> None: # self.session.generate_reply(instructions="greet the user and introduce yourself") @@ -144,9 +173,9 @@ class CustomAgent(Agent): chat_ctx = chat_ctx.copy() update_chat_instructions( chat_ctx, - instructions=ROOM_LOCATOR_INSTRUCTIONS - if mode == ROOM_LOCATOR_MODE - else GENERAL_INSTRUCTIONS, + instructions=_with_emotion_instructions( + ROOM_LOCATOR_INSTRUCTIONS if mode == ROOM_LOCATOR_MODE else GENERAL_INSTRUCTIONS + ), add_if_missing=True, ) @@ -173,6 +202,8 @@ class CustomAgent(Agent): async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: first_chunk_at: float | None = None chunk_count = 0 + self._emotion_prefix_buffer = "" + self._emotion_prefix_done = False try: async for chunk in llm_result: chunk_count += 1 @@ -182,7 +213,8 @@ class CustomAgent(Agent): "LLM first chunk after %.3fs", first_chunk_at - llm_node_started_at, ) - yield chunk + async for output_chunk in self._observe_emotion_prefix(chunk): + yield output_chunk finally: finished_at = time.perf_counter() logger.info( @@ -194,6 +226,9 @@ class CustomAgent(Agent): return _instrumented_stream() + def tts_node(self, text: AsyncIterable[str], model_settings: ModelSettings): + return Agent.default.tts_node(self, _strip_emotion_for_tts(text), model_settings) + def _consume_vision_frame(self) -> VisionFrame | None: if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None: return None @@ -255,6 +290,51 @@ class CustomAgent(Agent): yield chunk return _stream() + async def _observe_emotion_prefix( + self, chunk: llm.ChatChunk | str | FlushSentinel + ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: + if isinstance(chunk, str): + self._consume_emotion_prefix(chunk) + yield chunk + return + + if isinstance(chunk, llm.ChatChunk) and chunk.delta and chunk.delta.content: + self._consume_emotion_prefix(chunk.delta.content) + yield chunk + return + + yield chunk + + def _consume_emotion_prefix(self, content: str) -> None: + if self._emotion_prefix_done: + return + + self._emotion_prefix_buffer += content + match = EMOTION_PREFIX_RE.match(self._emotion_prefix_buffer) + if match: + emotion = match.group(1).lower() + if emotion not in EMOTION_LABELS: + logger.warning("LLM returned unsupported emotion=%s, using neutral", emotion) + emotion = DEFAULT_EMOTION + + self.current_emotion = emotion + self._emotion_prefix_done = True + self._emotion_prefix_buffer = "" + logger.info("LLM emotion selected: %s", emotion) + return + + candidate = self._emotion_prefix_buffer.lstrip().lower() + might_still_be_prefix = ( + not candidate + or "" not in candidate) + ) + if might_still_be_prefix and len(candidate) <= MAX_EMOTION_PREFIX_CHARS: + return + + self._emotion_prefix_done = True + self._emotion_prefix_buffer = "" + logger.warning("LLM response did not start with an emotion prefix") async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: if self._memory_client is None: @@ -294,6 +374,69 @@ def _select_mode(user_query: str) -> str: return GENERAL_MODE +def _with_emotion_instructions(instructions: str) -> str: + return f"{instructions}\n\n{EMOTION_INSTRUCTIONS}" + + +async def _strip_emotion_for_tts(text: AsyncIterable[str]) -> AsyncIterable[str]: + prefix_buffer = "" + scanning_prefix = True + + async for chunk in text: + if not chunk: + continue + + if scanning_prefix: + prefix_buffer += chunk + cleaned, done = _strip_leading_tts_emotion(prefix_buffer) + if not done: + continue + + scanning_prefix = False + prefix_buffer = "" + if cleaned: + yield _strip_inline_tts_emotion(cleaned) + continue + + cleaned = _strip_inline_tts_emotion(chunk) + if cleaned: + yield cleaned + + if scanning_prefix and prefix_buffer: + cleaned, _ = _strip_leading_tts_emotion(prefix_buffer, force=True) + cleaned = _strip_inline_tts_emotion(cleaned) + if cleaned: + yield cleaned + + +def _strip_leading_tts_emotion(text: str, *, force: bool = False) -> tuple[str, bool]: + match = TTS_EMOTION_MARKUP_RE.match(text) + if match: + return text[match.end() :], True + + match = TTS_EMOTION_LINE_RE.match(text) + if match: + return text[match.end() :], True + + candidate = text.lstrip().lower() + might_still_be_emotion = ( + not candidate + or "" not in candidate) + or "emotion".startswith(candidate) + or (candidate.startswith("emotion") and len(candidate) <= MAX_EMOTION_PREFIX_CHARS) + or "情绪".startswith(candidate) + ) + if not force and might_still_be_emotion and len(candidate) <= MAX_EMOTION_PREFIX_CHARS: + return "", False + + return text, True + + +def _strip_inline_tts_emotion(text: str) -> str: + return TTS_EMOTION_MARKUP_RE.sub("", text) + + def _is_room_locator_query(normalized_text: str) -> bool: room_context_hints = ( "房间", @@ -500,6 +643,7 @@ async def entrypoint(ctx: JobContext) -> None: INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE")) if not LLM_API_KEY: raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}") + logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default") TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv( "VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"