perf: improve speed

This commit is contained in:
0Xiao0
2026-05-15 10:44:31 +08:00
parent b18c5b40da
commit fba51a5257
3 changed files with 258 additions and 24 deletions

View File

@ -1,5 +1,6 @@
import logging import logging
import os import os
import time
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from pathlib import Path from pathlib import Path
@ -61,11 +62,40 @@ class CustomAgent(Agent):
tools: list[llm.Tool], tools: list[llm.Tool],
model_settings: ModelSettings, model_settings: ModelSettings,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
llm_node_started_at = time.perf_counter()
memory_context = await self._recall_room_memory(chat_ctx) memory_context = await self._recall_room_memory(chat_ctx)
if memory_context: if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
return Agent.default.llm_node(self, chat_ctx, tools, model_settings) llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
if not hasattr(llm_result, "__aiter__"):
elapsed = time.perf_counter() - llm_node_started_at
logger.info("LLM node completed without streaming in %.3fs", elapsed)
return llm_result
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
first_chunk_at: float | None = None
chunk_count = 0
try:
async for chunk in llm_result:
chunk_count += 1
if first_chunk_at is None:
first_chunk_at = time.perf_counter()
logger.info(
"LLM first chunk after %.3fs",
first_chunk_at - llm_node_started_at,
)
yield chunk
finally:
finished_at = time.perf_counter()
logger.info(
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
finished_at - llm_node_started_at,
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
chunk_count,
)
return _instrumented_stream()
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
if self._memory_client is None: if self._memory_client is None:
@ -75,10 +105,22 @@ class CustomAgent(Agent):
if not user_query: if not user_query:
return "" return ""
started_at = time.perf_counter()
try: try:
return await self._memory_client.recall(user_query) recalled = await self._memory_client.recall(user_query)
elapsed = time.perf_counter() - started_at
logger.info(
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
elapsed,
len(user_query),
len(recalled),
)
return recalled
except Exception: except Exception:
logger.exception("Unexpected memory recall failure") logger.exception(
"Unexpected memory recall failure after %.3fs",
time.perf_counter() - started_at,
)
return "" return ""
@ -140,8 +182,8 @@ async def entrypoint(ctx: JobContext) -> None:
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1) TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE) OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip() MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 10.0) MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 8000) MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
blackbox_stt = BlackboxSTT( blackbox_stt = BlackboxSTT(
@ -199,7 +241,7 @@ async def entrypoint(ctx: JobContext) -> None:
"false_interruption_timeout": 1.0, "false_interruption_timeout": 1.0,
}, },
), ),
preemptive_generation=False, preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
aec_warmup_duration=3.0, aec_warmup_duration=3.0,
tts_text_transforms=[ tts_text_transforms=[
"filter_emoji", "filter_emoji",
@ -211,6 +253,17 @@ async def entrypoint(ctx: JobContext) -> None:
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None: def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
metrics.log_metrics(ev.metrics) metrics.log_metrics(ev.metrics)
@session.on("conversation_item_added")
def _on_conversation_item_added(event) -> None:
item = getattr(event, "item", None)
if not isinstance(item, ChatMessage):
return
if item.role == "user" and item.metrics:
logger.info("User turn metrics: %s", item.metrics)
elif item.role == "assistant" and item.metrics:
logger.info("Assistant turn metrics: %s", item.metrics)
memory_client = ( memory_client = (
MemoryRecallClient( MemoryRecallClient(
url=MEMORY_URL, url=MEMORY_URL,

193
memory.py
View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import re
from typing import Any from typing import Any
import aiohttp import aiohttp
@ -11,6 +12,23 @@ from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError,
logger = logging.getLogger("memory-recall") logger = logging.getLogger("memory-recall")
_LOCATION_STOPWORDS = {
"哪里",
"在哪",
"在哪里",
"哪儿",
"位置",
"什么地方",
"帮我找",
"帮我寻找",
"找一下",
"",
"请问",
"",
"",
"",
}
class MemoryRecallClient: class MemoryRecallClient:
def __init__( def __init__(
@ -100,27 +118,31 @@ def _format_room_graph_memory(payload: Any, query: str) -> str:
if not objects and not relations and not summary: if not objects and not relations and not summary:
return "" return ""
objects_text = json.dumps(objects, ensure_ascii=False, indent=2) query_terms = _query_terms(query)
relations_text = json.dumps(relations, ensure_ascii=False, indent=2) relevant_objects, relevant_relations = _relevant_room_graph(
objects=objects,
relations=relations,
query_terms=query_terms,
)
objects_text = json.dumps(
relevant_objects or _compact_items(objects, limit=12),
ensure_ascii=False,
separators=(",", ":"),
)
relations_text = json.dumps(
relevant_relations or _compact_items(relations, limit=24),
ensure_ascii=False,
separators=(",", ":"),
)
prompt = f""" prompt = f"""
你是一个物品定位助手。 你是一个物品定位助手。
我的房间内有以下物品信息: 目标物品:{query}
相关物品:{objects_text}
{objects_text} 相关空间关系:{relations_text}
房间概览:{summary}
这些物品之间的空间关系如下:
{relations_text}
房间概览如下:
{summary}
现在我要找的目标物品是:{query}
请根据上面的 objects、relations 和 summary告诉我它在哪里。
回答要求: 回答要求:
1. 只说明它和其他物品的位置关系。 1. 只说明它和其他物品的位置关系。
@ -131,5 +153,140 @@ def _format_room_graph_memory(payload: Any, query: str) -> str:
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。 6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。 7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
""".strip() """.strip()
logger.info(
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
query_terms,
len(relevant_objects),
len(objects) if isinstance(objects, list) else 0,
len(relevant_relations),
len(relations) if isinstance(relations, list) else 0,
len(prompt),
)
return prompt return prompt
def _query_terms(query: str) -> list[str]:
normalized = re.sub(r"[\s?。!,、,.!]", "", query)
for word in _LOCATION_STOPWORDS:
normalized = normalized.replace(word, "")
terms = [normalized] if normalized else []
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
if token not in _LOCATION_STOPWORDS and token not in terms:
terms.append(token)
return terms[:4]
def _relevant_room_graph(
*,
objects: Any,
relations: Any,
query_terms: list[str],
) -> tuple[list[Any], list[Any]]:
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
return [], []
matched_ids: set[str] = set()
matched_objects: list[Any] = []
object_by_id: dict[str, Any] = {}
for obj in objects:
obj_id = _object_id(obj)
if obj_id:
object_by_id[obj_id] = obj
obj_text = _compact_text(obj)
if any(term and term in obj_text for term in query_terms):
matched_objects.append(obj)
if obj_id:
matched_ids.add(obj_id)
relevant_relations: list[Any] = []
related_ids: set[str] = set(matched_ids)
for relation in relations:
relation_text = _compact_text(relation)
relation_ids = _ids_in_value(relation)
if (
any(term and term in relation_text for term in query_terms)
or bool(matched_ids.intersection(relation_ids))
):
relevant_relations.append(relation)
related_ids.update(relation_ids)
relevant_objects = list(matched_objects)
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
for obj_id in related_ids:
obj = object_by_id.get(obj_id)
key = _object_key(obj)
if obj is not None and key not in seen_object_keys:
relevant_objects.append(obj)
seen_object_keys.add(key)
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
def _compact_items(items: Any, *, limit: int) -> list[Any]:
if not isinstance(items, list):
return []
return [_compact_item(item) for item in items[:limit]]
def _compact_item(item: Any) -> Any:
if not isinstance(item, dict):
return item
preferred_keys = (
"id",
"name",
"label",
"class",
"category",
"type",
"text",
"source",
"target",
"subject",
"object",
"relation",
"predicate",
"description",
)
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
return compact or item
def _object_id(obj: Any) -> str | None:
if not isinstance(obj, dict):
return None
for key in ("id", "object_id", "uuid", "name", "label"):
value = obj.get(key)
if isinstance(value, (str, int)):
return str(value)
return None
def _object_key(obj: Any) -> str:
return _object_id(obj) or _compact_text(obj)
def _ids_in_value(value: Any) -> set[str]:
ids: set[str] = set()
if isinstance(value, dict):
for key, item in value.items():
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
if isinstance(item, (str, int)):
ids.add(str(item))
elif isinstance(item, dict):
obj_id = _object_id(item)
if obj_id:
ids.add(obj_id)
ids.update(_ids_in_value(item))
elif isinstance(value, list):
for item in value:
ids.update(_ids_in_value(item))
return ids
def _compact_text(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))

24
tts.py
View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import os import os
import time
import wave import wave
from collections.abc import Mapping from collections.abc import Mapping
from io import BytesIO from io import BytesIO
@ -88,6 +89,7 @@ class BlackboxTTSStream(tts.ChunkedStream):
self._tts: BlackboxTTS = tts self._tts: BlackboxTTS = tts
async def _run(self, output_emitter: tts.AudioEmitter) -> None: async def _run(self, output_emitter: tts.AudioEmitter) -> None:
started_at = time.perf_counter()
form = aiohttp.FormData(default_to_multipart=True) form = aiohttp.FormData(default_to_multipart=True)
form.add_field("text", self.input_text) form.add_field("text", self.input_text)
form.add_field("model_name", self._tts._model_name) form.add_field("model_name", self._tts._model_name)
@ -131,6 +133,9 @@ class BlackboxTTSStream(tts.ChunkedStream):
content_type = resp.headers.get("Content-Type", "audio/wav") content_type = resp.headers.get("Content-Type", "audio/wav")
logged_wav_format = False logged_wav_format = False
wav_header_probe = bytearray() wav_header_probe = bytearray()
first_audio_at: float | None = None
chunk_count = 0
total_bytes = 0
output_emitter.initialize( output_emitter.initialize(
request_id=utils.shortuuid(), request_id=utils.shortuuid(),
sample_rate=self._tts.sample_rate, sample_rate=self._tts.sample_rate,
@ -140,6 +145,16 @@ class BlackboxTTSStream(tts.ChunkedStream):
async for data, _ in resp.content.iter_chunks(): async for data, _ in resp.content.iter_chunks():
if data: if data:
chunk_count += 1
total_bytes += len(data)
if first_audio_at is None:
first_audio_at = time.perf_counter()
logger.info(
"TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)",
first_audio_at - started_at,
len(self.input_text),
len(data),
)
if not logged_wav_format: if not logged_wav_format:
wav_header_probe.extend(data) wav_header_probe.extend(data)
logged_wav_format = _log_wav_format( logged_wav_format = _log_wav_format(
@ -156,6 +171,15 @@ class BlackboxTTSStream(tts.ChunkedStream):
logged_wav_format = True logged_wav_format = True
output_emitter.push(data) output_emitter.push(data)
output_emitter.flush() output_emitter.flush()
finished_at = time.perf_counter()
logger.info(
"TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)",
finished_at - started_at,
(first_audio_at - started_at) if first_audio_at else -1.0,
chunk_count,
total_bytes,
len(self.input_text),
)
except asyncio.TimeoutError as e: except asyncio.TimeoutError as e:
raise APITimeoutError("TTS blackbox request timed out") from e raise APITimeoutError("TTS blackbox request timed out") from e
except aiohttp.ClientError as e: except aiohttp.ClientError as e: