Files

100 lines
3.1 KiB
Python

"""配置加载模块"""
import os
from pathlib import Path
from typing import Optional
import yaml
from pydantic import ValidationError
from .types import Config, ServerConfig, OpenVikingConfig, EverMemOSConfig, MemoryConfig, LoggingConfig, LLMConfig, ObsidianConfig, StorageConfig
def load_config(config_path: Optional[str] = None) -> Config:
"""加载配置文件"""
if config_path is None:
config_path = os.environ.get("MEMORY_GATEWAY_CONFIG", "config.yaml")
config_file = Path(config_path)
if not config_file.exists():
# 返回默认配置
return _apply_env_overrides(Config())
try:
with open(config_file, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if data is None:
return _apply_env_overrides(Config())
config = Config(
server=ServerConfig(**data.get("server", {})),
openviking=OpenVikingConfig(**data.get("openviking", {})),
evermemos=EverMemOSConfig(**data.get("evermemos", {})),
memory=MemoryConfig(**data.get("memory", {})),
logging=LoggingConfig(**data.get("logging", {})),
llm=LLMConfig(**data.get("llm", {})),
obsidian=ObsidianConfig(**data.get("obsidian", {})),
storage=StorageConfig(**data.get("storage", {})),
)
return _apply_env_overrides(config)
except (ValidationError, yaml.YAMLError) as e:
print(f"配置文件解析错误: {e}")
return _apply_env_overrides(Config())
def get_config() -> Config:
"""获取全局配置(单例)"""
global _config
if _config is None:
_config = load_config()
return _config
def set_config(config: Config) -> None:
"""设置全局配置"""
global _config
_config = config
_config: Optional[Config] = None
def _apply_env_overrides(config: Config) -> Config:
openviking_updates = _backend_env_updates("OPENVIKING")
evermemos_updates = _backend_env_updates("EVERMEMOS")
if openviking_updates:
config.openviking = config.openviking.model_copy(update=openviking_updates)
if evermemos_updates:
config.evermemos = config.evermemos.model_copy(update=evermemos_updates)
return config
def _backend_env_updates(prefix: str) -> dict:
updates = {}
env_map = {
"ENABLED": "enabled",
"MODE": "mode",
"BASE_URL": "url",
"URL": "url",
"API_KEY": "api_key",
"TOKEN": "api_key",
"TIMEOUT": "timeout",
"TIMEOUT_SECONDS": "timeout",
"VERIFY_SSL": "verify_ssl",
"INGEST_PATH": "ingest_path",
}
for env_name, field_name in env_map.items():
value = os.environ.get(f"{prefix}_{env_name}")
if value is None:
continue
if field_name == "enabled":
updates[field_name] = value.lower() in {"1", "true", "yes", "on"}
elif field_name == "timeout":
updates[field_name] = int(value)
elif field_name == "verify_ssl":
updates[field_name] = value.lower() not in {"0", "false", "no", "off"}
else:
updates[field_name] = value
return updates