"""Preservation checks for skill revision drafts.""" from __future__ import annotations import re from typing import Any def check_preservation(*, base_content: str, draft_content: str) -> dict[str, Any]: base_sections = _sections(base_content) draft_sections = _sections(draft_content) preserved: list[str] = [] changed: list[str] = [] dropped: list[str] = [] for heading, body in base_sections.items(): draft_body = draft_sections.get(heading) if draft_body is None: dropped.append(heading) continue preserved.append(heading) if _normalize(body) != _normalize(draft_body): changed.append(heading) risk_level = "high" if dropped else "low" return { "passed": not dropped, "risk_level": risk_level, "preserved_sections": preserved, "changed_sections": changed, "dropped_sections": dropped, } def check_plugin_merge_preservation( *, local_content: str, upstream_content: str, draft_content: str, merge_decisions: dict[str, Any], ) -> dict[str, Any]: local = check_preservation(base_content=local_content, draft_content=draft_content) upstream = check_preservation(base_content=upstream_content, draft_content=draft_content) unresolved = [str(item) for item in merge_decisions.get("unresolved_conflicts") or []] safety_sections_missing = _important_sections_missing(upstream, local) passed = bool(local.get("passed")) and bool(upstream.get("passed")) and not unresolved and not safety_sections_missing return { "mode": "plugin_three_way", "passed": passed, "risk_level": "high" if not passed else "low", "local": local, "upstream": upstream, "unresolved_conflicts": unresolved, "safety_sections_missing": safety_sections_missing, "resolved_conflicts": [str(item) for item in merge_decisions.get("resolved_conflicts") or []], } def _sections(content: str) -> dict[str, str]: current = "body" sections: dict[str, list[str]] = {current: []} for line in (content or "").splitlines(): match = re.match(r"^#{1,6}\s+(.+?)\s*$", line) if match: current = match.group(1).strip() sections.setdefault(current, []) continue sections.setdefault(current, []).append(line) return { heading: "\n".join(lines).strip() for heading, lines in sections.items() if "\n".join(lines).strip() } def _normalize(value: str) -> str: return re.sub(r"\s+", " ", value or "").strip().lower() def _important_sections_missing(*reports: dict[str, Any]) -> list[str]: important = {"safety", "required tools", "required tool", "tools"} missing: list[str] = [] for report in reports: for section in report.get("dropped_sections") or []: if str(section).strip().lower() in important and str(section) not in missing: missing.append(str(section)) return missing