feat: four mode agent
This commit is contained in:
119
custom_agent.py
119
custom_agent.py
@ -42,7 +42,6 @@ logger = logging.getLogger("custom-agent")
|
||||
|
||||
CUSTOM_ENV_PATH = Path(__file__).with_name(".env")
|
||||
load_dotenv(dotenv_path=CUSTOM_ENV_PATH)
|
||||
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME", "")
|
||||
|
||||
ROOM_LOCATOR_INSTRUCTIONS = """
|
||||
你是一个房间物品定位助手。
|
||||
@ -72,6 +71,95 @@ VOICE_INPUT_MODE = "voice"
|
||||
VISION_VOICE_INPUT_MODE = "vision_voice"
|
||||
AUTO_INPUT_MODE = "auto"
|
||||
VISION_FRAME_TOPIC = "vision.frame"
|
||||
DEFAULT_AGENT_PROFILE = "normal"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentProfile:
|
||||
agent_name: str
|
||||
llm_provider: str
|
||||
input_mode: str
|
||||
|
||||
|
||||
AGENT_PROFILES = {
|
||||
"normal": AgentProfile(
|
||||
agent_name="normal-agent",
|
||||
llm_provider="openai-compatible",
|
||||
input_mode=VOICE_INPUT_MODE,
|
||||
),
|
||||
"beaver": AgentProfile(
|
||||
agent_name="beaver-agent",
|
||||
llm_provider="beaver",
|
||||
input_mode=VOICE_INPUT_MODE,
|
||||
),
|
||||
"vision-normal": AgentProfile(
|
||||
agent_name="vision-normal-agent",
|
||||
llm_provider="openai-compatible",
|
||||
input_mode=VISION_VOICE_INPUT_MODE,
|
||||
),
|
||||
"vision-beaver": AgentProfile(
|
||||
agent_name="vision-beaver-agent",
|
||||
llm_provider="beaver",
|
||||
input_mode=VISION_VOICE_INPUT_MODE,
|
||||
),
|
||||
}
|
||||
AGENT_PROFILE_ALIASES = {
|
||||
"default": "normal",
|
||||
"openai": "normal",
|
||||
"openai-compatible": "normal",
|
||||
"llm": "normal",
|
||||
"text": "normal",
|
||||
"voice": "normal",
|
||||
"vision": "vision-normal",
|
||||
"vision-llm": "vision-normal",
|
||||
"vision-openai": "vision-normal",
|
||||
"vision-openai-compatible": "vision-normal",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_agent_profile(value: str | None) -> str:
|
||||
if not value or not value.strip():
|
||||
return DEFAULT_AGENT_PROFILE
|
||||
|
||||
normalized = value.strip().lower().replace("_", "-")
|
||||
profile = AGENT_PROFILE_ALIASES.get(normalized, normalized)
|
||||
if profile in AGENT_PROFILES:
|
||||
return profile
|
||||
|
||||
logger.warning(
|
||||
"Invalid CUSTOM_AGENT_PROFILE=%r, using %s",
|
||||
value,
|
||||
DEFAULT_AGENT_PROFILE,
|
||||
)
|
||||
return DEFAULT_AGENT_PROFILE
|
||||
|
||||
|
||||
def _agent_profile_from_name(agent_name: str | None) -> str | None:
|
||||
if not agent_name or not agent_name.strip():
|
||||
return None
|
||||
|
||||
normalized = agent_name.strip().lower().replace("_", "-")
|
||||
for profile_name, profile in AGENT_PROFILES.items():
|
||||
if normalized == profile.agent_name:
|
||||
return profile_name
|
||||
return None
|
||||
|
||||
|
||||
def _selected_agent_profile_name() -> str:
|
||||
configured_profile = os.getenv("CUSTOM_AGENT_PROFILE")
|
||||
if configured_profile and configured_profile.strip():
|
||||
return _normalize_agent_profile(configured_profile)
|
||||
|
||||
inferred_profile = _agent_profile_from_name(os.getenv("CUSTOM_AGENT_NAME"))
|
||||
if inferred_profile is not None:
|
||||
return inferred_profile
|
||||
|
||||
return DEFAULT_AGENT_PROFILE
|
||||
|
||||
|
||||
AGENT_PROFILE_NAME = _selected_agent_profile_name()
|
||||
AGENT_PROFILE = AGENT_PROFILES[AGENT_PROFILE_NAME]
|
||||
AGENT_NAME = os.getenv("CUSTOM_AGENT_NAME") or AGENT_PROFILE.agent_name
|
||||
|
||||
DEFAULT_EMOTION = "neutral"
|
||||
EMOTION_LABELS = {
|
||||
@ -614,7 +702,25 @@ def _model_image_save_dir_from_env() -> Path | None:
|
||||
return Path(__file__).with_name("model_images")
|
||||
|
||||
|
||||
server = AgentServer()
|
||||
def _agent_server_from_env() -> AgentServer:
|
||||
configured_port = os.getenv("CUSTOM_AGENT_HTTP_PORT")
|
||||
if configured_port is None:
|
||||
return AgentServer()
|
||||
|
||||
try:
|
||||
port = int(configured_port)
|
||||
except ValueError:
|
||||
logger.warning("Invalid integer for CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
|
||||
port = 0
|
||||
|
||||
if port < 0 or port > 65535:
|
||||
logger.warning("Invalid CUSTOM_AGENT_HTTP_PORT=%r, using 0", configured_port)
|
||||
port = 0
|
||||
|
||||
return AgentServer(port=port)
|
||||
|
||||
|
||||
server = _agent_server_from_env()
|
||||
|
||||
|
||||
def prewarm(proc: JobProcess) -> None:
|
||||
@ -640,10 +746,10 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
LLM_BASE_URL = os.getenv("CUSTOM_LLM_BASE_URL")
|
||||
LLM_MODEL = os.getenv("CUSTOM_LLM_MODEL", "qwen-max")
|
||||
LLM_API_KEY = os.getenv("CUSTOM_LLM_API_KEY")
|
||||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", "openai").strip().lower()
|
||||
LLM_PROVIDER = os.getenv("CUSTOM_LLM_PROVIDER", AGENT_PROFILE.llm_provider).strip().lower()
|
||||
TEXT_LLM_MODEL = os.getenv("CUSTOM_TEXT_LLM_MODEL", LLM_MODEL)
|
||||
VISION_LLM_MODEL = os.getenv("CUSTOM_VISION_LLM_MODEL", LLM_MODEL)
|
||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE"))
|
||||
INPUT_MODE = _normalize_input_mode(os.getenv("CUSTOM_AGENT_INPUT_MODE", AGENT_PROFILE.input_mode))
|
||||
if LLM_PROVIDER not in {
|
||||
"openai",
|
||||
"openai-compatible",
|
||||
@ -656,7 +762,10 @@ async def entrypoint(ctx: JobContext) -> None:
|
||||
if LLM_PROVIDER in {"openai", "openai-compatible"} and not LLM_API_KEY:
|
||||
raise RuntimeError(f"CUSTOM_LLM_API_KEY is not set in {CUSTOM_ENV_PATH}")
|
||||
logger.info(
|
||||
"Using LLM provider=%s model=%s base_url=%s",
|
||||
"Using agent profile=%s agent_name=%s input_mode=%s llm_provider=%s model=%s base_url=%s",
|
||||
AGENT_PROFILE_NAME,
|
||||
AGENT_NAME or "<automatic>",
|
||||
INPUT_MODE,
|
||||
LLM_PROVIDER,
|
||||
LLM_MODEL,
|
||||
LLM_BASE_URL or "OpenAI default",
|
||||
|
||||
Reference in New Issue
Block a user