diff --git a/custom_agent.py b/custom_agent.py index 53c9b71..3786074 100644 --- a/custom_agent.py +++ b/custom_agent.py @@ -1,5 +1,6 @@ import logging import os +import time from collections.abc import AsyncIterable from pathlib import Path @@ -61,11 +62,40 @@ class CustomAgent(Agent): tools: list[llm.Tool], model_settings: ModelSettings, ) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: + llm_node_started_at = time.perf_counter() memory_context = await self._recall_room_memory(chat_ctx) if memory_context: chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context) - return Agent.default.llm_node(self, chat_ctx, tools, model_settings) + llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings) + if not hasattr(llm_result, "__aiter__"): + elapsed = time.perf_counter() - llm_node_started_at + logger.info("LLM node completed without streaming in %.3fs", elapsed) + return llm_result + + async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]: + first_chunk_at: float | None = None + chunk_count = 0 + try: + async for chunk in llm_result: + chunk_count += 1 + if first_chunk_at is None: + first_chunk_at = time.perf_counter() + logger.info( + "LLM first chunk after %.3fs", + first_chunk_at - llm_node_started_at, + ) + yield chunk + finally: + finished_at = time.perf_counter() + logger.info( + "LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)", + finished_at - llm_node_started_at, + (first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0, + chunk_count, + ) + + return _instrumented_stream() async def _recall_room_memory(self, chat_ctx: ChatContext) -> str: if self._memory_client is None: @@ -75,10 +105,22 @@ class CustomAgent(Agent): if not user_query: return "" + started_at = time.perf_counter() try: - return await self._memory_client.recall(user_query) + recalled = await self._memory_client.recall(user_query) + elapsed = time.perf_counter() - started_at + logger.info( + "Memory recall completed in %.3fs (query_len=%s, memory_len=%s)", + elapsed, + len(user_query), + len(recalled), + ) + return recalled except Exception: - logger.exception("Unexpected memory recall failure") + logger.exception( + "Unexpected memory recall failure after %.3fs", + time.perf_counter() - started_at, + ) return "" @@ -140,8 +182,8 @@ async def entrypoint(ctx: JobContext) -> None: TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1) OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE) MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip() - MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 10.0) - MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 8000) + MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0) + MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000) MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None blackbox_stt = BlackboxSTT( @@ -199,7 +241,7 @@ async def entrypoint(ctx: JobContext) -> None: "false_interruption_timeout": 1.0, }, ), - preemptive_generation=False, + preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True), aec_warmup_duration=3.0, tts_text_transforms=[ "filter_emoji", @@ -211,6 +253,17 @@ async def entrypoint(ctx: JobContext) -> None: def _on_metrics_collected(ev: MetricsCollectedEvent) -> None: metrics.log_metrics(ev.metrics) + @session.on("conversation_item_added") + def _on_conversation_item_added(event) -> None: + item = getattr(event, "item", None) + if not isinstance(item, ChatMessage): + return + + if item.role == "user" and item.metrics: + logger.info("User turn metrics: %s", item.metrics) + elif item.role == "assistant" and item.metrics: + logger.info("Assistant turn metrics: %s", item.metrics) + memory_client = ( MemoryRecallClient( url=MEMORY_URL, diff --git a/memory.py b/memory.py index e497428..1721e2a 100644 --- a/memory.py +++ b/memory.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import json import logging +import re from typing import Any import aiohttp @@ -11,6 +12,23 @@ from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError, logger = logging.getLogger("memory-recall") +_LOCATION_STOPWORDS = { + "哪里", + "在哪", + "在哪里", + "哪儿", + "位置", + "什么地方", + "帮我找", + "帮我寻找", + "找一下", + "找", + "请问", + "请", + "吗", + "呢", +} + class MemoryRecallClient: def __init__( @@ -100,27 +118,31 @@ def _format_room_graph_memory(payload: Any, query: str) -> str: if not objects and not relations and not summary: return "" - objects_text = json.dumps(objects, ensure_ascii=False, indent=2) - relations_text = json.dumps(relations, ensure_ascii=False, indent=2) + query_terms = _query_terms(query) + relevant_objects, relevant_relations = _relevant_room_graph( + objects=objects, + relations=relations, + query_terms=query_terms, + ) + + objects_text = json.dumps( + relevant_objects or _compact_items(objects, limit=12), + ensure_ascii=False, + separators=(",", ":"), + ) + relations_text = json.dumps( + relevant_relations or _compact_items(relations, limit=24), + ensure_ascii=False, + separators=(",", ":"), + ) prompt = f""" 你是一个物品定位助手。 -我的房间内有以下物品信息: - -{objects_text} - -这些物品之间的空间关系如下: - -{relations_text} - -房间概览如下: - -{summary} - -现在我要找的目标物品是:{query} - -请根据上面的 objects、relations 和 summary,告诉我它在哪里。 +目标物品:{query} +相关物品:{objects_text} +相关空间关系:{relations_text} +房间概览:{summary} 回答要求: 1. 只说明它和其他物品的位置关系。 @@ -131,5 +153,140 @@ def _format_room_graph_memory(payload: Any, query: str) -> str: 6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。 7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。 """.strip() - + + logger.info( + "Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s", + query_terms, + len(relevant_objects), + len(objects) if isinstance(objects, list) else 0, + len(relevant_relations), + len(relations) if isinstance(relations, list) else 0, + len(prompt), + ) return prompt + + +def _query_terms(query: str) -> list[str]: + normalized = re.sub(r"[\s??。!,、,.!]", "", query) + for word in _LOCATION_STOPWORDS: + normalized = normalized.replace(word, "") + + terms = [normalized] if normalized else [] + for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query): + if token not in _LOCATION_STOPWORDS and token not in terms: + terms.append(token) + return terms[:4] + + +def _relevant_room_graph( + *, + objects: Any, + relations: Any, + query_terms: list[str], +) -> tuple[list[Any], list[Any]]: + if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms: + return [], [] + + matched_ids: set[str] = set() + matched_objects: list[Any] = [] + object_by_id: dict[str, Any] = {} + + for obj in objects: + obj_id = _object_id(obj) + if obj_id: + object_by_id[obj_id] = obj + + obj_text = _compact_text(obj) + if any(term and term in obj_text for term in query_terms): + matched_objects.append(obj) + if obj_id: + matched_ids.add(obj_id) + + relevant_relations: list[Any] = [] + related_ids: set[str] = set(matched_ids) + for relation in relations: + relation_text = _compact_text(relation) + relation_ids = _ids_in_value(relation) + if ( + any(term and term in relation_text for term in query_terms) + or bool(matched_ids.intersection(relation_ids)) + ): + relevant_relations.append(relation) + related_ids.update(relation_ids) + + relevant_objects = list(matched_objects) + seen_object_keys = {_object_key(obj) for obj in relevant_objects} + for obj_id in related_ids: + obj = object_by_id.get(obj_id) + key = _object_key(obj) + if obj is not None and key not in seen_object_keys: + relevant_objects.append(obj) + seen_object_keys.add(key) + + return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32) + + +def _compact_items(items: Any, *, limit: int) -> list[Any]: + if not isinstance(items, list): + return [] + return [_compact_item(item) for item in items[:limit]] + + +def _compact_item(item: Any) -> Any: + if not isinstance(item, dict): + return item + + preferred_keys = ( + "id", + "name", + "label", + "class", + "category", + "type", + "text", + "source", + "target", + "subject", + "object", + "relation", + "predicate", + "description", + ) + compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")} + return compact or item + + +def _object_id(obj: Any) -> str | None: + if not isinstance(obj, dict): + return None + for key in ("id", "object_id", "uuid", "name", "label"): + value = obj.get(key) + if isinstance(value, (str, int)): + return str(value) + return None + + +def _object_key(obj: Any) -> str: + return _object_id(obj) or _compact_text(obj) + + +def _ids_in_value(value: Any) -> set[str]: + ids: set[str] = set() + if isinstance(value, dict): + for key, item in value.items(): + if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}: + if isinstance(item, (str, int)): + ids.add(str(item)) + elif isinstance(item, dict): + obj_id = _object_id(item) + if obj_id: + ids.add(obj_id) + ids.update(_ids_in_value(item)) + elif isinstance(value, list): + for item in value: + ids.update(_ids_in_value(item)) + return ids + + +def _compact_text(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, separators=(",", ":")) diff --git a/tts.py b/tts.py index b374f03..a3cbc58 100644 --- a/tts.py +++ b/tts.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio import logging import os +import time import wave from collections.abc import Mapping from io import BytesIO @@ -88,6 +89,7 @@ class BlackboxTTSStream(tts.ChunkedStream): self._tts: BlackboxTTS = tts async def _run(self, output_emitter: tts.AudioEmitter) -> None: + started_at = time.perf_counter() form = aiohttp.FormData(default_to_multipart=True) form.add_field("text", self.input_text) form.add_field("model_name", self._tts._model_name) @@ -131,6 +133,9 @@ class BlackboxTTSStream(tts.ChunkedStream): content_type = resp.headers.get("Content-Type", "audio/wav") logged_wav_format = False wav_header_probe = bytearray() + first_audio_at: float | None = None + chunk_count = 0 + total_bytes = 0 output_emitter.initialize( request_id=utils.shortuuid(), sample_rate=self._tts.sample_rate, @@ -140,6 +145,16 @@ class BlackboxTTSStream(tts.ChunkedStream): async for data, _ in resp.content.iter_chunks(): if data: + chunk_count += 1 + total_bytes += len(data) + if first_audio_at is None: + first_audio_at = time.perf_counter() + logger.info( + "TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)", + first_audio_at - started_at, + len(self.input_text), + len(data), + ) if not logged_wav_format: wav_header_probe.extend(data) logged_wav_format = _log_wav_format( @@ -156,6 +171,15 @@ class BlackboxTTSStream(tts.ChunkedStream): logged_wav_format = True output_emitter.push(data) output_emitter.flush() + finished_at = time.perf_counter() + logger.info( + "TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)", + finished_at - started_at, + (first_audio_at - started_at) if first_audio_at else -1.0, + chunk_count, + total_bytes, + len(self.input_text), + ) except asyncio.TimeoutError as e: raise APITimeoutError("TTS blackbox request timed out") from e except aiohttp.ClientError as e: