perf: improve speed
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -61,11 +62,40 @@ class CustomAgent(Agent):
|
|||||||
tools: list[llm.Tool],
|
tools: list[llm.Tool],
|
||||||
model_settings: ModelSettings,
|
model_settings: ModelSettings,
|
||||||
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||||
|
llm_node_started_at = time.perf_counter()
|
||||||
memory_context = await self._recall_room_memory(chat_ctx)
|
memory_context = await self._recall_room_memory(chat_ctx)
|
||||||
if memory_context:
|
if memory_context:
|
||||||
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
||||||
|
|
||||||
return Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
||||||
|
if not hasattr(llm_result, "__aiter__"):
|
||||||
|
elapsed = time.perf_counter() - llm_node_started_at
|
||||||
|
logger.info("LLM node completed without streaming in %.3fs", elapsed)
|
||||||
|
return llm_result
|
||||||
|
|
||||||
|
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
||||||
|
first_chunk_at: float | None = None
|
||||||
|
chunk_count = 0
|
||||||
|
try:
|
||||||
|
async for chunk in llm_result:
|
||||||
|
chunk_count += 1
|
||||||
|
if first_chunk_at is None:
|
||||||
|
first_chunk_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"LLM first chunk after %.3fs",
|
||||||
|
first_chunk_at - llm_node_started_at,
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
finished_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
|
||||||
|
finished_at - llm_node_started_at,
|
||||||
|
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
|
||||||
|
chunk_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _instrumented_stream()
|
||||||
|
|
||||||
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
||||||
if self._memory_client is None:
|
if self._memory_client is None:
|
||||||
@ -75,10 +105,22 @@ class CustomAgent(Agent):
|
|||||||
if not user_query:
|
if not user_query:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
started_at = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
return await self._memory_client.recall(user_query)
|
recalled = await self._memory_client.recall(user_query)
|
||||||
|
elapsed = time.perf_counter() - started_at
|
||||||
|
logger.info(
|
||||||
|
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
|
||||||
|
elapsed,
|
||||||
|
len(user_query),
|
||||||
|
len(recalled),
|
||||||
|
)
|
||||||
|
return recalled
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Unexpected memory recall failure")
|
logger.exception(
|
||||||
|
"Unexpected memory recall failure after %.3fs",
|
||||||
|
time.perf_counter() - started_at,
|
||||||
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@ -140,8 +182,8 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
||||||
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
||||||
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
|
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
|
||||||
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 10.0)
|
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
||||||
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 8000)
|
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
||||||
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
||||||
|
|
||||||
blackbox_stt = BlackboxSTT(
|
blackbox_stt = BlackboxSTT(
|
||||||
@ -199,7 +241,7 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
"false_interruption_timeout": 1.0,
|
"false_interruption_timeout": 1.0,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
preemptive_generation=False,
|
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
|
||||||
aec_warmup_duration=3.0,
|
aec_warmup_duration=3.0,
|
||||||
tts_text_transforms=[
|
tts_text_transforms=[
|
||||||
"filter_emoji",
|
"filter_emoji",
|
||||||
@ -211,6 +253,17 @@ async def entrypoint(ctx: JobContext) -> None:
|
|||||||
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
||||||
metrics.log_metrics(ev.metrics)
|
metrics.log_metrics(ev.metrics)
|
||||||
|
|
||||||
|
@session.on("conversation_item_added")
|
||||||
|
def _on_conversation_item_added(event) -> None:
|
||||||
|
item = getattr(event, "item", None)
|
||||||
|
if not isinstance(item, ChatMessage):
|
||||||
|
return
|
||||||
|
|
||||||
|
if item.role == "user" and item.metrics:
|
||||||
|
logger.info("User turn metrics: %s", item.metrics)
|
||||||
|
elif item.role == "assistant" and item.metrics:
|
||||||
|
logger.info("Assistant turn metrics: %s", item.metrics)
|
||||||
|
|
||||||
memory_client = (
|
memory_client = (
|
||||||
MemoryRecallClient(
|
MemoryRecallClient(
|
||||||
url=MEMORY_URL,
|
url=MEMORY_URL,
|
||||||
|
|||||||
193
memory.py
193
memory.py
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -11,6 +12,23 @@ from livekit.agents import APIConnectionError, APIStatusError, APITimeoutError,
|
|||||||
|
|
||||||
logger = logging.getLogger("memory-recall")
|
logger = logging.getLogger("memory-recall")
|
||||||
|
|
||||||
|
_LOCATION_STOPWORDS = {
|
||||||
|
"哪里",
|
||||||
|
"在哪",
|
||||||
|
"在哪里",
|
||||||
|
"哪儿",
|
||||||
|
"位置",
|
||||||
|
"什么地方",
|
||||||
|
"帮我找",
|
||||||
|
"帮我寻找",
|
||||||
|
"找一下",
|
||||||
|
"找",
|
||||||
|
"请问",
|
||||||
|
"请",
|
||||||
|
"吗",
|
||||||
|
"呢",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MemoryRecallClient:
|
class MemoryRecallClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -100,27 +118,31 @@ def _format_room_graph_memory(payload: Any, query: str) -> str:
|
|||||||
if not objects and not relations and not summary:
|
if not objects and not relations and not summary:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
objects_text = json.dumps(objects, ensure_ascii=False, indent=2)
|
query_terms = _query_terms(query)
|
||||||
relations_text = json.dumps(relations, ensure_ascii=False, indent=2)
|
relevant_objects, relevant_relations = _relevant_room_graph(
|
||||||
|
objects=objects,
|
||||||
|
relations=relations,
|
||||||
|
query_terms=query_terms,
|
||||||
|
)
|
||||||
|
|
||||||
|
objects_text = json.dumps(
|
||||||
|
relevant_objects or _compact_items(objects, limit=12),
|
||||||
|
ensure_ascii=False,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
relations_text = json.dumps(
|
||||||
|
relevant_relations or _compact_items(relations, limit=24),
|
||||||
|
ensure_ascii=False,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
你是一个物品定位助手。
|
你是一个物品定位助手。
|
||||||
|
|
||||||
我的房间内有以下物品信息:
|
目标物品:{query}
|
||||||
|
相关物品:{objects_text}
|
||||||
{objects_text}
|
相关空间关系:{relations_text}
|
||||||
|
房间概览:{summary}
|
||||||
这些物品之间的空间关系如下:
|
|
||||||
|
|
||||||
{relations_text}
|
|
||||||
|
|
||||||
房间概览如下:
|
|
||||||
|
|
||||||
{summary}
|
|
||||||
|
|
||||||
现在我要找的目标物品是:{query}
|
|
||||||
|
|
||||||
请根据上面的 objects、relations 和 summary,告诉我它在哪里。
|
|
||||||
|
|
||||||
回答要求:
|
回答要求:
|
||||||
1. 只说明它和其他物品的位置关系。
|
1. 只说明它和其他物品的位置关系。
|
||||||
@ -131,5 +153,140 @@ def _format_room_graph_memory(payload: Any, query: str) -> str:
|
|||||||
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
|
6. 不要回答 right-near-low、left-far-high 这类区域标签,只回答“在……的左边/右边/上方/下方/前面/后面/附近”等相对关系。
|
||||||
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
|
7. 如果用户当前输入不是找物品或问位置,可以忽略这段房间记忆。
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Formatted room memory: query_terms=%s, objects=%s/%s, relations=%s/%s, chars=%s",
|
||||||
|
query_terms,
|
||||||
|
len(relevant_objects),
|
||||||
|
len(objects) if isinstance(objects, list) else 0,
|
||||||
|
len(relevant_relations),
|
||||||
|
len(relations) if isinstance(relations, list) else 0,
|
||||||
|
len(prompt),
|
||||||
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _query_terms(query: str) -> list[str]:
|
||||||
|
normalized = re.sub(r"[\s??。!,、,.!]", "", query)
|
||||||
|
for word in _LOCATION_STOPWORDS:
|
||||||
|
normalized = normalized.replace(word, "")
|
||||||
|
|
||||||
|
terms = [normalized] if normalized else []
|
||||||
|
for token in re.findall(r"[\u4e00-\u9fffA-Za-z0-9_-]{2,}", query):
|
||||||
|
if token not in _LOCATION_STOPWORDS and token not in terms:
|
||||||
|
terms.append(token)
|
||||||
|
return terms[:4]
|
||||||
|
|
||||||
|
|
||||||
|
def _relevant_room_graph(
|
||||||
|
*,
|
||||||
|
objects: Any,
|
||||||
|
relations: Any,
|
||||||
|
query_terms: list[str],
|
||||||
|
) -> tuple[list[Any], list[Any]]:
|
||||||
|
if not isinstance(objects, list) or not isinstance(relations, list) or not query_terms:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
matched_ids: set[str] = set()
|
||||||
|
matched_objects: list[Any] = []
|
||||||
|
object_by_id: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for obj in objects:
|
||||||
|
obj_id = _object_id(obj)
|
||||||
|
if obj_id:
|
||||||
|
object_by_id[obj_id] = obj
|
||||||
|
|
||||||
|
obj_text = _compact_text(obj)
|
||||||
|
if any(term and term in obj_text for term in query_terms):
|
||||||
|
matched_objects.append(obj)
|
||||||
|
if obj_id:
|
||||||
|
matched_ids.add(obj_id)
|
||||||
|
|
||||||
|
relevant_relations: list[Any] = []
|
||||||
|
related_ids: set[str] = set(matched_ids)
|
||||||
|
for relation in relations:
|
||||||
|
relation_text = _compact_text(relation)
|
||||||
|
relation_ids = _ids_in_value(relation)
|
||||||
|
if (
|
||||||
|
any(term and term in relation_text for term in query_terms)
|
||||||
|
or bool(matched_ids.intersection(relation_ids))
|
||||||
|
):
|
||||||
|
relevant_relations.append(relation)
|
||||||
|
related_ids.update(relation_ids)
|
||||||
|
|
||||||
|
relevant_objects = list(matched_objects)
|
||||||
|
seen_object_keys = {_object_key(obj) for obj in relevant_objects}
|
||||||
|
for obj_id in related_ids:
|
||||||
|
obj = object_by_id.get(obj_id)
|
||||||
|
key = _object_key(obj)
|
||||||
|
if obj is not None and key not in seen_object_keys:
|
||||||
|
relevant_objects.append(obj)
|
||||||
|
seen_object_keys.add(key)
|
||||||
|
|
||||||
|
return _compact_items(relevant_objects, limit=16), _compact_items(relevant_relations, limit=32)
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_items(items: Any, *, limit: int) -> list[Any]:
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return []
|
||||||
|
return [_compact_item(item) for item in items[:limit]]
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_item(item: Any) -> Any:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
return item
|
||||||
|
|
||||||
|
preferred_keys = (
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"label",
|
||||||
|
"class",
|
||||||
|
"category",
|
||||||
|
"type",
|
||||||
|
"text",
|
||||||
|
"source",
|
||||||
|
"target",
|
||||||
|
"subject",
|
||||||
|
"object",
|
||||||
|
"relation",
|
||||||
|
"predicate",
|
||||||
|
"description",
|
||||||
|
)
|
||||||
|
compact = {key: item[key] for key in preferred_keys if key in item and item[key] not in (None, "")}
|
||||||
|
return compact or item
|
||||||
|
|
||||||
|
|
||||||
|
def _object_id(obj: Any) -> str | None:
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return None
|
||||||
|
for key in ("id", "object_id", "uuid", "name", "label"):
|
||||||
|
value = obj.get(key)
|
||||||
|
if isinstance(value, (str, int)):
|
||||||
|
return str(value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _object_key(obj: Any) -> str:
|
||||||
|
return _object_id(obj) or _compact_text(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def _ids_in_value(value: Any) -> set[str]:
|
||||||
|
ids: set[str] = set()
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for key, item in value.items():
|
||||||
|
if key in {"id", "object_id", "source", "target", "subject", "object", "from", "to"}:
|
||||||
|
if isinstance(item, (str, int)):
|
||||||
|
ids.add(str(item))
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
obj_id = _object_id(item)
|
||||||
|
if obj_id:
|
||||||
|
ids.add(obj_id)
|
||||||
|
ids.update(_ids_in_value(item))
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
ids.update(_ids_in_value(item))
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _compact_text(value: Any) -> str:
|
||||||
|
return json.dumps(value, ensure_ascii=False, separators=(",", ":"))
|
||||||
|
|||||||
24
tts.py
24
tts.py
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
import wave
|
import wave
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@ -88,6 +89,7 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
self._tts: BlackboxTTS = tts
|
self._tts: BlackboxTTS = tts
|
||||||
|
|
||||||
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
|
||||||
|
started_at = time.perf_counter()
|
||||||
form = aiohttp.FormData(default_to_multipart=True)
|
form = aiohttp.FormData(default_to_multipart=True)
|
||||||
form.add_field("text", self.input_text)
|
form.add_field("text", self.input_text)
|
||||||
form.add_field("model_name", self._tts._model_name)
|
form.add_field("model_name", self._tts._model_name)
|
||||||
@ -131,6 +133,9 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
content_type = resp.headers.get("Content-Type", "audio/wav")
|
content_type = resp.headers.get("Content-Type", "audio/wav")
|
||||||
logged_wav_format = False
|
logged_wav_format = False
|
||||||
wav_header_probe = bytearray()
|
wav_header_probe = bytearray()
|
||||||
|
first_audio_at: float | None = None
|
||||||
|
chunk_count = 0
|
||||||
|
total_bytes = 0
|
||||||
output_emitter.initialize(
|
output_emitter.initialize(
|
||||||
request_id=utils.shortuuid(),
|
request_id=utils.shortuuid(),
|
||||||
sample_rate=self._tts.sample_rate,
|
sample_rate=self._tts.sample_rate,
|
||||||
@ -140,6 +145,16 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
|
|
||||||
async for data, _ in resp.content.iter_chunks():
|
async for data, _ in resp.content.iter_chunks():
|
||||||
if data:
|
if data:
|
||||||
|
chunk_count += 1
|
||||||
|
total_bytes += len(data)
|
||||||
|
if first_audio_at is None:
|
||||||
|
first_audio_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"TTS first audio chunk after %.3fs (text_len=%s, bytes=%s)",
|
||||||
|
first_audio_at - started_at,
|
||||||
|
len(self.input_text),
|
||||||
|
len(data),
|
||||||
|
)
|
||||||
if not logged_wav_format:
|
if not logged_wav_format:
|
||||||
wav_header_probe.extend(data)
|
wav_header_probe.extend(data)
|
||||||
logged_wav_format = _log_wav_format(
|
logged_wav_format = _log_wav_format(
|
||||||
@ -156,6 +171,15 @@ class BlackboxTTSStream(tts.ChunkedStream):
|
|||||||
logged_wav_format = True
|
logged_wav_format = True
|
||||||
output_emitter.push(data)
|
output_emitter.push(data)
|
||||||
output_emitter.flush()
|
output_emitter.flush()
|
||||||
|
finished_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"TTS stream completed in %.3fs (first_chunk=%.3fs, chunks=%s, bytes=%s, text_len=%s)",
|
||||||
|
finished_at - started_at,
|
||||||
|
(first_audio_at - started_at) if first_audio_at else -1.0,
|
||||||
|
chunk_count,
|
||||||
|
total_bytes,
|
||||||
|
len(self.input_text),
|
||||||
|
)
|
||||||
except asyncio.TimeoutError as e:
|
except asyncio.TimeoutError as e:
|
||||||
raise APITimeoutError("TTS blackbox request timed out") from e
|
raise APITimeoutError("TTS blackbox request timed out") from e
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user