"""配置加载模块""" import os from pathlib import Path from typing import Optional import yaml from pydantic import ValidationError from .types import Config, ServerConfig, OpenVikingConfig, EverMemOSConfig, MemoryConfig, LoggingConfig, LLMConfig, ObsidianConfig, StorageConfig def load_config(config_path: Optional[str] = None) -> Config: """加载配置文件""" if config_path is None: config_path = os.environ.get("MEMORY_GATEWAY_CONFIG", "config.yaml") config_file = Path(config_path) if not config_file.exists(): # 返回默认配置 return _apply_env_overrides(Config()) try: with open(config_file, "r", encoding="utf-8") as f: data = yaml.safe_load(f) if data is None: return _apply_env_overrides(Config()) config = Config( server=ServerConfig(**data.get("server", {})), openviking=OpenVikingConfig(**data.get("openviking", {})), evermemos=EverMemOSConfig(**data.get("evermemos", {})), memory=MemoryConfig(**data.get("memory", {})), logging=LoggingConfig(**data.get("logging", {})), llm=LLMConfig(**data.get("llm", {})), obsidian=ObsidianConfig(**data.get("obsidian", {})), storage=StorageConfig(**data.get("storage", {})), ) return _apply_env_overrides(config) except (ValidationError, yaml.YAMLError) as e: print(f"配置文件解析错误: {e}") return _apply_env_overrides(Config()) def get_config() -> Config: """获取全局配置(单例)""" global _config if _config is None: _config = load_config() return _config def set_config(config: Config) -> None: """设置全局配置""" global _config _config = config _config: Optional[Config] = None def _apply_env_overrides(config: Config) -> Config: openviking_updates = _backend_env_updates("OPENVIKING") evermemos_updates = _backend_env_updates("EVERMEMOS") if openviking_updates: config.openviking = config.openviking.model_copy(update=openviking_updates) if evermemos_updates: config.evermemos = config.evermemos.model_copy(update=evermemos_updates) return config def _backend_env_updates(prefix: str) -> dict: updates = {} env_map = { "ENABLED": "enabled", "MODE": "mode", "BASE_URL": "url", "URL": "url", "API_KEY": "api_key", "TOKEN": "api_key", "TIMEOUT": "timeout", "TIMEOUT_SECONDS": "timeout", "VERIFY_SSL": "verify_ssl", "INGEST_PATH": "ingest_path", } for env_name, field_name in env_map.items(): value = os.environ.get(f"{prefix}_{env_name}") if value is None: continue if field_name == "enabled": updates[field_name] = value.lower() in {"1", "true", "yes", "on"} elif field_name == "timeout": updates[field_name] = int(value) elif field_name == "verify_ssl": updates[field_name] = value.lower() not in {"0", "false", "no", "off"} else: updates[field_name] = value return updates