Initial SOC memory POC implementation

This commit is contained in:
2026-04-27 17:13:06 +08:00
parent fc68581198
commit e6b1520bce
89 changed files with 7610 additions and 1 deletions

View File

@ -0,0 +1,42 @@
# retrieve_context_skill
这个 skill 用于根据当前 case 的关键信号,从 OpenViking 或 mock dataset 中召回最相关的上下文。
## 目标
输入当前 case 的场景、告警类型、IOC、描述输出一组排序后的相关内容
- 相似历史 case
- 相关 KB
- 相关 Playbook
- 关键 decision points
## 第一阶段输入
- `scenario`
- `alert_type`
- `summary`
- `entities`
- `observables`
- `top_k`
## 第一阶段输出
- `matched_cases`
- `matched_knowledge`
- `decision_points`
- `next_actions`
## 第一阶段检索策略
1. 先按 `scenario` 过滤
2. 再按 `alert_type`、IOC、关键词做匹配
3. 再按 evidence / tags 做轻量重排序
4. 输出 top-k
## 第一阶段不做
- 向量检索
- 图检索
- 个性化排序
- 多源复杂重排

View File

@ -0,0 +1,39 @@
# retrieve_context_skill
## 用途
在 SOC case 研判时,为 agent 检索最相关的历史 case 和知识上下文。
## 输入
- `scenario`: 场景,如 `phishing``o365_suspicious_login`
- `alert_type`: 告警类型
- `summary`: 当前 case 摘要
- `entities`: 用户、主机、邮箱等
- `observables`: 域名、IP、URL、Hash 等
- `top_k`: 期望返回条数
## 输出
- 相关历史 case 列表
- 相关 KB / Playbook 列表
- 关键 evidence / decision points
- 推荐下一步调查动作
## 默认检索顺序
1. `session/<session_id>`
2. `soc/case`
3. `soc/knowledge`
4. `agent/<agent_id>`
5. `user/<user_id>`
## Mock 阶段工作方式
在没有真实数据和完整 OpenViking 检索链路时,先使用 `evaluation/datasets/mock_cases/``evaluation/datasets/mock_kb/` 做本地检索验证。
## 成功标准
- 钓鱼 case 能召回钓鱼 playbook 和相似 phishing case
- O365 异常登录 case 能召回登录异常 KB 和相似 case
- 返回结果对人工 reviewer 看起来是“有帮助的上下文”,而不是泛资料堆积

View File

