Compare commits

...

7 Commits

Author SHA1 Message Date
2064db15dc add env 2026-05-22 15:36:38 +08:00
f272053a95 fix: prompt 2026-05-22 14:46:10 +08:00
fba51a5257 perf: improve speed 2026-05-15 10:44:31 +08:00
b18c5b40da fix: tts parameters 2026-05-14 15:33:20 +08:00
89011fed81 fix: memory recall fuction prompt 2026-05-14 11:18:04 +08:00
3a2f5c4252 feat: memory recall fuction 2026-05-14 10:16:08 +08:00
746053fd58 fix 2026-05-13 15:35:04 +08:00
4 changed files with 716 additions and 56 deletions

70
.env Normal file
View File

@ -0,0 +1,70 @@
# LiveKit connection
LIVEKIT_URL=ws://localhost:7880
LIVEKIT_API_KEY=
LIVEKIT_API_SECRET=
CUSTOM_AGENT_NAME=my-agent
# ASR blackbox
CUSTOM_ASR_URL=http://localhost:5000/asr-blackbox
CUSTOM_ASR_MODEL=qwen
CUSTOM_ASR_LANGUAGE=Chinese
CUSTOM_ASR_OUTPUT_LANGUAGE=zh
CUSTOM_ASR_HOTWORDS=
CUSTOM_ASR_ITN=
CUSTOM_ASR_CHUNK_MODE=
# OpenAI-compatible LLM
# CUSTOM_LLM_BASE_URL=https://oai.bwgdi.com/v1
# CUSTOM_LLM_MODEL=Qwen3.6-35B
# CUSTOM_LLM_API_KEY=
# CUSTOM_LLM_VERIFY_SSL=false
CUSTOM_LLM_BASE_URL=http://localhost/v1
CUSTOM_LLM_MODEL=Qwen-VL
CUSTOM_LLM_API_KEY=
CUSTOM_LLM_VERIFY_SSL=false
# CUSTOM_LLM_BASE_URL=https://api.deepseek.com
# CUSTOM_LLM_MODEL=deepseek-v4-flash
# CUSTOM_LLM_API_KEY=
# CUSTOM_LLM_VERIFY_SSL=false
# TTS blackbox
CUSTOM_TTS_URL=http://localhost:5000/tts-blackbox
CUSTOM_TTS_MODEL=voxcpmtts
# CUSTOM_TTS_PROMPT_WAV=/home/verachen/Workspace/livekit/agents/2food.wav
CUSTOM_TTS_STREAMING=true
# CUSTOM_TTS_PROMPT_TEXT=澳门有乜嘢好食嘅
# VoxCPM TTS parameters
VOXCPM_CFG_VALUE=2.0
VOXCPM_INFERENCE_TIMESTEPS=10
VOXCPM_DO_NORMALIZE=true
VOXCPM_DENOISE=true
VOXCPM_RETRY_BADCASE=true
VOXCPM_RETRY_BADCASE_MAX_TIMES=3
VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD=6.0
# MeloTTS parameters
CUSTOM_TTS_SPEED=1.0
# CosyVoice parameters
CUSTOM_TTS_SPK_ID=
CUSTOM_TTS_MODE=
CUSTOM_TTS_INSTRUCT_TEXT=
# GPT-SoVITS parameters
CUSTOM_TTS_TEXT_LANG=zh
CUSTOM_TTS_PROMPT_LANG=zh
CUSTOM_TTS_TEXT_SPLIT_METHOD=cut0
CUSTOM_TTS_BATCH_SIZE=1
CUSTOM_TTS_MEDIA_TYPE=wav
CUSTOM_TTS_REF_AUDIO_PATH=
CUSTOM_MEMORY_URL=http://localhost:8766/api/room_graph
CUSTOM_MEMORY_TIMEOUT=2
CUSTOM_MEMORY_MAX_CHARS=2000
CUSTOM_MEMORY_API_KEY=
CUSTOM_PREEMPTIVE_GENERATION=true

View File

