feat: add emotion prompt
This commit is contained in:
154
custom_agent.py
154
custom_agent.py
@ -2,6 +2,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import dataclass
|
||||
@ -57,6 +58,12 @@ GENERAL_INSTRUCTIONS = """
|
||||
回答自然、简洁、准确。
|
||||
""".strip()
|
||||
|
||||
EMOTION_INSTRUCTIONS = """
|
||||
每次回复必须先输出一个情绪标签,格式严格为:<emotion=neutral>
|
||||
emotion 只能从 neutral、happy、sad、angry、surprised、fearful、calm、concerned 中选择。
|
||||
情绪标签之后直接输出给用户的正常回复,不要解释标签。
|
||||
""".strip()
|
||||
|
||||
ROOM_LOCATOR_MODE = "room_locator"
|
||||
GENERAL_MODE = "general"
|
||||
VOICE_INPUT_MODE = "voice"
|
||||
@ -64,6 +71,25 @@ VISION_VOICE_INPUT_MODE = "vision_voice"
|
||||
AUTO_INPUT_MODE = "auto"
|
||||
VISION_FRAME_TOPIC = "vision.frame"
|
||||
|
||||
DEFAULT_EMOTION = "neutral"
|
||||
EMOTION_LABELS = {
|
||||
"neutral",
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"surprised",
|
||||
"fearful",
|
||||
"calm",
|
||||
"concerned",
|
||||
}
|
||||
EMOTION_PREFIX_RE = re.compile(r"^\s*<emotion=([a-z_]+)>\s*", re.IGNORECASE)
|
||||
TTS_EMOTION_MARKUP_RE = re.compile(r"<\s*emotion\s*=\s*[^>]{1,80}>\s*", re.IGNORECASE)
|
||||
TTS_EMOTION_LINE_RE = re.compile(
|
||||
r"^\s*(?:emotion|情绪)\s*[::=]\s*[\w\u4e00-\u9fff-]{1,40}\s*[,,。.!!\s-]*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
MAX_EMOTION_PREFIX_CHARS = 80
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisionFrame:
|
||||
@ -111,13 +137,16 @@ class CustomAgent(Agent):
|
||||
vision_llm: llm.LLM | None = None,
|
||||
model_image_save_dir: Path | None = None,
|
||||
) -> None:
|
||||
super().__init__(instructions=GENERAL_INSTRUCTIONS)
|
||||
super().__init__(instructions=_with_emotion_instructions(GENERAL_INSTRUCTIONS))
|
||||
self._memory_client = memory_client
|
||||
self._vision_store = vision_store
|
||||
self._input_mode = input_mode
|
||||
self._text_llm = text_llm
|
||||
self._vision_llm = vision_llm
|
||||
self._model_image_save_dir = model_image_save_dir
|
||||
self.current_emotion = DEFAULT_EMOTION
|
||||
self._emotion_prefix_buffer = ""
|
||||
self._emotion_prefix_done = True
|
||||
|
||||
async def on_enter(self) -> None:
|
||||
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
||||
@ -144,9 +173,9 @@ class CustomAgent(Agent):
|
||||
chat_ctx = chat_ctx.copy()
|
||||
update_chat_instructions(
|
||||
chat_ctx,
|
||||
instructions=ROOM_LOCATOR_INSTRUCTIONS
|
||||
if mode == ROOM_LOCATOR_MODE
|
||||
else GENERAL_INSTRUCTIONS,
|
||||
instructions=_with_emotion_instructions(
|
||||
ROOM_LOCATOR_INSTRUCTIONS if mode == ROOM_LOCATOR_MODE else GENERAL_INSTRUCTIONS
|
||||
),
|
||||
add_if_missing=True,
|
||||
)
|
||||
|
||||
@ -173,6 +202,8 @@ class CustomAgent(Agent):
|
||||
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||
first_chunk_at: float | None = None
|
||||
chunk_count = 0
|
||||
self._emotion_prefix_buffer = ""
|
||||
self._emotion_prefix_done = False
|
||||
try:
|
||||
async for chunk in llm_result:
|
||||
chunk_count += 1
|
||||
@ -182,7 +213,8 @@ class CustomAgent(Agent):
|
||||
"LLM first chunk after %.3fs",
|
||||
first_chunk_at - llm_node_started_at,
|
||||
)
|
||||
yield chunk
|
||||
async for output_chunk in self._observe_emotion_prefix(chunk):
|
||||
yield output_chunk
|
||||
finally:
|
||||
finished_at = time.perf_counter()
|
||||
logger.info(
|
||||
@ -194,6 +226,9 @@ class CustomAgent(Agent):
|
||||
|
||||
return _instrumented_stream()
|
||||
|
||||
def tts_node(self, text: AsyncIterable[str], model_settings: ModelSettings):
|
||||
return Agent.default.tts_node(self, _strip_emotion_for_tts(text), model_settings)
|
||||
|
||||
def _consume_vision_frame(self) -> VisionFrame | None:
|
||||
if self._input_mode == VOICE_INPUT_MODE or self._vision_store is None:
|
||||
return None
|
||||
@ -255,6 +290,51 @@ class CustomAgent(Agent):
|
||||
yield chunk
|
||||
|
||||
return _stream()
|
||||
async def _observe_emotion_prefix(
|
||||
self, chunk: llm.ChatChunk | str | FlushSentinel
|
||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||
if isinstance(chunk, str):
|
||||
self._consume_emotion_prefix(chunk)
|
||||
yield chunk
|
||||
return
|
||||
|
||||
if isinstance(chunk, llm.ChatChunk) and chunk.delta and chunk.delta.content:
|
||||
self._consume_emotion_prefix(chunk.delta.content)
|
||||
yield chunk
|
||||
return
|
||||
|
||||
yield chunk
|
||||
|
||||
def _consume_emotion_prefix(self, content: str) -> None:
|
||||
if self._emotion_prefix_done:
|
||||
return
|
||||
|
||||
self._emotion_prefix_buffer += content
|
||||
match = EMOTION_PREFIX_RE.match(self._emotion_prefix_buffer)
|
||||
if match:
|
||||
emotion = match.group(1).lower()
|
||||
if emotion not in EMOTION_LABELS:
|
||||
logger.warning("LLM returned unsupported emotion=%s, using neutral", emotion)
|
||||
emotion = DEFAULT_EMOTION
|
||||
|
||||
self.current_emotion = emotion
|
||||
self._emotion_prefix_done = True
|
||||
self._emotion_prefix_buffer = ""
|
||||
logger.info("LLM emotion selected: %s", emotion)
|
||||
return
|
||||
|
||||
candidate = self._emotion_prefix_buffer.lstrip().lower()
|
||||
might_still_be_prefix = (
|
||||
not candidate
|
||||
or "<emotion=".startswith(candidate)
|
||||
or (candidate.startswith("<emotion=") and ">" not in candidate)
|
||||
)
|
||||
if might_still_be_prefix and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
|
||||
return
|
||||
|
||||
self._emotion_prefix_done = True
|
||||
self._emotion_prefix_buffer = ""
|
||||
logger.warning("LLM response did not start with an emotion prefix")
|
||||
|
||||
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
||||
if self._memory_client is None:
|
||||
@ -294,6 +374,69 @@ def _select_mode(user_query: str) -> str:
|
||||
return GENERAL_MODE
|
||||
|
||||
|
||||
def _with_emotion_instructions(instructions: str) -> str:
|
||||
return f"{instructions}\n\n{EMOTION_INSTRUCTIONS}"
|
||||
|
||||
|
||||
async def _strip_emotion_for_tts(text: AsyncIterable[str]) -> AsyncIterable[str]:
|
||||
prefix_buffer = ""
|
||||
scanning_prefix = True
|
||||
|
||||
async for chunk in text:
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
if scanning_prefix:
|
||||
prefix_buffer += chunk
|
||||
cleaned, done = _strip_leading_tts_emotion(prefix_buffer)
|
||||
if not done:
|
||||
continue
|
||||
|
||||
scanning_prefix = False
|
||||
prefix_buffer = ""
|
||||
if cleaned:
|
||||
yield _strip_inline_tts_emotion(cleaned)
|
||||
continue
|
||||
|
||||
cleaned = _strip_inline_tts_emotion(chunk)
|
||||
if cleaned:
|
||||
yield cleaned
|
||||
|
||||
if scanning_prefix and prefix_buffer:
|
||||
cleaned, _ = _strip_leading_tts_emotion(prefix_buffer, force=True)
|
||||
cleaned = _strip_inline_tts_emotion(cleaned)
|
||||
if cleaned:
|
||||
yield cleaned
|
||||
|
||||
|
||||
def _strip_leading_tts_emotion(text: str, *, force: bool = False) -> tuple[str, bool]:
|
||||
match = TTS_EMOTION_MARKUP_RE.match(text)
|
||||
if match:
|
||||
return text[match.end() :], True
|
||||
|
||||
match = TTS_EMOTION_LINE_RE.match(text)
|
||||
if match:
|
||||
return text[match.end() :], True
|
||||
|
||||
candidate = text.lstrip().lower()
|
||||
might_still_be_emotion = (
|
||||
not candidate
|
||||
or "<emotion=".startswith(candidate)
|
||||
or (candidate.startswith("<emotion") and ">" not in candidate)
|
||||
or "emotion".startswith(candidate)
|
||||
or (candidate.startswith("emotion") and len(candidate) <= MAX_EMOTION_PREFIX_CHARS)
|
||||
or "情绪".startswith(candidate)
|
||||
)
|
||||
if not force and might_still_be_emotion and len(candidate) <= MAX_EMOTION_PREFIX_CHARS:
|
||||
return "", False
|
||||
|
||||
return text, True
|
||||
|
||||
|
||||
def _strip_inline_tts_emotion(text: str) -> str:
|
||||
return TTS_EMOTION_MARKUP_RE.sub("", text)
|
||||
|
||||
|
||||
def _is_room_locator_query(normalized_text: str) -> bool:
|
||||
room_context_hints = (
|
||||
"房间",
|
||||
@ -500,6 +643,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
|
||||
if not LLM_API_KEY:
|
||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||
logger.info("Using LLM model=%s base_url=%s", LLM_MODEL, LLM_BASE_URL or "OpenAI default")
|
||||
|
||||
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
||||
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
|
||||
|
||||
Reference in New Issue
Block a user