@ -0,0 +1,216 @@
"""Retrieval entrypoint for SOC Memory POC.
Supports two modes:
- local: retrieve from normalized mock datasets
- openviking: retrieve from OpenViking resource namespaces and filter results
"""
from __future__ import annotations
import asyncio
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
from memory_gateway.openviking_client import OpenVikingClient
CASE_URI_PREFIX = "viking://resources/soc-memory-poc/case"
KNOWLEDGE_URI_PREFIX = "viking://resources/soc-memory-poc/knowledge"
def _load_json_dir(path: str | Path) -> list[dict[str, Any]]:
path = Path(path)
items: list[dict[str, Any]] = []
for file in sorted(path.rglob("*.json")):
with file.open("r", encoding="utf-8") as f:
items.append(json.load(f))
return items
@dataclass
class RetrievalQuery:
scenario: str
alert_type: str = ""
summary: str = ""
entities: dict[str, list[str]] | None = None
observables: dict[str, list[str]] | None = None
top_k: int = 3
def _flatten_values(data: dict[str, list[str]] | None) -> set[str]:
if not data:
return set()
values: set[str] = set()
for items in data.values():
values.update(str(item).lower() for item in items)
return values
def _score_case(query: RetrievalQuery, item: dict[str, Any]) -> int:
score = 0
if item.get("scenario") == query.scenario:
score += 50
for pattern in item.get("patterns", []):
if query.alert_type and pattern == f"alert_type:{query.alert_type}":
score += 20
query_observables = _flatten_values(query.observables)
item_observables = _flatten_values(item.get("observables"))
score += 8 * len(query_observables & item_observables)
summary = query.summary.lower()
haystacks = [item.get("title", "").lower(), item.get("abstract", "").lower()]
for token in [t for t in summary.split() if len(t) > 4]:
if any(token in text for text in haystacks):
score += 2
return score
def _score_knowledge(query: RetrievalQuery, item: dict[str, Any]) -> int:
score = 0
if item.get("scenario") == query.scenario:
score += 40
title = item.get("title", "").lower()
abstract = item.get("abstract", "").lower()
for token in [t for t in query.summary.lower().split() if len(t) > 4]:
if token in title or token in abstract:
score += 2
if query.alert_type and query.alert_type in " ".join(item.get("related_refs", {}).get("cases", [])).lower():
score += 5
return score
def retrieve_context_local(
query: RetrievalQuery,
cases_dir: str | Path = "evaluation/datasets/normalized_cases",
knowledge_dir: str | Path = "evaluation/datasets/normalized_kb",
) -> dict[str, Any]:
cases = _load_json_dir(cases_dir)
knowledge = _load_json_dir(knowledge_dir)
ranked_cases = sorted(
({"score": _score_case(query, item), "item": item} for item in cases),
key=lambda x: x["score"],
reverse=True,
)
ranked_knowledge = sorted(
({"score": _score_knowledge(query, item), "item": item} for item in knowledge),
key=lambda x: x["score"],
reverse=True,
)
matched_cases = [entry for entry in ranked_cases if entry["score"] > 0][: query.top_k]
matched_knowledge = [entry for entry in ranked_knowledge if entry["score"] > 0][: query.top_k]
decision_points: list[str] = []
next_actions: list[str] = []
for entry in matched_knowledge:
item = entry["item"]
decision_points.extend(item.get("decision_points", []))
next_actions.extend(item.get("investigation_guidance", []))
return {
"backend": "local",
"query": asdict(query),
"matched_cases": matched_cases,
"matched_knowledge": matched_knowledge,
"decision_points": decision_points[: query.top_k],
"next_actions": next_actions[: query.top_k],
}
def _canonicalize_resource_uri(uri: str) -> str:
if ".json/" in uri:
return uri.split(".json/", 1)[0] + ".json"
return uri
def _query_text(query: RetrievalQuery) -> str:
parts = [query.scenario, query.alert_type, query.summary]
parts.extend(sorted(_flatten_values(query.observables)))
return " ".join(part for part in parts if part).strip()
def _dedupe_openviking_results(results: list[dict[str, Any]], prefix: str) -> list[dict[str, Any]]:
deduped: dict[str, dict[str, Any]] = {}
for item in results:
uri = item.get("uri") or ""
if not uri.startswith(prefix):
continue
canonical_uri = _canonicalize_resource_uri(uri)
score = item.get("score") or 0
existing = deduped.get(canonical_uri)
payload = {
"uri": canonical_uri,
"abstract": item.get("abstract", ""),
"score": score,
"context_type": item.get("context_type"),
"source_uri": uri,
}
if existing is None or score > existing.get("score", 0):
deduped[canonical_uri] = payload
return sorted(deduped.values(), key=lambda x: x["score"], reverse=True)
async def retrieve_context_openviking(
query: RetrievalQuery,
case_uri: str = CASE_URI_PREFIX,
knowledge_uri: str = KNOWLEDGE_URI_PREFIX,
) -> dict[str, Any]:
client = OpenVikingClient()
try:
query_text = _query_text(query)
case_result = await client.search(query=query_text, uri=case_uri, limit=max(query.top_k * 5, 10))
knowledge_result = await client.search(query=query_text, uri=knowledge_uri, limit=max(query.top_k * 5, 10))
matched_cases = _dedupe_openviking_results(case_result.results, case_uri)[: query.top_k]
matched_knowledge = _dedupe_openviking_results(knowledge_result.results, knowledge_uri)[: query.top_k]
return {
"backend": "openviking",
"query": asdict(query),
"matched_cases": matched_cases,
"matched_knowledge": matched_knowledge,
"decision_points": [],
"next_actions": [],
}
finally:
await client.close()
def main() -> None:
import argparse
parser = argparse.ArgumentParser(description="Retrieve SOC context from local datasets or OpenViking.")
parser.add_argument("--backend", choices=["local", "openviking"], default="local", help="Retrieval backend")
parser.add_argument("--scenario", required=True, help="Scenario, e.g. phishing or o365_suspicious_login")
parser.add_argument("--alert-type", default="", help="Alert type")
parser.add_argument("--summary", default="", help="Short case summary")
parser.add_argument("--top-k", type=int, default=3, help="Number of results to return")
parser.add_argument("--cases-dir", default="evaluation/datasets/normalized_cases", help="Normalized case dataset directory")
parser.add_argument("--knowledge-dir", default="evaluation/datasets/normalized_kb", help="Normalized knowledge dataset directory")
parser.add_argument("--case-uri", default=CASE_URI_PREFIX, help="OpenViking case URI prefix")
parser.add_argument("--knowledge-uri", default=KNOWLEDGE_URI_PREFIX, help="OpenViking knowledge URI prefix")
args = parser.parse_args()
query = RetrievalQuery(
scenario=args.scenario,
alert_type=args.alert_type,
summary=args.summary,
top_k=args.top_k,
)
if args.backend == "openviking":
result = asyncio.run(retrieve_context_openviking(query, case_uri=args.case_uri, knowledge_uri=args.knowledge_uri))
else:
result = retrieve_context_local(query, cases_dir=args.cases_dir, knowledge_dir=args.knowledge_dir)
print(json.dumps(result, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()