perf: improve speed

This commit is contained in:
0Xiao0
2026-05-15 10:44:31 +08:00
parent b18c5b40da
commit fba51a5257
3 changed files with 258 additions and 24 deletions

View File

@ -1,5 +1,6 @@
import logging
import os
import time
from collections.abc import AsyncIterable
from pathlib import Path
@ -61,11 +62,40 @@ class CustomAgent(Agent):
tools: list[llm.Tool],
model_settings: ModelSettings,
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
llm_node_started_at = time.perf_counter()
memory_context = await self._recall_room_memory(chat_ctx)
if memory_context:
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
if not hasattr(llm_result, "__aiter__"):
elapsed = time.perf_counter() - llm_node_started_at
logger.info("LLM node completed without streaming in %.3fs", elapsed)
return llm_result
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
first_chunk_at: float | None = None
chunk_count = 0
try:
async for chunk in llm_result:
chunk_count += 1
if first_chunk_at is None:
first_chunk_at = time.perf_counter()
logger.info(
"LLM first chunk after %.3fs",
first_chunk_at - llm_node_started_at,
)
yield chunk
finally:
finished_at = time.perf_counter()
logger.info(
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
finished_at - llm_node_started_at,
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
chunk_count,
)
return _instrumented_stream()
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
if self._memory_client is None:
@ -75,10 +105,22 @@ class CustomAgent(Agent):
if not user_query:
return ""
started_at = time.perf_counter()
try:
return await self._memory_client.recall(user_query)
recalled = await self._memory_client.recall(user_query)
elapsed = time.perf_counter() - started_at
logger.info(
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
elapsed,
len(user_query),
len(recalled),
)
return recalled
except Exception:
logger.exception("Unexpected memory recall failure")
logger.exception(
"Unexpected memory recall failure after %.3fs",
time.perf_counter() - started_at,
)
return ""
@ -140,8 +182,8 @@ async def entrypoint(ctx: JobContext) -> None:
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 10.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 8000)
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
blackbox_stt = BlackboxSTT(
@ -199,7 +241,7 @@ async def entrypoint(ctx: JobContext) -> None:
"false_interruption_timeout": 1.0,
},
),
preemptive_generation=False,
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
aec_warmup_duration=3.0,
tts_text_transforms=[
"filter_emoji",
@ -211,6 +253,17 @@ async def entrypoint(ctx: JobContext) -> None:
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
metrics.log_metrics(ev.metrics)
@session.on("conversation_item_added")
def _on_conversation_item_added(event) -> None:
item = getattr(event, "item", None)
if not isinstance(item, ChatMessage):
return
if item.role == "user" and item.metrics:
logger.info("User turn metrics: %s", item.metrics)
elif item.role == "assistant" and item.metrics:
logger.info("Assistant turn metrics: %s", item.metrics)
memory_client = (
MemoryRecallClient(
url=MEMORY_URL,