"""Minimal EverMemOS-style consolidation worker. This worker is deliberately deterministic for the POC. It extracts stable candidate memories from session episodes, deduplicates them against existing records, promotes eligible records, and sends high-risk/high-value candidates to Obsidian review rather than blindly polluting long-term memory. """ from __future__ import annotations import hashlib import re from dataclasses import dataclass, field from memory_gateway.namespace import default_namespace_for_context from memory_gateway.obsidian_review import write_review_draft from memory_gateway.repositories import MetadataRepository from memory_gateway.schemas import ( AccessContext, EpisodeRecord, MemoryRecord, MemoryType, SourceType, Visibility, ) _SENTENCE_RE = re.compile(r"(?<=[。!?.!?])\s+|\n+") _NOISE_RE = re.compile(r"\s+") @dataclass class ConsolidationResult: session_id: str episodes: int candidates: list[MemoryRecord] = field(default_factory=list) promoted: list[MemoryRecord] = field(default_factory=list) duplicates: list[dict] = field(default_factory=list) review_drafts: list[str] = field(default_factory=list) conflicts: list[dict] = field(default_factory=list) class EverMemOSWorker: def __init__(self, repo: MetadataRepository) -> None: self.repo = repo def consolidate_session( self, session_id: str, ctx: AccessContext, min_importance: float = 0.6, target_namespace: str | None = None, ) -> ConsolidationResult: episodes = self.repo.list_session_episodes(session_id) result = ConsolidationResult(session_id=session_id, episodes=len(episodes)) existing = list(self.repo.list_memories()) seen_fingerprints = {self._fingerprint(memory.content): memory for memory in existing} for episode in episodes: for candidate in self._extract_candidates(episode, ctx, min_importance, target_namespace): result.candidates.append(candidate) fingerprint = self._fingerprint(candidate.content) duplicate = seen_fingerprints.get(fingerprint) if duplicate: result.duplicates.append({"candidate_id": candidate.id, "existing_id": duplicate.id}) continue conflict_ids = self._find_conflicts(candidate, existing) if conflict_ids: draft = write_review_draft(candidate, reason="conflict", conflict_ids=conflict_ids) result.review_drafts.append(str(draft)) result.conflicts.append({"candidate_id": candidate.id, "conflict_ids": conflict_ids}) continue if candidate.importance >= 0.85: draft = write_review_draft(candidate, reason="high_value") result.review_drafts.append(str(draft)) continue if candidate.importance >= min_importance and candidate.confidence >= 0.55: self.repo.upsert_memory(candidate) result.promoted.append(candidate) seen_fingerprints[fingerprint] = candidate existing.append(candidate) return result def _extract_candidates( self, episode: EpisodeRecord, ctx: AccessContext, min_importance: float, target_namespace: str | None, ) -> list[MemoryRecord]: text = episode.summary or episode.content parts = [self._normalize(part) for part in _SENTENCE_RE.split(text) if self._normalize(part)] candidates: list[MemoryRecord] = [] for part in parts: if len(part) < 20: continue memory_type = self._classify_type(part, episode.tags) importance = self._estimate_importance(part, episode.tags, min_importance) confidence = 0.65 if episode.summary else 0.58 visibility = Visibility.WORKSPACE_SHARED if "workspace" in episode.tags and ctx.workspace_id else Visibility.PRIVATE memory_ctx = AccessContext( user_id=ctx.user_id, agent_id=ctx.agent_id, workspace_id=ctx.workspace_id, session_id=ctx.session_id, ) candidates.append( MemoryRecord( user_id=ctx.user_id, agent_id=ctx.agent_id, workspace_id=ctx.workspace_id, session_id=episode.session_id, namespace=target_namespace or default_namespace_for_context(memory_ctx, visibility), memory_type=memory_type, content=part, summary=part[:180], tags=list(set(episode.tags + ["promoted-from-session", "evermemos-candidate"])), importance=importance, confidence=confidence, visibility=visibility, source=SourceType.EVERMEMOS, source_ref=episode.id, ) ) return candidates def _classify_type(self, text: str, tags: list[str]) -> MemoryType: lowered = text.lower() if "preference" in tags or "偏好" in text: return MemoryType.PREFERENCE if "decision" in tags or "决定" in text or "决策" in text: return MemoryType.DECISION if "procedure" in tags or "步骤" in text or "流程" in text: return MemoryType.PROCEDURE if "经验" in text or "worked" in lowered or "failed" in lowered: return MemoryType.EXPERIENCE return MemoryType.SUMMARY def _estimate_importance(self, text: str, tags: list[str], min_importance: float) -> float: importance = max(min_importance, 0.6) signal_words = ["必须", "不要", "偏好", "长期", "决策", "结论", "重要", "preference", "decision", "must"] if any(word in text.lower() for word in signal_words): importance += 0.15 if "review" in tags or "high-value" in tags: importance += 0.2 return min(1.0, importance) def _find_conflicts(self, candidate: MemoryRecord, existing: list[MemoryRecord]) -> list[str]: candidate_text = candidate.content.lower() negation_signals = ["不要", "不再", "禁止", "not ", "never", "disable"] positive_signals = ["需要", "必须", "启用", "prefer", "always", "enable"] has_negative = any(signal in candidate_text for signal in negation_signals) has_positive = any(signal in candidate_text for signal in positive_signals) if not has_negative and not has_positive: return [] candidate_tokens = self._tokens(candidate.content) conflicts = [] for memory in existing: if memory.user_id != candidate.user_id: continue if memory.memory_type != candidate.memory_type: continue overlap = candidate_tokens.intersection(self._tokens(memory.content)) if len(overlap) < 2: continue memory_text = memory.content.lower() memory_negative = any(signal in memory_text for signal in negation_signals) memory_positive = any(signal in memory_text for signal in positive_signals) if has_negative != memory_negative or has_positive != memory_positive: conflicts.append(memory.id) return conflicts def _tokens(self, text: str) -> set[str]: return {token for token in re.split(r"[^a-zA-Z0-9\u4e00-\u9fff]+", text.lower()) if len(token) >= 2} def _normalize(self, text: str) -> str: return _NOISE_RE.sub(" ", text).strip(" -_*#\t") def _fingerprint(self, text: str) -> str: normalized = self._normalize(text).lower() return hashlib.sha1(normalized.encode("utf-8")).hexdigest()