100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
"""配置加载模块"""
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import yaml
|
|
from pydantic import ValidationError
|
|
|
|
from .types import Config, ServerConfig, OpenVikingConfig, EverMemOSConfig, MemoryConfig, LoggingConfig, LLMConfig, ObsidianConfig, StorageConfig
|
|
|
|
|
|
def load_config(config_path: Optional[str] = None) -> Config:
|
|
"""加载配置文件"""
|
|
if config_path is None:
|
|
config_path = os.environ.get("MEMORY_GATEWAY_CONFIG", "config.yaml")
|
|
|
|
config_file = Path(config_path)
|
|
|
|
if not config_file.exists():
|
|
# 返回默认配置
|
|
return _apply_env_overrides(Config())
|
|
|
|
try:
|
|
with open(config_file, "r", encoding="utf-8") as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
if data is None:
|
|
return _apply_env_overrides(Config())
|
|
|
|
config = Config(
|
|
server=ServerConfig(**data.get("server", {})),
|
|
openviking=OpenVikingConfig(**data.get("openviking", {})),
|
|
evermemos=EverMemOSConfig(**data.get("evermemos", {})),
|
|
memory=MemoryConfig(**data.get("memory", {})),
|
|
logging=LoggingConfig(**data.get("logging", {})),
|
|
llm=LLMConfig(**data.get("llm", {})),
|
|
obsidian=ObsidianConfig(**data.get("obsidian", {})),
|
|
storage=StorageConfig(**data.get("storage", {})),
|
|
)
|
|
return _apply_env_overrides(config)
|
|
except (ValidationError, yaml.YAMLError) as e:
|
|
print(f"配置文件解析错误: {e}")
|
|
return _apply_env_overrides(Config())
|
|
|
|
|
|
def get_config() -> Config:
|
|
"""获取全局配置(单例)"""
|
|
global _config
|
|
if _config is None:
|
|
_config = load_config()
|
|
return _config
|
|
|
|
|
|
def set_config(config: Config) -> None:
|
|
"""设置全局配置"""
|
|
global _config
|
|
_config = config
|
|
|
|
|
|
_config: Optional[Config] = None
|
|
|
|
|
|
def _apply_env_overrides(config: Config) -> Config:
|
|
openviking_updates = _backend_env_updates("OPENVIKING")
|
|
evermemos_updates = _backend_env_updates("EVERMEMOS")
|
|
if openviking_updates:
|
|
config.openviking = config.openviking.model_copy(update=openviking_updates)
|
|
if evermemos_updates:
|
|
config.evermemos = config.evermemos.model_copy(update=evermemos_updates)
|
|
return config
|
|
|
|
|
|
def _backend_env_updates(prefix: str) -> dict:
|
|
updates = {}
|
|
env_map = {
|
|
"ENABLED": "enabled",
|
|
"MODE": "mode",
|
|
"BASE_URL": "url",
|
|
"URL": "url",
|
|
"API_KEY": "api_key",
|
|
"TOKEN": "api_key",
|
|
"TIMEOUT": "timeout",
|
|
"TIMEOUT_SECONDS": "timeout",
|
|
"VERIFY_SSL": "verify_ssl",
|
|
"INGEST_PATH": "ingest_path",
|
|
}
|
|
for env_name, field_name in env_map.items():
|
|
value = os.environ.get(f"{prefix}_{env_name}")
|
|
if value is None:
|
|
continue
|
|
if field_name == "enabled":
|
|
updates[field_name] = value.lower() in {"1", "true", "yes", "on"}
|
|
elif field_name == "timeout":
|
|
updates[field_name] = int(value)
|
|
elif field_name == "verify_ssl":
|
|
updates[field_name] = value.lower() not in {"0", "false", "no", "off"}
|
|
else:
|
|
updates[field_name] = value
|
|
return updates
|