"""Retrieval entrypoint for SOC Memory POC. Supports two modes: - local: retrieve from normalized mock datasets - openviking: retrieve from OpenViking resource namespaces and filter results """ from __future__ import annotations import asyncio import json from dataclasses import asdict, dataclass from pathlib import Path from typing import Any from memory_gateway.openviking_client import OpenVikingClient CASE_URI_PREFIX = "viking://resources/soc-memory-poc/case" KNOWLEDGE_URI_PREFIX = "viking://resources/soc-memory-poc/knowledge" def _load_json_dir(path: str | Path) -> list[dict[str, Any]]: path = Path(path) items: list[dict[str, Any]] = [] for file in sorted(path.rglob("*.json")): with file.open("r", encoding="utf-8") as f: items.append(json.load(f)) return items @dataclass class RetrievalQuery: scenario: str alert_type: str = "" summary: str = "" entities: dict[str, list[str]] | None = None observables: dict[str, list[str]] | None = None top_k: int = 3 def _flatten_values(data: dict[str, list[str]] | None) -> set[str]: if not data: return set() values: set[str] = set() for items in data.values(): values.update(str(item).lower() for item in items) return values def _score_case(query: RetrievalQuery, item: dict[str, Any]) -> int: score = 0 if item.get("scenario") == query.scenario: score += 50 for pattern in item.get("patterns", []): if query.alert_type and pattern == f"alert_type:{query.alert_type}": score += 20 query_observables = _flatten_values(query.observables) item_observables = _flatten_values(item.get("observables")) score += 8 * len(query_observables & item_observables) summary = query.summary.lower() haystacks = [item.get("title", "").lower(), item.get("abstract", "").lower()] for token in [t for t in summary.split() if len(t) > 4]: if any(token in text for text in haystacks): score += 2 return score def _score_knowledge(query: RetrievalQuery, item: dict[str, Any]) -> int: score = 0 if item.get("scenario") == query.scenario: score += 40 title = item.get("title", "").lower() abstract = item.get("abstract", "").lower() for token in [t for t in query.summary.lower().split() if len(t) > 4]: if token in title or token in abstract: score += 2 if query.alert_type and query.alert_type in " ".join(item.get("related_refs", {}).get("cases", [])).lower(): score += 5 return score def retrieve_context_local( query: RetrievalQuery, cases_dir: str | Path = "evaluation/datasets/normalized_cases", knowledge_dir: str | Path = "evaluation/datasets/normalized_kb", ) -> dict[str, Any]: cases = _load_json_dir(cases_dir) knowledge = _load_json_dir(knowledge_dir) ranked_cases = sorted( ({"score": _score_case(query, item), "item": item} for item in cases), key=lambda x: x["score"], reverse=True, ) ranked_knowledge = sorted( ({"score": _score_knowledge(query, item), "item": item} for item in knowledge), key=lambda x: x["score"], reverse=True, ) matched_cases = [entry for entry in ranked_cases if entry["score"] > 0][: query.top_k] matched_knowledge = [entry for entry in ranked_knowledge if entry["score"] > 0][: query.top_k] decision_points: list[str] = [] next_actions: list[str] = [] for entry in matched_knowledge: item = entry["item"] decision_points.extend(item.get("decision_points", [])) next_actions.extend(item.get("investigation_guidance", [])) return { "backend": "local", "query": asdict(query), "matched_cases": matched_cases, "matched_knowledge": matched_knowledge, "decision_points": decision_points[: query.top_k], "next_actions": next_actions[: query.top_k], } def _canonicalize_resource_uri(uri: str) -> str: if ".json/" in uri: return uri.split(".json/", 1)[0] + ".json" return uri def _query_text(query: RetrievalQuery) -> str: parts = [query.scenario, query.alert_type, query.summary] parts.extend(sorted(_flatten_values(query.observables))) return " ".join(part for part in parts if part).strip() def _dedupe_openviking_results(results: list[dict[str, Any]], prefix: str) -> list[dict[str, Any]]: deduped: dict[str, dict[str, Any]] = {} for item in results: uri = item.get("uri") or "" if not uri.startswith(prefix): continue canonical_uri = _canonicalize_resource_uri(uri) score = item.get("score") or 0 existing = deduped.get(canonical_uri) payload = { "uri": canonical_uri, "abstract": item.get("abstract", ""), "score": score, "context_type": item.get("context_type"), "source_uri": uri, } if existing is None or score > existing.get("score", 0): deduped[canonical_uri] = payload return sorted(deduped.values(), key=lambda x: x["score"], reverse=True) async def retrieve_context_openviking( query: RetrievalQuery, case_uri: str = CASE_URI_PREFIX, knowledge_uri: str = KNOWLEDGE_URI_PREFIX, ) -> dict[str, Any]: client = OpenVikingClient() try: query_text = _query_text(query) case_result = await client.search(query=query_text, uri=case_uri, limit=max(query.top_k * 5, 10)) knowledge_result = await client.search(query=query_text, uri=knowledge_uri, limit=max(query.top_k * 5, 10)) matched_cases = _dedupe_openviking_results(case_result.results, case_uri)[: query.top_k] matched_knowledge = _dedupe_openviking_results(knowledge_result.results, knowledge_uri)[: query.top_k] return { "backend": "openviking", "query": asdict(query), "matched_cases": matched_cases, "matched_knowledge": matched_knowledge, "decision_points": [], "next_actions": [], } finally: await client.close() def main() -> None: import argparse parser = argparse.ArgumentParser(description="Retrieve SOC context from local datasets or OpenViking.") parser.add_argument("--backend", choices=["local", "openviking"], default="local", help="Retrieval backend") parser.add_argument("--scenario", required=True, help="Scenario, e.g. phishing or o365_suspicious_login") parser.add_argument("--alert-type", default="", help="Alert type") parser.add_argument("--summary", default="", help="Short case summary") parser.add_argument("--top-k", type=int, default=3, help="Number of results to return") parser.add_argument("--cases-dir", default="evaluation/datasets/normalized_cases", help="Normalized case dataset directory") parser.add_argument("--knowledge-dir", default="evaluation/datasets/normalized_kb", help="Normalized knowledge dataset directory") parser.add_argument("--case-uri", default=CASE_URI_PREFIX, help="OpenViking case URI prefix") parser.add_argument("--knowledge-uri", default=KNOWLEDGE_URI_PREFIX, help="OpenViking knowledge URI prefix") args = parser.parse_args() query = RetrievalQuery( scenario=args.scenario, alert_type=args.alert_type, summary=args.summary, top_k=args.top_k, ) if args.backend == "openviking": result = asyncio.run(retrieve_context_openviking(query, case_uri=args.case_uri, knowledge_uri=args.knowledge_uri)) else: result = retrieve_context_local(query, cases_dir=args.cases_dir, knowledge_dir=args.knowledge_dir) print(json.dumps(result, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()