from __future__ import annotations import argparse import asyncio import os import signal import sys from dataclasses import dataclass from pathlib import Path from dotenv import load_dotenv CUSTOM_ENV_PATH = Path(__file__).with_name(".env") load_dotenv(dotenv_path=CUSTOM_ENV_PATH) @dataclass(frozen=True) class AgentProfile: agent_name: str llm_provider: str input_mode: str AGENT_PROFILES = { "normal": AgentProfile( agent_name="normal-agent", llm_provider="openai-compatible", input_mode="voice", ), "beaver": AgentProfile( agent_name="beaver-agent", llm_provider="beaver", input_mode="voice", ), "vision-normal": AgentProfile( agent_name="vision-normal-agent", llm_provider="openai-compatible", input_mode="vision_voice", ), "vision-beaver": AgentProfile( agent_name="vision-beaver-agent", llm_provider="beaver", input_mode="vision_voice", ), } DEFAULT_PROFILES = ("normal", "beaver") def _env_bool(name: str, default: bool) -> bool: value = os.getenv(name) if value is None: return default normalized = value.strip().lower() if normalized in {"1", "true", "yes", "on"}: return True if normalized in {"0", "false", "no", "off"}: return False return default def _parse_profiles(value: str) -> list[str]: if value.strip().lower() == "all": return list(AGENT_PROFILES) profiles: list[str] = [] for raw_profile in value.split(","): profile = raw_profile.strip().lower().replace("_", "-") if not profile: continue if profile not in AGENT_PROFILES: valid = ", ".join([*AGENT_PROFILES, "all"]) raise ValueError(f"unknown profile {raw_profile!r}; valid values: {valid}") profiles.append(profile) if not profiles: raise ValueError("at least one profile is required") return list(dict.fromkeys(profiles)) def _parse_http_port_base(value: str | None) -> int: if value is None or not value.strip(): return 0 try: port = int(value) except ValueError as exc: raise ValueError(f"invalid HTTP port base {value!r}") from exc if port < 0 or port > 65535: raise ValueError(f"invalid HTTP port base {value!r}; expected 0-65535") return port def _profile_http_port(http_port_base: int, index: int) -> int: if http_port_base == 0: return 0 port = http_port_base + index if port > 65535: raise ValueError("HTTP port range exceeds 65535") return port def _child_env(profile_name: str, *, http_port: int) -> dict[str, str]: profile = AGENT_PROFILES[profile_name] env = os.environ.copy() env.update( { "CUSTOM_AGENT_PROFILE": profile_name, "CUSTOM_AGENT_NAME": profile.agent_name, "CUSTOM_LLM_PROVIDER": profile.llm_provider, "CUSTOM_AGENT_INPUT_MODE": profile.input_mode, "CUSTOM_AGENT_HTTP_PORT": str(http_port), } ) return env async def _pipe_output(prefix: str, stream: asyncio.StreamReader) -> None: while line := await stream.readline(): text = line.decode("utf-8", errors="replace").rstrip() print(f"[{prefix}] {text}", flush=True) async def _start_profile( profile_name: str, *, mode: str, http_port: int, reload: bool, ) -> asyncio.subprocess.Process: profile = AGENT_PROFILES[profile_name] script_path = Path(__file__).with_name("custom_agent.py") args = [sys.executable, str(script_path), mode] if mode == "dev" and not reload: args.append("--no-reload") process = await asyncio.create_subprocess_exec( *args, env=_child_env(profile_name, http_port=http_port), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) print( f"started {profile_name}: pid={process.pid} agent={profile.agent_name} " f"llm={profile.llm_provider} input={profile.input_mode} http_port={http_port}", flush=True, ) if process.stdout is not None: asyncio.create_task(_pipe_output(profile_name, process.stdout)) if process.stderr is not None: asyncio.create_task(_pipe_output(profile_name, process.stderr)) return process async def _terminate(processes: list[asyncio.subprocess.Process]) -> None: for process in processes: if process.returncode is None: process.terminate() try: await asyncio.wait_for( asyncio.gather(*(process.wait() for process in processes)), timeout=10.0, ) except asyncio.TimeoutError: for process in processes: if process.returncode is None: process.kill() await asyncio.gather(*(process.wait() for process in processes)) async def _run( profiles: list[str], *, mode: str, http_port_base: int, reload: bool, ) -> int: processes = [ await _start_profile( profile, mode=mode, http_port=_profile_http_port(http_port_base, index), reload=reload, ) for index, profile in enumerate(profiles) ] stop_event = asyncio.Event() loop = asyncio.get_running_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, stop_event.set) wait_tasks = {asyncio.create_task(process.wait()): process for process in processes} stop_task = asyncio.create_task(stop_event.wait()) done, _ = await asyncio.wait( [*wait_tasks, stop_task], return_when=asyncio.FIRST_COMPLETED, ) exit_code = 0 if stop_task not in done: for task in done: process = wait_tasks[task] exit_code = task.result() or 0 print(f"agent profile process exited: pid={process.pid} code={exit_code}", flush=True) await _terminate(processes) return exit_code def main() -> None: parser = argparse.ArgumentParser(description="Start multiple custom LiveKit agent profiles.") parser.add_argument( "mode", nargs="?", default="dev", choices=("console", "dev", "start", "connect"), help="custom_agent.py CLI mode to run for each profile", ) parser.add_argument( "--profiles", default=os.getenv("CUSTOM_AGENT_PROFILES", ",".join(DEFAULT_PROFILES)), help="comma-separated profiles to start, or 'all'", ) parser.add_argument( "--http-port-base", default=os.getenv("CUSTOM_AGENT_HTTP_PORT_BASE", "0"), help=( "base HTTP health-check port for profile workers; " "0 lets the OS assign free ports" ), ) parser.add_argument( "--reload", action="store_true", default=_env_bool("CUSTOM_AGENT_DEV_RELOAD", False), help="enable auto-reload in dev mode", ) args = parser.parse_args() try: profiles = _parse_profiles(args.profiles) http_port_base = _parse_http_port_base(args.http_port_base) except ValueError as exc: parser.error(str(exc)) raise SystemExit( asyncio.run( _run( profiles, mode=args.mode, http_port_base=http_port_base, reload=args.reload, ) ) ) if __name__ == "__main__": main()