397 lines
14 KiB
Python
397 lines
14 KiB
Python
import logging
|
|
import os
|
|
import time
|
|
from collections.abc import AsyncIterable
|
|
from pathlib import Path
|
|
|
|
from dotenv import load_dotenv
|
|
from memory import MemoryRecallClient
|
|
|
|
from asr import BlackboxSTT
|
|
from livekit.agents import (
|
|
Agent,
|
|
AgentServer,
|
|
AgentSession,
|
|
ChatContext,
|
|
ChatMessage,
|
|
FlushSentinel,
|
|
JobContext,
|
|
JobProcess,
|
|
MetricsCollectedEvent,
|
|
ModelSettings,
|
|
RecordingOptions,
|
|
TurnHandlingOptions,
|
|
cli,
|
|
llm,
|
|
metrics,
|
|
room_io,
|
|
stt,
|
|
)
|
|
from livekit.plugins import openai, silero
|
|
from livekit.plugins.turn_detector.multilingual import MultilingualModel
|
|
from tts import BlackboxTTS
|
|
|
|
logger = logging.getLogger("custom-agent")
|
|
|
|
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
|
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
|
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
|
|
|
ROOM_LOCATOR_INSTRUCTIONS = """
|
|
你是一个房间物品定位助手。
|
|
当用户询问房间内某个物品的位置时:
|
|
- 只用一句中文回答
|
|
- 描述目标物品和其他物品的相对位置关系
|
|
- 不要使用 Markdown、emoji、列表、标题、坐标区域标签
|
|
- 不要解释推理过程
|
|
如果用户的问题与房间物品定位无关,则正常回答用户问题。
|
|
""".strip()
|
|
|
|
class CustomAgent(Agent):
|
|
def __init__(self, *, memory_client: MemoryRecallClient | None = None) -> None:
|
|
super().__init__(instructions=ROOM_LOCATOR_INSTRUCTIONS)
|
|
self._memory_client = memory_client
|
|
|
|
async def on_enter(self) -> None:
|
|
# self.session.generate_reply(instructions="greet the user and introduce yourself")
|
|
pass
|
|
|
|
async def llm_node(
|
|
self,
|
|
chat_ctx: ChatContext,
|
|
tools: list[llm.Tool],
|
|
model_settings: ModelSettings,
|
|
) -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
|
llm_node_started_at = time.perf_counter()
|
|
memory_context = await self._recall_room_memory(chat_ctx)
|
|
if memory_context:
|
|
chat_ctx = _with_memory_as_latest_user_message(chat_ctx, memory_context)
|
|
|
|
llm_result = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
|
|
if not hasattr(llm_result, "__aiter__"):
|
|
elapsed = time.perf_counter() - llm_node_started_at
|
|
logger.info("LLM node completed without streaming in %.3fs", elapsed)
|
|
return llm_result
|
|
|
|
async def _instrumented_stream() -> AsyncIterable[llm.ChatChunk | str | FlushSentinel]:
|
|
first_chunk_at: float | None = None
|
|
chunk_count = 0
|
|
try:
|
|
async for chunk in llm_result:
|
|
chunk_count += 1
|
|
if first_chunk_at is None:
|
|
first_chunk_at = time.perf_counter()
|
|
logger.info(
|
|
"LLM first chunk after %.3fs",
|
|
first_chunk_at - llm_node_started_at,
|
|
)
|
|
yield chunk
|
|
finally:
|
|
finished_at = time.perf_counter()
|
|
logger.info(
|
|
"LLM stream completed in %.3fs (first_chunk=%.3fs, chunks=%s)",
|
|
finished_at - llm_node_started_at,
|
|
(first_chunk_at - llm_node_started_at) if first_chunk_at else -1.0,
|
|
chunk_count,
|
|
)
|
|
|
|
return _instrumented_stream()
|
|
|
|
async def _recall_room_memory(self, chat_ctx: ChatContext) -> str:
|
|
if self._memory_client is None:
|
|
return ""
|
|
|
|
user_query = _latest_user_text(chat_ctx)
|
|
if not user_query:
|
|
return ""
|
|
|
|
started_at = time.perf_counter()
|
|
try:
|
|
recalled = await self._memory_client.recall(user_query)
|
|
elapsed = time.perf_counter() - started_at
|
|
logger.info(
|
|
"Memory recall completed in %.3fs (query_len=%s, memory_len=%s)",
|
|
elapsed,
|
|
len(user_query),
|
|
len(recalled),
|
|
)
|
|
return recalled
|
|
except Exception:
|
|
logger.exception(
|
|
"Unexpected memory recall failure after %.3fs",
|
|
time.perf_counter() - started_at,
|
|
)
|
|
return ""
|
|
|
|
|
|
def _latest_user_text(chat_ctx: ChatContext) -> str:
|
|
for item in reversed(chat_ctx.items):
|
|
if isinstance(item, ChatMessage) and item.role == "user":
|
|
return (item.text_content or "").strip()
|
|
return ""
|
|
|
|
|
|
def _with_memory_as_latest_user_message(chat_ctx: ChatContext, memory_context: str) -> ChatContext:
|
|
chat_ctx = chat_ctx.copy()
|
|
for index in range(len(chat_ctx.items) - 1, -1, -1):
|
|
item = chat_ctx.items[index]
|
|
if isinstance(item, ChatMessage) and item.role == "user":
|
|
user_msg = item.model_copy(deep=True)
|
|
user_msg.content = [memory_context]
|
|
chat_ctx.items[index] = user_msg
|
|
return chat_ctx
|
|
|
|
chat_ctx.items.append(ChatMessage(role="user", content=[memory_context]))
|
|
return chat_ctx
|
|
|
|
|
|
server = AgentServer()
|
|
|
|
|
|
def prewarm(proc: JobProcess) -> None:
|
|
# Load Silero VAD as requested
|
|
proc.userdata["vad"] = silero.VAD.load()
|
|
|
|
|
|
server.setup_fnc = prewarm
|
|
|
|
|
|
@server.rtc_session(agent_name=AGENT_NAME)
|
|
async def entrypoint(ctx: JobContext) -> None:
|
|
ctx.log_context_fields = {
|
|
"room": ctx.room.name,
|
|
}
|
|
|
|
# Configuration for custom local endpoints. These can be set in your .env file.
|
|
ASR_URL = os.getenv("CUSTOM_ASR_URL", "http://10.6.80.21:5003/asr-blackbox")
|
|
ASR_MODEL = os.getenv("CUSTOM_ASR_MODEL", "sensevoice")
|
|
ASR_LANGUAGE = os.getenv("CUSTOM_ASR_LANGUAGE", "auto")
|
|
ASR_OUTPUT_LANGUAGE = os.getenv("CUSTOM_ASR_OUTPUT_LANGUAGE", "zh")
|
|
|
|
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
|
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
|
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
|
|
if not LLM_API_KEY:
|
|
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
|
|
|
TTS_URL = os.getenv("CUSTOM_TTS_URL") or os.getenv(
|
|
"VOXCPM_TTS_URL", "http://localhost:5050/tts-blackbox"
|
|
)
|
|
TTS_MODEL = os.getenv("CUSTOM_TTS_MODEL") or os.getenv("VOXCPM_TTS_MODEL", "voxcpmtts")
|
|
TTS_SAMPLE_RATE = _env_int("CUSTOM_TTS_SAMPLE_RATE", 16000)
|
|
TTS_NUM_CHANNELS = _env_int("CUSTOM_TTS_NUM_CHANNELS", 1)
|
|
OUTPUT_SAMPLE_RATE = _env_int("CUSTOM_OUTPUT_SAMPLE_RATE", TTS_SAMPLE_RATE)
|
|
MEMORY_URL = os.getenv("CUSTOM_MEMORY_URL", "").strip()
|
|
MEMORY_TIMEOUT = _env_float("CUSTOM_MEMORY_TIMEOUT", 2.0)
|
|
MEMORY_MAX_CHARS = _env_int("CUSTOM_MEMORY_MAX_CHARS", 2000)
|
|
MEMORY_API_KEY = os.getenv("CUSTOM_MEMORY_API_KEY") or None
|
|
|
|
blackbox_stt = BlackboxSTT(
|
|
url=ASR_URL,
|
|
model_name=ASR_MODEL,
|
|
language=ASR_LANGUAGE,
|
|
output_language=ASR_OUTPUT_LANGUAGE,
|
|
hotwords=os.getenv("CUSTOM_ASR_HOTWORDS"),
|
|
itn=os.getenv("CUSTOM_ASR_ITN"),
|
|
chunk_mode=os.getenv("CUSTOM_ASR_CHUNK_MODE"),
|
|
)
|
|
stt_stream = stt.StreamAdapter(stt=blackbox_stt, vad=ctx.proc.userdata["vad"])
|
|
|
|
import httpx
|
|
from openai import AsyncClient as OpenAIAsyncClient
|
|
|
|
# OpenAI-compatible endpoints can be used by setting CUSTOM_LLM_BASE_URL.
|
|
http_client = httpx.AsyncClient(verify=_env_bool("CUSTOM_LLM_VERIFY_SSL", False))
|
|
|
|
if LLM_BASE_URL:
|
|
openai_client = OpenAIAsyncClient(
|
|
api_key=LLM_API_KEY,
|
|
base_url=LLM_BASE_URL,
|
|
http_client=http_client,
|
|
)
|
|
else:
|
|
openai_client = OpenAIAsyncClient(
|
|
api_key=LLM_API_KEY,
|
|
http_client=http_client,
|
|
)
|
|
|
|
session: AgentSession = AgentSession(
|
|
# 1. Custom ASR blackbox with StreamAdapter
|
|
stt=stt_stream,
|
|
# 2. OpenAI-compatible LLM, e.g. MiniMax, Qwen, or OpenAI.
|
|
llm=openai.LLM(
|
|
model=LLM_MODEL,
|
|
client=openai_client,
|
|
),
|
|
# 3. TTS blackbox
|
|
tts=BlackboxTTS(
|
|
url=TTS_URL,
|
|
model_name=TTS_MODEL,
|
|
params=_tts_params_from_env(TTS_MODEL),
|
|
prompt_wav_path=_tts_prompt_wav_from_env(TTS_MODEL),
|
|
sample_rate=TTS_SAMPLE_RATE,
|
|
num_channels=TTS_NUM_CHANNELS,
|
|
),
|
|
# 4. Silero VAD
|
|
vad=ctx.proc.userdata["vad"],
|
|
turn_handling=TurnHandlingOptions(
|
|
turn_detection=MultilingualModel(),
|
|
interruption={
|
|
"resume_false_interruption": True,
|
|
"false_interruption_timeout": 1.0,
|
|
},
|
|
),
|
|
preemptive_generation=_env_bool("CUSTOM_PREEMPTIVE_GENERATION", True),
|
|
aec_warmup_duration=3.0,
|
|
tts_text_transforms=[
|
|
"filter_emoji",
|
|
"filter_markdown",
|
|
],
|
|
)
|
|
|
|
@session.on("metrics_collected")
|
|
def _on_metrics_collected(ev: MetricsCollectedEvent) -> None:
|
|
metrics.log_metrics(ev.metrics)
|
|
|
|
@session.on("conversation_item_added")
|
|
def _on_conversation_item_added(event) -> None:
|
|
item = getattr(event, "item", None)
|
|
if not isinstance(item, ChatMessage):
|
|
return
|
|
|
|
if item.role == "user" and item.metrics:
|
|
logger.info("User turn metrics: %s", item.metrics)
|
|
elif item.role == "assistant" and item.metrics:
|
|
logger.info("Assistant turn metrics: %s", item.metrics)
|
|
|
|
memory_client = (
|
|
MemoryRecallClient(
|
|
url=MEMORY_URL,
|
|
timeout=MEMORY_TIMEOUT,
|
|
max_chars=MEMORY_MAX_CHARS,
|
|
api_key=MEMORY_API_KEY,
|
|
)
|
|
if MEMORY_URL
|
|
else None
|
|
)
|
|
|
|
await session.start(
|
|
agent=CustomAgent(memory_client=memory_client),
|
|
room=ctx.room,
|
|
room_options=room_io.RoomOptions(
|
|
audio_output=room_io.AudioOutputOptions(
|
|
sample_rate=OUTPUT_SAMPLE_RATE,
|
|
num_channels=TTS_NUM_CHANNELS,
|
|
),
|
|
),
|
|
record=_recording_options_from_env(),
|
|
)
|
|
|
|
|
|
def _tts_params_from_env(model_name: str) -> dict[str, str]:
|
|
params: dict[str, str] = {}
|
|
model_name = model_name.lower()
|
|
|
|
if model_name == "voxcpmtts":
|
|
_set_if_present(params, "streaming", os.getenv("CUSTOM_TTS_STREAMING"))
|
|
_set_if_present(
|
|
params,
|
|
"prompt_text",
|
|
os.getenv("CUSTOM_TTS_PROMPT_TEXT") or os.getenv("VOXCPM_PROMPT_TEXT"),
|
|
)
|
|
_set_if_present(params, "cfg_value", os.getenv("VOXCPM_CFG_VALUE"))
|
|
_set_if_present(params, "inference_timesteps", os.getenv("VOXCPM_INFERENCE_TIMESTEPS"))
|
|
_set_if_present(params, "do_normalize", os.getenv("VOXCPM_DO_NORMALIZE"))
|
|
_set_if_present(params, "denoise", os.getenv("VOXCPM_DENOISE"))
|
|
_set_if_present(params, "retry_badcase", os.getenv("VOXCPM_RETRY_BADCASE"))
|
|
_set_if_present(
|
|
params,
|
|
"retry_badcase_max_times",
|
|
os.getenv("VOXCPM_RETRY_BADCASE_MAX_TIMES"),
|
|
)
|
|
_set_if_present(
|
|
params,
|
|
"retry_badcase_ratio_threshold",
|
|
os.getenv("VOXCPM_RETRY_BADCASE_RATIO_THRESHOLD"),
|
|
)
|
|
elif model_name == "melotts":
|
|
_set_if_present(params, "speed", os.getenv("CUSTOM_TTS_SPEED"))
|
|
elif model_name == "cosyvoicetts":
|
|
_set_if_present(params, "spk_id", os.getenv("CUSTOM_TTS_SPK_ID"))
|
|
_set_if_present(params, "model", os.getenv("CUSTOM_TTS_MODE"))
|
|
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
|
_set_if_present(params, "instruct_text", os.getenv("CUSTOM_TTS_INSTRUCT_TEXT"))
|
|
elif model_name == "sovitstts":
|
|
_set_if_present(params, "text_lang", os.getenv("CUSTOM_TTS_TEXT_LANG"))
|
|
_set_if_present(params, "prompt_lang", os.getenv("CUSTOM_TTS_PROMPT_LANG"))
|
|
_set_if_present(params, "text_split_method", os.getenv("CUSTOM_TTS_TEXT_SPLIT_METHOD"))
|
|
_set_if_present(params, "batch_size", os.getenv("CUSTOM_TTS_BATCH_SIZE"))
|
|
_set_if_present(params, "media_type", os.getenv("CUSTOM_TTS_MEDIA_TYPE"))
|
|
_set_if_present(params, "streaming_mode", os.getenv("CUSTOM_TTS_STREAMING"))
|
|
_set_if_present(params, "ref_audio_path", os.getenv("CUSTOM_TTS_REF_AUDIO_PATH"))
|
|
_set_if_present(params, "prompt_text", os.getenv("CUSTOM_TTS_PROMPT_TEXT"))
|
|
|
|
return params
|
|
|
|
|
|
def _tts_prompt_wav_from_env(model_name: str) -> str | None:
|
|
if model_name.lower() != "voxcpmtts":
|
|
return None
|
|
|
|
return os.getenv("CUSTOM_TTS_PROMPT_WAV") or os.getenv("VOXCPM_PROMPT_WAV") or None
|
|
|
|
|
|
def _set_if_present(params: dict[str, str], key: str, value: str | None) -> None:
|
|
if value:
|
|
params[key] = value
|
|
|
|
|
|
def _env_int(name: str, default: int) -> int:
|
|
value = os.getenv(name)
|
|
if not value:
|
|
return default
|
|
try:
|
|
return int(value)
|
|
except ValueError:
|
|
logger.warning("Invalid integer for %s=%r, using %s", name, value, default)
|
|
return default
|
|
|
|
|
|
def _env_float(name: str, default: float) -> float:
|
|
value = os.getenv(name)
|
|
if not value:
|
|
return default
|
|
try:
|
|
return float(value)
|
|
except ValueError:
|
|
logger.warning("Invalid float for %s=%r, using %s", name, value, default)
|
|
return default
|
|
|
|
|
|
def _env_bool(name: str, default: bool) -> bool:
|
|
value = os.getenv(name)
|
|
if value is None:
|
|
return default
|
|
|
|
normalized = value.strip().lower()
|
|
if normalized in {"1", "true", "yes", "on"}:
|
|
return True
|
|
if normalized in {"0", "false", "no", "off"}:
|
|
return False
|
|
|
|
logger.warning("Invalid boolean for %s=%r, using %s", name, value, default)
|
|
return default
|
|
|
|
|
|
def _recording_options_from_env() -> RecordingOptions:
|
|
return RecordingOptions(
|
|
audio=_env_bool("CUSTOM_RECORD_AUDIO", False),
|
|
traces=_env_bool("CUSTOM_RECORD_TRACES", False),
|
|
logs=_env_bool("CUSTOM_RECORD_LOGS", False),
|
|
transcript=_env_bool("CUSTOM_RECORD_TRANSCRIPT", False),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli.run_app(server)
|