@ -1,28 +1,36 @@
import logging
import os
import time
from collections.abc import AsyncIterable
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from memory import MemoryRecallClient
from tts import BlackboxTTS
from asr import BlackboxSTT
from livekit.agents import (
Agent,
AgentServer,
AgentSession,
ChatContext,
ChatMessage,
FlushSentinel,
JobContext,
JobProcess,
MetricsCollectedEvent,
ModelSettings,
RecordingOptions,
TurnHandlingOptions,
cli,
llm,
metrics,
room_io,
stt,
)
from livekit.agents.voice.generation import update_instructions as update_chat_instructions
from livekit.plugins import openai, silero
from livekit.plugins.turn_detector.multilingual import MultilingualModel
from tts import BlackboxTTS
logger = logging.getLogger("custom-agent")
@ -30,19 +38,237 @@ CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
ROOM_LOCATOR_INSTRUCTIONS = """
你是一个房间物品定位助手。
当用户询问房间内某个物品的位置时:
- 只用一句中文回答
- 描述目标物品和其他物品的相对位置关系
- 不要使用 Markdown、emoji、列表、标题、坐标区域标签
- 不要解释推理过程
如果用户的问题与房间物品定位无关,则正常回答用户问题。
""".strip()
GENERAL_INSTRUCTIONS = """
你是一个智能语音助手。
正常回答用户问题。
回答自然、简洁、准确。
""".strip()
ROOM_LOCATOR_MODE = "room_locator"
GENERAL_MODE = "general"
class CustomAgent(Agent):
def __init__(self) -> None:
super().__init__(
instructions="Your name is Kelly, built by LiveKit. You are a helpful assistant."
"Keep your responses concise and friendly."
"You are interacting with the user via a local ASR and LLM pipeline.",
)
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
super().__init__(instructions=GENERAL_INSTRUCTIONS)
self._memory_client = memory_client
async def on_enter(self) -> None:
# self.session.generate_reply(instructions="greet the user and introduce yourself")
pass
async def llm_node(
self,
chat_ctx: ChatContext,
tools: list[llm.Tool],
model_settings: ModelSettings,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
llm_node_started_at = time.perf_counter()
user_query = _latest_user_text(chat_ctx)
mode = _select_mode(user_query)
logger.info("Selected agent mode: %s", mode)
chat_ctx = chat_ctx.copy()
update_chat_instructions(
chat_ctx,
instructions=ROOM_LOCATOR_INSTRUCTIONS
if mode == ROOM_LOCATOR_MODE
else GENERAL_INSTRUCTIONS,
add_if_missing=True,
)
if mode == ROOM_LOCATOR_MODE:
memory_context = await self._recall_room_memory(chat_ctx)
if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
if not hasattr(llm_result, "__aiter__"):
elapsed = time.perf_counter() - llm_node_started_at
logger.info("LLM node completed without streaming in %.3fs", elapsed)
return llm_result
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
first_chunk_at: float | None = None
chunk_count = 0
try:
async for chunk in llm_result:
chunk_count += 1
if first_chunk_at is None:
first_chunk_at = time.perf_counter()
logger.info(
"LLM first chunk after %.3fs",
first_chunk_at - llm_node_started_at,
)
yield chunk
finally:
finished_at = time.perf_counter()
logger.info(
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
finished_at - llm_node_started_at,
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
chunk_count,
)
return _instrumented_stream()
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
if self._memory_client is None:
return ""
user_query = _latest_user_text(chat_ctx)
if not user_query:
return ""
started_at = time.perf_counter()
try:
recalled = await self._memory_client.recall(user_query)
elapsed = time.perf_counter() - started_at
logger.info(
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
elapsed,
len(user_query),
len(recalled),
)
return recalled
except Exception:
logger.exception(
"Unexpected memory recall failure after %.3fs",
time.perf_counter() - started_at,
)
return ""
def _select_mode(user_query: str) -> str:
normalized = _normalize_text(user_query)
if not normalized:
return GENERAL_MODE
if _is_room_locator_query(normalized):
return ROOM_LOCATOR_MODE
return GENERAL_MODE
def _is_room_locator_query(normalized_text: str) -> bool:
room_context_hints = (
"房间",
"屋里",
"屋子",
"室内",
"客厅",
"卧室",
"书房",
"厨房",
"餐厅",
"沙发",
"",
"",
"",
"",
"",
"",
"电视",
"空调",
"书架",
"",
"冰箱",
"茶几",
"电脑",
"",
"",
"相机",
"植物",
)
spatial_hints = (
"在哪里",
"在哪",
"位置",
"方位",
"旁边",
"左边",
"右边",
"前面",
"后面",
"上面",
"下面",
"附近",
"对面",
"靠近",
"挨着",
"隔着",
)
software_hints = (
"python",
"代码",
"函数",
"class",
"bug",
"日志",
"logging",
"api",
"server",
"agent",
"prompt",
"模型",
"数据库",
"git",
"uv",
"ruff",
"mypy",
)
if any(hint in normalized_text for hint in software_hints):
return False
has_spatial_hint = any(hint in normalized_text for hint in spatial_hints)
has_room_context_hint = any(hint in normalized_text for hint in room_context_hints)
if has_spatial_hint and has_room_context_hint:
return True
if has_spatial_hint and len(normalized_text) <= 12:
return True
return False
def _normalize_text(text: str) -> str:
return "".join(text.split()).lower()
def _latest_user_text(chat_ctx: ChatContext) -> str:
for item in reversed(chat_ctx.items):
if isinstance(item, ChatMessage) and item.role == "user":
return (item.text_content or "").strip()
return ""
def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext:
chat_ctx = chat_ctx.copy()
for index in range(len(chat_ctx.items) - 1, -1, -1):
item = chat_ctx.items[index]
if isinstance(item, ChatMessage) and item.role == "user":
user_msg = item.model_copy(deep=True)
user_msg.content = [memory_context]
chat_ctx.items[index] = user_msg
return chat_ctx
chat_ctx.items.append(ChatMessage(role="user", content=[memory_context]))
return chat_ctx
server = AgentServer()
@ -66,19 +292,23 @@ async def entrypoint(ctx: JobContext) -> None:
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
MINIMAX_BASE_URL = os.getenv("MINIMAX_LLM_BASE_URL", "https://oai.bwgdi.com/v1")
MINIMAX_MODEL = os.getenv("MINIMAX_LLM_MODEL", "qwen-max")
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY")
if not MINIMAX_API_KEY:
raise RuntimeError(f"MINIMAX_API_KEY is not set in {CUSTOM_ENV_PATH}")
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
if not LLM_API_KEY:
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
"VOXCPM_TTS_URL", "http://localhost:5000/tts-blackbox"
)
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
blackbox_stt = BlackboxSTT(
url=ASR_URL,
@ -94,22 +324,27 @@ async def entrypoint(ctx: JobContext) -> None:
import httpx
from openai import AsyncClient as OpenAIAsyncClient
# Create a custom HTTP client that disables SSL verification
http_client = httpx.AsyncClient(verify=False)
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
# Create the OpenAI AsyncClient with the custom HTTP client
openai_client = OpenAIAsyncClient(
api_key=MINIMAX_API_KEY,
base_url=MINIMAX_BASE_URL,
http_client=http_client,
)
if LLM_BASE_URL:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL,
http_client=http_client,
)
else:
openai_client = OpenAIAsyncClient(
api_key=LLM_API_KEY,
http_client=http_client,
)
session: AgentSession = AgentSession(
# 1. Custom ASR blackbox with StreamAdapter
stt=stt_stream,
# 2. Minimax LLM - Using OpenAI plugin with local base_url
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
llm=openai.LLM(
model=MINIMAX_MODEL,
model=LLM_MODEL,
client=openai_client,
),
# 3. TTS blackbox
@ -117,7 +352,7 @@ async def entrypoint(ctx: JobContext) -> None:
url=TTS_URL,
model_name=TTS_MODEL,
params=_tts_params_from_env(TTS_MODEL),
prompt_wav_path=os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV"),
prompt_wav_path=_tts_prompt_wav_from_env(TTS_MODEL),
sample_rate=TTS_SAMPLE_RATE,
num_channels=TTS_NUM_CHANNELS,
),
@ -130,7 +365,7 @@ async def entrypoint(ctx: JobContext) -> None:
"false_interruption_timeout": 1.0,
},
),
preemptive_generation=False,
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
aec_warmup_duration=3.0,
tts_text_transforms=[
"filter_emoji",
@ -142,8 +377,30 @@ async def entrypoint(ctx: JobContext) -> None:
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
metrics.log_metrics(ev.metrics)
@session.on("conversation_item_added")
def _on_conversation_item_added(event) -> None:
item = getattr(event, "item", None)
if not isinstance(item, ChatMessage):
return
if item.role == "user" and item.metrics:
logger.info("User turn metrics: %s", item.metrics)
elif item.role == "assistant" and item.metrics:
logger.info("Assistant turn metrics: %s", item.metrics)
memory_client = (
MemoryRecallClient(
url=MEMORY_URL,
timeout=MEMORY_TIMEOUT,
max_chars=MEMORY_MAX_CHARS,
api_key=MEMORY_API_KEY,
)
if MEMORY_URL
else None
)
await session.start(
agent=CustomAgent(),
agent=CustomAgent(memory_client=memory_client),
room=ctx.room,
room_options=room_io.RoomOptions(
audio_output=room_io.AudioOutputOptions(
@ -160,49 +417,55 @@ def _tts_params_from_env(model_name: str) -> dict[str, str]:
model_name = model_name.lower()
if model_name == "voxcpmtts":
params.update(
{
"streaming": os.getenv("CUSTOM_TTS_STREAMING", "false"),
"prompt_text": os.getenv(
"CUSTOM_TTS_PROMPT_TEXT",
os.getenv("VOXCPM_PROMPT_TEXT", "澳门有乜嘢好食嘅"),
),
"cfg_value": os.getenv("VOXCPM_CFG_VALUE", "2.0"),
"inference_timesteps": os.getenv("VOXCPM_INFERENCE_TIMESTEPS", "10"),
"do_normalize": os.getenv("VOXCPM_DO_NORMALIZE", "true"),
"denoise": os.getenv("VOXCPM_DENOISE", "true"),
"retry_badcase": os.getenv("VOXCPM_RETRY_BADCASE", "true"),
"retry_badcase_max_times": os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES", "3"),
"retry_badcase_ratio_threshold": os.getenv(
"VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD", "6.0"
),
}
_set_if_present(params, "streaming", os.getenv("CUSTOM_TTS_STREAMING"))
_set_if_present(
params,
"prompt_text",
os.getenv("CUSTOM_TTS_PROMPT_TEXT") or os.getenv("VOXCPM_PROMPT_TEXT"),
)
_set_if_present(params, "cfg_value", os.getenv("VOXCPM_CFG_VALUE"))
_set_if_present(params, "inference_timesteps", os.getenv("VOXCPM_INFERENCE_TIMESTEPS"))
_set_if_present(params, "do_normalize", os.getenv("VOXCPM_DO_NORMALIZE"))
_set_if_present(params, "denoise", os.getenv("VOXCPM_DENOISE"))
_set_if_present(params, "retry_badcase", os.getenv("VOXCPM_RETRY_BADCASE"))
_set_if_present(
params,
"retry_badcase_max_times",
os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES"),
)
_set_if_present(
params,
"retry_badcase_ratio_threshold",
os.getenv("VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD"),
)
elif model_name == "melotts":
params["speed"] = os.getenv("CUSTOM_TTS_SPEED", "1.0")
_set_if_present(params, "speed", os.getenv("CUSTOM_TTS_SPEED"))
elif model_name == "cosyvoicetts":
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
elif model_name == "sovitstts":
params.update(
{
"text_lang": os.getenv("CUSTOM_TTS_TEXT_LANG", "zh"),
"prompt_lang": os.getenv("CUSTOM_TTS_PROMPT_LANG", "zh"),
"text_split_method": os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD", "cut0"),
"batch_size": os.getenv("CUSTOM_TTS_BATCH_SIZE", "1"),
"media_type": os.getenv("CUSTOM_TTS_MEDIA_TYPE", "wav"),
"streaming_mode": os.getenv("CUSTOM_TTS_STREAMING", "false"),
}
)
_set_if_present(params, "text_lang", os.getenv("CUSTOM_TTS_TEXT_LANG"))
_set_if_present(params, "prompt_lang", os.getenv("CUSTOM_TTS_PROMPT_LANG"))
_set_if_present(params, "text_split_method", os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD"))
_set_if_present(params, "batch_size", os.getenv("CUSTOM_TTS_BATCH_SIZE"))
_set_if_present(params, "media_type", os.getenv("CUSTOM_TTS_MEDIA_TYPE"))
_set_if_present(params, "streaming_mode", os.getenv("CUSTOM_TTS_STREAMING"))
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
return params
def _set_if_present(params: dict[str, str], key: str, value: Optional[str]) -> None:
def _tts_prompt_wav_from_env(model_name: str) -> str | None:
if model_name.lower() != "voxcpmtts":
return None
return os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV") or None
def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None:
if value:
params[key] = value
@ -218,6 +481,17 @@ def _env_int(name: str, default: int) -> int:
return default
def _env_float(name: str, default: float) -> float:
value = os.getenv(name)
if not value:
return default
try:
return float(value)
except ValueError:
logger.warning("Invalid float for %s=%r, using %s", name, value, default)
return default
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:

292
memory.py Normal file
View File

@ -0,0 +1,292 @@
from __future__ import annotations
import asyncio
import json
import logging
import re
from typing import Any
import aiohttp
from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, utils
logger = logging.getLogger("memory-recall")
_LOCATION_STOPWORDS = {
"哪里",
"在哪",
"在哪里",
"哪儿",
"位置",
"什么地方",
"帮我找",
"帮我寻找",
"找一下",
"",
"请问",
"",
"",
"",
}
class MemoryRecallClient:
def __init__(
self,
*,
url: str,
timeout: float = 5.0,
max_chars: int = 2000,
api_key: str | None = None,
http_session: aiohttp.ClientSession | None = None,
) -> None:
self._url = url
self._timeout = timeout
self._max_chars = max_chars
self._api_key = api_key
self._http_session = http_session
self._cached_payload: Any | None = None
def _ensure_session(self) -> aiohttp.ClientSession:
if self._http_session is None:
self._http_session = utils.http_context.http_session()
return self._http_session
async def recall(self, query: str) -> str:
query = query.strip()
if not query:
return ""
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
try:
async with self._ensure_session().get(
self._url,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self._timeout),
) as resp:
if resp.status != 200:
error_text = await resp.text()
raise APIStatusError(
message=f"Memory recall error: {error_text}",
status_code=resp.status,
request_id=None,
body=error_text,
)
try:
data = await resp.json()
except aiohttp.ContentTypeError:
data = await resp.text()
self._cached_payload = data
return self._format_memory(data, query)
except asyncio.TimeoutError:
logger.warning(
"Memory recall timed out after %.1fs, using cached room graph", self._timeout
)
return self._format_cached_memory(query)
except aiohttp.ClientError as e:
logger.warning("Memory recall connection error: %s, using cached room graph", e)
return self._format_cached_memory(query)
except (APIConnectionError, APIStatusError, APITimeoutError) as e:
logger.warning("Memory recall failed: %s, using cached room graph", e)
return self._format_cached_memory(query)
def _format_memory(self, data: Any, query: str) -> str:
memory = _format_room_graph_memory(data, query)
if len(memory) > self._max_chars:
memory = memory[: self._max_chars].rstrip()
return memory
def _format_cached_memory(self, query: str) -> str:
if self._cached_payload is None:
return ""
return self._format_memory(self._cached_payload, query)
def _format_room_graph_memory(payload: Any, query: str) -> str:
if not isinstance(payload, dict):
logger.warning("Unsupported room graph response: %s", payload)
return ""
objects = payload.get("objects", [])
relations = payload.get("relations", [])
summary = payload.get("summary", "")
if not objects and not relations and not summary:
return ""
query_terms = _query_terms(query)
relevant_objects, relevant_relations = _relevant_room_graph(
objects=objects,
relations=relations,
query_terms=query_terms,
)
objects_text = json.dumps(
relevant_objects or _compact_items(objects, limit=12),
ensure_ascii=False,
separators=(",", ":"),
)
relations_text = json.dumps(
relevant_relations or _compact_items(relations, limit=24),
ensure_ascii=False,
separators=(",", ":"),
)
prompt = f"""
你是一个物品定位助手。
目标物品:{query}
相关物品:{objects_text}
相关空间关系:{relations_text}
房间概览:{summary}
回答要求:
1. 只说明它和其他物品的位置关系。
2. 不要编造不存在的关系。
3. 如果信息不足,请说“根据当前房间记忆,无法确定准确位置”。
4. 回答尽量简短,例如:“黑色背包在透明塑料盒的左边,在显示器的左边。”
5. 不要输出 Markdown、emoji、标题、列表、项目符号、坐标区域标签、水平/深度/高度分析或解释过程。
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
""".strip()
logger.info(
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
query_terms,
len(relevant_objects),
len(objects) if isinstance(objects, list) else 0,
len(relevant_relations),
len(relations) if isinstance(relations, list) else 0,
len(prompt),
)
return prompt
def _query_terms(query: str) -> list[str]:
normalized = re.sub(r"[\s?。!,、,.!]", "", query)
for word in _LOCATION_STOPWORDS:
normalized = normalized.replace(word, "")
terms = [normalized] if normalized else []
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
if token not in _LOCATION_STOPWORDS and token not in terms:
terms.append(token)
return terms[:4]
def _relevant_room_graph(
*,
objects: Any,
relations: Any,
query_terms: list[str],
) -> tuple[list[Any], list[Any]]:
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
return [], []
matched_ids: set[str] = set()
matched_objects: list[Any] = []
object_by_id: dict[str, Any] = {}
for obj in objects:
obj_id = _object_id(obj)
if obj_id:
object_by_id[obj_id] = obj
obj_text = _compact_text(obj)
if any(term and term in obj_text for term in query_terms):
matched_objects.append(obj)
if obj_id:
matched_ids.add(obj_id)
relevant_relations: list[Any] = []
related_ids: set[str] = set(matched_ids)
for relation in relations:
relation_text = _compact_text(relation)
relation_ids = _ids_in_value(relation)
if (
any(term and term in relation_text for term in query_terms)
or bool(matched_ids.intersection(relation_ids))
):
relevant_relations.append(relation)
related_ids.update(relation_ids)
relevant_objects = list(matched_objects)
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
for obj_id in related_ids:
obj = object_by_id.get(obj_id)
key = _object_key(obj)
if obj is not None and key not in seen_object_keys:
relevant_objects.append(obj)
seen_object_keys.add(key)
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
def _compact_items(items: Any, *, limit: int) -> list[Any]:
if not isinstance(items, list):
return []
return [_compact_item(item) for item in items[:limit]]
def _compact_item(item: Any) -> Any:
if not isinstance(item, dict):
return item
preferred_keys = (
"id",
"name",
"label",
"class",
"category",
"type",
"text",
"source",
"target",
"subject",
"object",
"relation",
"predicate",
"description",
)
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
return compact or item
def _object_id(obj: Any) -> str | None:
if not isinstance(obj, dict):
return None
for key in ("id", "object_id", "uuid", "name", "label"):
value = obj.get(key)
if isinstance(value, (str, int)):
return str(value)
return None
def _object_key(obj: Any) -> str:
return _object_id(obj) or _compact_text(obj)
def _ids_in_value(value: Any) -> set[str]:
ids: set[str] = set()
if isinstance(value, dict):
for key, item in value.items():
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
if isinstance(item, (str, int)):
ids.add(str(item))
elif isinstance(item, dict):
obj_id = _object_id(item)
if obj_id:
ids.add(obj_id)
ids.update(_ids_in_value(item))
elif isinstance(value, list):
for item in value:
ids.update(_ids_in_value(item))
return ids
def _compact_text(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))

24
tts.py
View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import logging
import os
import time
import wave
from collections.abc import Mapping
from io import BytesIO
@ -88,6 +89,7 @@ class BlackboxTTSStream(tts.ChunkedStream):
self._tts: BlackboxTTS = tts
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
started_at = time.perf_counter()
form = aiohttp.FormData(default_to_multipart=True)
form.add_field("text", self.input_text)
form.add_field("model_name", self._tts._model_name)
@ -131,6 +133,9 @@ class BlackboxTTSStream(tts.ChunkedStream):
content_type = resp.headers.get("Content-Type", "audio/wav")
logged_wav_format = False
wav_header_probe = bytearray()
first_audio_at: float | None = None
chunk_count = 0
total_bytes = 0
output_emitter.initialize(
request_id=utils.shortuuid(),
sample_rate=self._tts.sample_rate,
@ -140,6 +145,16 @@ class BlackboxTTSStream(tts.ChunkedStream):
async for data, _ in resp.content.iter_chunks():
if data:
chunk_count += 1
total_bytes += len(data)
if first_audio_at is None:
first_audio_at = time.perf_counter()
logger.info(
"TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)",
first_audio_at - started_at,
len(self.input_text),
len(data),
)
if not logged_wav_format:
wav_header_probe.extend(data)
logged_wav_format = _log_wav_format(
@ -156,6 +171,15 @@ class BlackboxTTSStream(tts.ChunkedStream):
logged_wav_format = True
output_emitter.push(data)
output_emitter.flush()
finished_at = time.perf_counter()
logger.info(
"TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)",
finished_at - started_at,
(first_audio_at - started_at) if first_audio_at else -1.0,
chunk_count,
total_bytes,
len(self.input_text),
)
except asyncio.TimeoutError as e:
raise APITimeoutError("TTS blackbox request timed out") from e
except aiohttp.ClientError as e: