feat: four mode agent
This commit is contained in:
10
.env.example
10
.env.example
@ -2,7 +2,10 @@
|
||||
LIVEKIT_URL=ws://localhost:7880
|
||||
LIVEKIT_API_KEY=
|
||||
LIVEKIT_API_SECRET=
|
||||
CUSTOM_AGENT_NAME=my-agent
|
||||
|
||||
CUSTOM_AGENT_PROFILE=normal
|
||||
# CUSTOM_AGENT_NAME=normal-agent
|
||||
CUSTOM_AGENT_PROFILES=normal,beaver,vision-normal,vision-beaver
|
||||
|
||||
# Beaver terminal text WebSocket
|
||||
BEAVER_WS_URL=ws://terminaltest.1localhost.nip.io:8088/api/channels/terminal-dev/ws
|
||||
@ -18,8 +21,9 @@ CUSTOM_ASR_HOTWORDS=
|
||||
CUSTOM_ASR_ITN=
|
||||
CUSTOM_ASR_CHUNK_MODE=
|
||||
|
||||
# LLM backend: openai/openai-compatible or hermes_gateway/openclaw.
|
||||
CUSTOM_LLM_PROVIDER=beaver
|
||||
# LLM backend: openai/openai-compatible, hermes_gateway/openclaw, or beaver.
|
||||
# Defaults come from CUSTOM_AGENT_PROFILE. Uncomment to override.
|
||||
# CUSTOM_LLM_PROVIDER=beaver
|
||||
CUSTOM_BEAVER_WARMUP_TEXT=初始化连接,请简短回复 ready
|
||||
|
||||
# OpenAI-compatible LLM
|
||||
|
||||
119
custom_agent.py
119
custom_agent.py
@ -42,7 +42,6 @@ logger = logging.getLogger("custom-agent")
|
||||
|
||||
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
||||
|
||||
ROOM_LOCATOR_INSTRUCTIONS = """
|
||||
你是一个房间物品定位助手。
|
||||
@ -72,6 +71,95 @@ VOICE_INPUT_MODE = "voice"
|
||||
VISION_VOICE_INPUT_MODE = "vision_voice"
|
||||
AUTO_INPUT_MODE = "auto"
|
||||
VISION_FRAME_TOPIC = "vision.frame"
|
||||
DEFAULT_AGENT_PROFILE = "normal"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentProfile:
|
||||
agent_name: str
|
||||
llm_provider: str
|
||||
input_mode: str
|
||||
|
||||
|
||||
AGENT_PROFILES = {
|
||||
"normal": AgentProfile(
|
||||
agent_name="normal-agent",
|
||||
llm_provider="openai-compatible",
|
||||
input_mode=VOICE_INPUT_MODE,
|
||||
),
|
||||
"beaver": AgentProfile(
|
||||
agent_name="beaver-agent",
|
||||
llm_provider="beaver",
|
||||
input_mode=VOICE_INPUT_MODE,
|
||||
),
|
||||
"vision-normal": AgentProfile(
|
||||
agent_name="vision-normal-agent",
|
||||
llm_provider="openai-compatible",
|
||||
input_mode=VISION_VOICE_INPUT_MODE,
|
||||
),
|
||||
"vision-beaver": AgentProfile(
|
||||
agent_name="vision-beaver-agent",
|
||||
llm_provider="beaver",
|
||||
input_mode=VISION_VOICE_INPUT_MODE,
|
||||
),
|
||||
}
|
||||
AGENT_PROFILE_ALIASES = {
|
||||
"default": "normal",
|
||||
"openai": "normal",
|
||||
"openai-compatible": "normal",
|
||||
"llm": "normal",
|
||||
"text": "normal",
|
||||
"voice": "normal",
|
||||
"vision": "vision-normal",
|
||||
"vision-llm": "vision-normal",
|
||||
"vision-openai": "vision-normal",
|
||||
"vision-openai-compatible": "vision-normal",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_agent_profile(value: str | None) -> str:
|
||||
if not value or not value.strip():
|
||||
return DEFAULT_AGENT_PROFILE
|
||||
|
||||
normalized = value.strip().lower().replace("_", "-")
|
||||
profile = AGENT_PROFILE_ALIASES.get(normalized, normalized)
|
||||
if profile in AGENT_PROFILES:
|
||||
return profile
|
||||
|
||||
logger.warning(
|
||||
"Invalid CUSTOM_AGENT_PROFILE=%r, using %s",
|
||||
value,
|
||||
DEFAULT_AGENT_PROFILE,
|
||||
)
|
||||
return DEFAULT_AGENT_PROFILE
|
||||
|
||||
|
||||
def _agent_profile_from_name(agent_name: str | None) -> str | None:
|
||||
if not agent_name or not agent_name.strip():
|
||||
return None
|
||||
|
||||
normalized = agent_name.strip().lower().replace("_", "-")
|
||||
for profile_name, profile in AGENT_PROFILES.items():
|
||||
if normalized == profile.agent_name:
|
||||
return profile_name
|
||||
return None
|
||||
|
||||
|
||||
def _selected_agent_profile_name() -> str:
|
||||
configured_profile = os.getenv("CUSTOM_AGENT_PROFILE")
|
||||
if configured_profile and configured_profile.strip():
|
||||
return _normalize_agent_profile(configured_profile)
|
||||
|
||||
inferred_profile = _agent_profile_from_name(os.getenv("CUSTOM_AGENT_NAME"))
|
||||
if inferred_profile is not None:
|
||||
return inferred_profile
|
||||
|
||||
return DEFAULT_AGENT_PROFILE
|
||||
|
||||
|
||||
AGENT_PROFILE_NAME = _selected_agent_profile_name()
|
||||
AGENT_PROFILE = AGENT_PROFILES[AGENT_PROFILE_NAME]
|
||||
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
|
||||
|
||||
DEFAULT_EMOTION = "neutral"
|
||||
EMOTION_LABELS = {
|
||||
@ -614,7 +702,25 @@ def _model_image_save_dir_from_env() -> Path | None:
|
||||
return Path(__file__).with_name("model_images")
|
||||
|
||||
|
||||
server = AgentServer()
|
||||
def _agent_server_from_env() -> AgentServer:
|
||||
configured_port = os.getenv("CUSTOM_AGENT_HTTP_PORT")
|
||||
if configured_port is None:
|
||||
return AgentServer()
|
||||
|
||||
try:
|
||||
port = int(configured_port)
|
||||
except ValueError:
|
||||
logger.warning("Invalid integer for CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
|
||||
port = 0
|
||||
|
||||
if port < 0 or port > 65535:
|
||||
logger.warning("Invalid CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
|
||||
port = 0
|
||||
|
||||
return AgentServer(port=port)
|
||||
|
||||
|
||||
server = _agent_server_from_env()
|
||||
|
||||
|
||||
def prewarm(proc: JobProcess) -> None:
|
||||
@ -640,10 +746,10 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
||||
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
||||
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
|
||||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", "openai").strip().lower()
|
||||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
|
||||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
||||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
|
||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode))
|
||||
if LLM_PROVIDER not in {
|
||||
"openai",
|
||||
"openai-compatible",
|
||||
@ -656,7 +762,10 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
|
||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||
logger.info(
|
||||
"Using LLM provider=%s model=%s base_url=%s",
|
||||
"Using agent profile=%s agent_name=%s input_mode=%s llm_provider=%s model=%s base_url=%s",
|
||||
AGENT_PROFILE_NAME,
|
||||
AGENT_NAME or "<automatic>",
|
||||
INPUT_MODE,
|
||||
LLM_PROVIDER,
|
||||
LLM_MODEL,
|
||||
LLM_BASE_URL or "OpenAI default",
|
||||
|
||||
264
start_agent_profiles.py
Normal file
264
start_agent_profiles.py
Normal file
@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentProfile:
|
||||
agent_name: str
|
||||
llm_provider: str
|
||||
input_mode: str
|
||||
|
||||
|
||||
AGENT_PROFILES = {
|
||||
"normal": AgentProfile(
|
||||
agent_name="normal-agent",
|
||||
llm_provider="openai-compatible",
|
||||
input_mode="voice",
|
||||
),
|
||||
"beaver": AgentProfile(
|
||||
agent_name="beaver-agent",
|
||||
llm_provider="beaver",
|
||||
input_mode="voice",
|
||||
),
|
||||
"vision-normal": AgentProfile(
|
||||
agent_name="vision-normal-agent",
|
||||
llm_provider="openai-compatible",
|
||||
input_mode="vision_voice",
|
||||
),
|
||||
"vision-beaver": AgentProfile(
|
||||
agent_name="vision-beaver-agent",
|
||||
llm_provider="beaver",
|
||||
input_mode="vision_voice",
|
||||
),
|
||||
}
|
||||
DEFAULT_PROFILES = ("normal", "beaver")
|
||||
|
||||
|
||||
def _env_bool(name: str, default: bool) -> bool:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if normalized in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def _parse_profiles(value: str) -> list[str]:
|
||||
if value.strip().lower() == "all":
|
||||
return list(AGENT_PROFILES)
|
||||
|
||||
profiles: list[str] = []
|
||||
for raw_profile in value.split(","):
|
||||
profile = raw_profile.strip().lower().replace("_", "-")
|
||||
if not profile:
|
||||
continue
|
||||
if profile not in AGENT_PROFILES:
|
||||
valid = ", ".join([*AGENT_PROFILES, "all"])
|
||||
raise ValueError(f"unknown profile {raw_profile!r}; valid values: {valid}")
|
||||
profiles.append(profile)
|
||||
|
||||
if not profiles:
|
||||
raise ValueError("at least one profile is required")
|
||||
return list(dict.fromkeys(profiles))
|
||||
|
||||
|
||||
def _parse_http_port_base(value: str | None) -> int:
|
||||
if value is None or not value.strip():
|
||||
return 0
|
||||
|
||||
try:
|
||||
port = int(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"invalid HTTP port base {value!r}") from exc
|
||||
|
||||
if port < 0 or port > 65535:
|
||||
raise ValueError(f"invalid HTTP port base {value!r}; expected 0-65535")
|
||||
return port
|
||||
|
||||
|
||||
def _profile_http_port(http_port_base: int, index: int) -> int:
|
||||
if http_port_base == 0:
|
||||
return 0
|
||||
|
||||
port = http_port_base + index
|
||||
if port > 65535:
|
||||
raise ValueError("HTTP port range exceeds 65535")
|
||||
return port
|
||||
|
||||
|
||||
def _child_env(profile_name: str, *, http_port: int) -> dict[str, str]:
|
||||
profile = AGENT_PROFILES[profile_name]
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"CUSTOM_AGENT_PROFILE": profile_name,
|
||||
"CUSTOM_AGENT_NAME": profile.agent_name,
|
||||
"CUSTOM_LLM_PROVIDER": profile.llm_provider,
|
||||
"CUSTOM_AGENT_INPUT_MODE": profile.input_mode,
|
||||
"CUSTOM_AGENT_HTTP_PORT": str(http_port),
|
||||
}
|
||||
)
|
||||
return env
|
||||
|
||||
|
||||
async def _pipe_output(prefix: str, stream: asyncio.StreamReader) -> None:
|
||||
while line := await stream.readline():
|
||||
text = line.decode("utf-8", errors="replace").rstrip()
|
||||
print(f"[{prefix}] {text}", flush=True)
|
||||
|
||||
|
||||
async def _start_profile(
|
||||
profile_name: str,
|
||||
*,
|
||||
mode: str,
|
||||
http_port: int,
|
||||
reload: bool,
|
||||
) -> asyncio.subprocess.Process:
|
||||
profile = AGENT_PROFILES[profile_name]
|
||||
script_path = Path(__file__).with_name("custom_agent.py")
|
||||
args = [sys.executable, str(script_path), mode]
|
||||
if mode == "dev" and not reload:
|
||||
args.append("--no-reload")
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
env=_child_env(profile_name, http_port=http_port),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
print(
|
||||
f"started {profile_name}: pid={process.pid} agent={profile.agent_name} "
|
||||
f"llm={profile.llm_provider} input={profile.input_mode} http_port={http_port}",
|
||||
flush=True,
|
||||
)
|
||||
if process.stdout is not None:
|
||||
asyncio.create_task(_pipe_output(profile_name, process.stdout))
|
||||
if process.stderr is not None:
|
||||
asyncio.create_task(_pipe_output(profile_name, process.stderr))
|
||||
return process
|
||||
|
||||
|
||||
async def _terminate(processes: list[asyncio.subprocess.Process]) -> None:
|
||||
for process in processes:
|
||||
if process.returncode is None:
|
||||
process.terminate()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*(process.wait() for process in processes)),
|
||||
timeout=10.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
for process in processes:
|
||||
if process.returncode is None:
|
||||
process.kill()
|
||||
await asyncio.gather(*(process.wait() for process in processes))
|
||||
|
||||
|
||||
async def _run(
|
||||
profiles: list[str],
|
||||
*,
|
||||
mode: str,
|
||||
http_port_base: int,
|
||||
reload: bool,
|
||||
) -> int:
|
||||
processes = [
|
||||
await _start_profile(
|
||||
profile,
|
||||
mode=mode,
|
||||
http_port=_profile_http_port(http_port_base, index),
|
||||
reload=reload,
|
||||
)
|
||||
for index, profile in enumerate(profiles)
|
||||
]
|
||||
stop_event = asyncio.Event()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, stop_event.set)
|
||||
|
||||
wait_tasks = {asyncio.create_task(process.wait()): process for process in processes}
|
||||
stop_task = asyncio.create_task(stop_event.wait())
|
||||
done, _ = await asyncio.wait(
|
||||
[*wait_tasks, stop_task],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
exit_code = 0
|
||||
if stop_task not in done:
|
||||
for task in done:
|
||||
process = wait_tasks[task]
|
||||
exit_code = task.result() or 0
|
||||
print(f"agent profile process exited: pid={process.pid} code={exit_code}", flush=True)
|
||||
|
||||
await _terminate(processes)
|
||||
return exit_code
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Start multiple custom LiveKit agent profiles.")
|
||||
parser.add_argument(
|
||||
"mode",
|
||||
nargs="?",
|
||||
default="dev",
|
||||
choices=("console", "dev", "start", "connect"),
|
||||
help="custom_agent.py CLI mode to run for each profile",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profiles",
|
||||
default=os.getenv("CUSTOM_AGENT_PROFILES", ",".join(DEFAULT_PROFILES)),
|
||||
help="comma-separated profiles to start, or 'all'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--http-port-base",
|
||||
default=os.getenv("CUSTOM_AGENT_HTTP_PORT_BASE", "0"),
|
||||
help=(
|
||||
"base HTTP health-check port for profile workers; "
|
||||
"0 lets the OS assign free ports"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
default=_env_bool("CUSTOM_AGENT_DEV_RELOAD", False),
|
||||
help="enable auto-reload in dev mode",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
profiles = _parse_profiles(args.profiles)
|
||||
http_port_base = _parse_http_port_base(args.http_port_base)
|
||||
except ValueError as exc:
|
||||
parser.error(str(exc))
|
||||
|
||||
raise SystemExit(
|
||||
asyncio.run(
|
||||
_run(
|
||||
profiles,
|
||||
mode=args.mode,
|
||||
http_port_base=http_port_base,
|
||||
reload=args.reload,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user