perf: improve speed
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncIterable
|
||||
from pathlib import Path
|
||||
|
||||
@ -61,11 +62,40 @@ class CustomAgent(Agent):
|
||||
tools: list[llm.Tool],
|
||||
model_settings: ModelSettings,
|
||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||
llm_node_started_at = time.perf_counter()
|
||||
memory_context = await self._recall_room_memory(chat_ctx)
|
||||
if memory_context:
|
||||
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
||||
|
||||
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
||||
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
||||
if not hasattr(llm_result, "__aiter__"):
|
||||
elapsed = time.perf_counter() - llm_node_started_at
|
||||
logger.info("LLM node completed without streaming in %.3fs", elapsed)
|
||||
return llm_result
|
||||
|
||||
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||
first_chunk_at: float | None = None
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in llm_result:
|
||||
chunk_count += 1
|
||||
if first_chunk_at is None:
|
||||
first_chunk_at = time.perf_counter()
|
||||
logger.info(
|
||||
"LLM first chunk after %.3fs",
|
||||
first_chunk_at - llm_node_started_at,
|
||||
)
|
||||
yield chunk
|
||||
finally:
|
||||
finished_at = time.perf_counter()
|
||||
logger.info(
|
||||
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
|
||||
finished_at - llm_node_started_at,
|
||||
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
|
||||
chunk_count,
|
||||
)
|
||||
|
||||
return _instrumented_stream()
|
||||
|
||||
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
||||
if self._memory_client is None:
|
||||
@ -75,10 +105,22 @@ class CustomAgent(Agent):
|
||||
if not user_query:
|
||||
return ""
|
||||
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
return await self._memory_client.recall(user_query)
|
||||
recalled = await self._memory_client.recall(user_query)
|
||||
elapsed = time.perf_counter() - started_at
|
||||
logger.info(
|
||||
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
|
||||
elapsed,
|
||||
len(user_query),
|
||||
len(recalled),
|
||||
)
|
||||
return recalled
|
||||
except Exception:
|
||||
logger.exception("Unexpected memory recall failure")
|
||||
logger.exception(
|
||||
"Unexpected memory recall failure after %.3fs",
|
||||
time.perf_counter() - started_at,
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
@ -140,8 +182,8 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
||||
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
||||
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
|
||||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 10.0)
|
||||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 8000)
|
||||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
||||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
||||
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
||||
|
||||
blackbox_stt = BlackboxSTT(
|
||||
@ -199,7 +241,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
"false_interruption_timeout": 1.0,
|
||||
},
|
||||
),
|
||||
preemptive_generation=False,
|
||||
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
|
||||
aec_warmup_duration=3.0,
|
||||
tts_text_transforms=[
|
||||
"filter_emoji",
|
||||
@ -211,6 +253,17 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
||||
metrics.log_metrics(ev.metrics)
|
||||
|
||||
@session.on("conversation_item_added")
|
||||
def _on_conversation_item_added(event) -> None:
|
||||
item = getattr(event, "item", None)
|
||||
if not isinstance(item, ChatMessage):
|
||||
return
|
||||
|
||||
if item.role == "user" and item.metrics:
|
||||
logger.info("User turn metrics: %s", item.metrics)
|
||||
elif item.role == "assistant" and item.metrics:
|
||||
logger.info("Assistant turn metrics: %s", item.metrics)
|
||||
|
||||
memory_client = (
|
||||
MemoryRecallClient(
|
||||
url=MEMORY_URL,
|
||||
|
||||
Reference in New Issue
Block a user