Files
livekit_agents/start_agent_profiles.py
2026-06-04 15:54:09 +08:00

265 lines
7.3 KiB
Python

from __future__ import annotations
import argparse
import asyncio
import os
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from dotenv import load_dotenv
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
@dataclass(frozen=True)
class AgentProfile:
agent_name: str
llm_provider: str
input_mode: str
AGENT_PROFILES = {
"normal": AgentProfile(
agent_name="normal-agent",
llm_provider="openai-compatible",
input_mode="voice",
),
"beaver": AgentProfile(
agent_name="beaver-agent",
llm_provider="beaver",
input_mode="voice",
),
"vision-normal": AgentProfile(
agent_name="vision-normal-agent",
llm_provider="openai-compatible",
input_mode="vision_voice",
),
"vision-beaver": AgentProfile(
agent_name="vision-beaver-agent",
llm_provider="beaver",
input_mode="vision_voice",
),
}
DEFAULT_PROFILES = ("normal", "beaver")
def _env_bool(name: str, default: bool) -> bool:
value = os.getenv(name)
if value is None:
return default
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
def _parse_profiles(value: str) -> list[str]:
if value.strip().lower() == "all":
return list(AGENT_PROFILES)
profiles: list[str] = []
for raw_profile in value.split(","):
profile = raw_profile.strip().lower().replace("_", "-")
if not profile:
continue
if profile not in AGENT_PROFILES:
valid = ", ".join([*AGENT_PROFILES, "all"])
raise ValueError(f"unknown profile {raw_profile!r}; valid values: {valid}")
profiles.append(profile)
if not profiles:
raise ValueError("at least one profile is required")
return list(dict.fromkeys(profiles))
def _parse_http_port_base(value: str | None) -> int:
if value is None or not value.strip():
return 0
try:
port = int(value)
except ValueError as exc:
raise ValueError(f"invalid HTTP port base {value!r}") from exc
if port < 0 or port > 65535:
raise ValueError(f"invalid HTTP port base {value!r}; expected 0-65535")
return port
def _profile_http_port(http_port_base: int, index: int) -> int:
if http_port_base == 0:
return 0
port = http_port_base + index
if port > 65535:
raise ValueError("HTTP port range exceeds 65535")
return port
def _child_env(profile_name: str, *, http_port: int) -> dict[str, str]:
profile = AGENT_PROFILES[profile_name]
env = os.environ.copy()
env.update(
{
"CUSTOM_AGENT_PROFILE": profile_name,
"CUSTOM_AGENT_NAME": profile.agent_name,
"CUSTOM_LLM_PROVIDER": profile.llm_provider,
"CUSTOM_AGENT_INPUT_MODE": profile.input_mode,
"CUSTOM_AGENT_HTTP_PORT": str(http_port),
}
)
return env
async def _pipe_output(prefix: str, stream: asyncio.StreamReader) -> None:
while line := await stream.readline():
text = line.decode("utf-8", errors="replace").rstrip()
print(f"[{prefix}] {text}", flush=True)
async def _start_profile(
profile_name: str,
*,
mode: str,
http_port: int,
reload: bool,
) -> asyncio.subprocess.Process:
profile = AGENT_PROFILES[profile_name]
script_path = Path(__file__).with_name("custom_agent.py")
args = [sys.executable, str(script_path), mode]
if mode == "dev" and not reload:
args.append("--no-reload")
process = await asyncio.create_subprocess_exec(
*args,
env=_child_env(profile_name, http_port=http_port),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
print(
f"started {profile_name}: pid={process.pid} agent={profile.agent_name} "
f"llm={profile.llm_provider} input={profile.input_mode} http_port={http_port}",
flush=True,
)
if process.stdout is not None:
asyncio.create_task(_pipe_output(profile_name, process.stdout))
if process.stderr is not None:
asyncio.create_task(_pipe_output(profile_name, process.stderr))
return process
async def _terminate(processes: list[asyncio.subprocess.Process]) -> None:
for process in processes:
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(
asyncio.gather(*(process.wait() for process in processes)),
timeout=10.0,
)
except asyncio.TimeoutError:
for process in processes:
if process.returncode is None:
process.kill()
await asyncio.gather(*(process.wait() for process in processes))
async def _run(
profiles: list[str],
*,
mode: str,
http_port_base: int,
reload: bool,
) -> int:
processes = [
await _start_profile(
profile,
mode=mode,
http_port=_profile_http_port(http_port_base, index),
reload=reload,
)
for index, profile in enumerate(profiles)
]
stop_event = asyncio.Event()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, stop_event.set)
wait_tasks = {asyncio.create_task(process.wait()): process for process in processes}
stop_task = asyncio.create_task(stop_event.wait())
done, _ = await asyncio.wait(
[*wait_tasks, stop_task],
return_when=asyncio.FIRST_COMPLETED,
)
exit_code = 0
if stop_task not in done:
for task in done:
process = wait_tasks[task]
exit_code = task.result() or 0
print(f"agent profile process exited: pid={process.pid} code={exit_code}", flush=True)
await _terminate(processes)
return exit_code
def main() -> None:
parser = argparse.ArgumentParser(description="Start multiple custom LiveKit agent profiles.")
parser.add_argument(
"mode",
nargs="?",
default="dev",
choices=("console", "dev", "start", "connect"),
help="custom_agent.py CLI mode to run for each profile",
)
parser.add_argument(
"--profiles",
default=os.getenv("CUSTOM_AGENT_PROFILES", ",".join(DEFAULT_PROFILES)),
help="comma-separated profiles to start, or 'all'",
)
parser.add_argument(
"--http-port-base",
default=os.getenv("CUSTOM_AGENT_HTTP_PORT_BASE", "0"),
help=(
"base HTTP health-check port for profile workers; "
"0 lets the OS assign free ports"
),
)
parser.add_argument(
"--reload",
action="store_true",
default=_env_bool("CUSTOM_AGENT_DEV_RELOAD", False),
help="enable auto-reload in dev mode",
)
args = parser.parse_args()
try:
profiles = _parse_profiles(args.profiles)
http_port_base = _parse_http_port_base(args.http_port_base)
except ValueError as exc:
parser.error(str(exc))
raise SystemExit(
asyncio.run(
_run(
profiles,
mode=args.mode,
http_port_base=http_port_base,
reload=args.reload,
)
)
)
if __name__ == "__main__":
main()