Refactor code structure for improved readability and maintainability
This commit is contained in:
@ -18,16 +18,16 @@ def load_config(config_path: Optional[str] = None) -> Config:
|
||||
|
||||
if not config_file.exists():
|
||||
# 返回默认配置
|
||||
return Config()
|
||||
return _apply_env_overrides(Config())
|
||||
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if data is None:
|
||||
return Config()
|
||||
return _apply_env_overrides(Config())
|
||||
|
||||
return Config(
|
||||
config = Config(
|
||||
server=ServerConfig(**data.get("server", {})),
|
||||
openviking=OpenVikingConfig(**data.get("openviking", {})),
|
||||
evermemos=EverMemOSConfig(**data.get("evermemos", {})),
|
||||
@ -37,9 +37,10 @@ def load_config(config_path: Optional[str] = None) -> Config:
|
||||
obsidian=ObsidianConfig(**data.get("obsidian", {})),
|
||||
storage=StorageConfig(**data.get("storage", {})),
|
||||
)
|
||||
return _apply_env_overrides(config)
|
||||
except (ValidationError, yaml.YAMLError) as e:
|
||||
print(f"配置文件解析错误: {e}")
|
||||
return Config()
|
||||
return _apply_env_overrides(Config())
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
@ -57,3 +58,42 @@ def set_config(config: Config) -> None:
|
||||
|
||||
|
||||
_config: Optional[Config] = None
|
||||
|
||||
|
||||
def _apply_env_overrides(config: Config) -> Config:
|
||||
openviking_updates = _backend_env_updates("OPENVIKING")
|
||||
evermemos_updates = _backend_env_updates("EVERMEMOS")
|
||||
if openviking_updates:
|
||||
config.openviking = config.openviking.model_copy(update=openviking_updates)
|
||||
if evermemos_updates:
|
||||
config.evermemos = config.evermemos.model_copy(update=evermemos_updates)
|
||||
return config
|
||||
|
||||
|
||||
def _backend_env_updates(prefix: str) -> dict:
|
||||
updates = {}
|
||||
env_map = {
|
||||
"ENABLED": "enabled",
|
||||
"MODE": "mode",
|
||||
"BASE_URL": "url",
|
||||
"URL": "url",
|
||||
"API_KEY": "api_key",
|
||||
"TOKEN": "api_key",
|
||||
"TIMEOUT": "timeout",
|
||||
"TIMEOUT_SECONDS": "timeout",
|
||||
"VERIFY_SSL": "verify_ssl",
|
||||
"INGEST_PATH": "ingest_path",
|
||||
}
|
||||
for env_name, field_name in env_map.items():
|
||||
value = os.environ.get(f"{prefix}_{env_name}")
|
||||
if value is None:
|
||||
continue
|
||||
if field_name == "enabled":
|
||||
updates[field_name] = value.lower() in {"1", "true", "yes", "on"}
|
||||
elif field_name == "timeout":
|
||||
updates[field_name] = int(value)
|
||||
elif field_name == "verify_ssl":
|
||||
updates[field_name] = value.lower() not in {"0", "false", "no", "off"}
|
||||
else:
|
||||
updates[field_name] = value
|
||||
return updates
|
||||
|
||||
Reference in New Issue
Block a user