#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os import urllib.error import urllib.request from pathlib import Path from typing import Any DEFAULT_GATEWAY_URL = os.environ.get("SOC_MEMORY_GATEWAY_URL", "http://127.0.0.1:1934") DEFAULT_GATEWAY_API_KEY = os.environ.get("SOC_MEMORY_GATEWAY_API_KEY", "") DEFAULT_POC_ROOT = os.environ.get("SOC_MEMORY_POC_ROOT", "/home/tom/soc_memory_poc") DEFAULT_VAULT_ROOT = str(Path(DEFAULT_POC_ROOT) / "obsidian-vault") CASE_URI = "viking://resources/soc-memory-poc/case" KNOWLEDGE_URI = "viking://resources/soc-memory-poc/knowledge" def post_json(url: str, payload: dict[str, Any], api_key: str = "") -> dict[str, Any]: data = json.dumps(payload).encode("utf-8") req = urllib.request.Request(url, data=data, method="POST") req.add_header("Content-Type", "application/json") if api_key: req.add_header("X-API-Key", api_key) with urllib.request.urlopen(req, timeout=30) as resp: return json.loads(resp.read().decode("utf-8")) def canonicalize_uri(uri: str) -> str: if ".json/" in uri: return uri.split(".json/", 1)[0] + ".json" return uri def filter_results(results: list[dict[str, Any]], prefix: str) -> list[dict[str, Any]]: deduped: dict[str, dict[str, Any]] = {} for item in results: uri = item.get("uri") or "" canonical = canonicalize_uri(uri) if not canonical.startswith(prefix): continue score = item.get("score") or 0 payload = dict(item) payload["uri"] = canonical if canonical not in deduped or score > (deduped[canonical].get("score") or 0): deduped[canonical] = payload return sorted(deduped.values(), key=lambda entry: entry.get("score") or 0, reverse=True) def gateway_search(query: str, uri: str, limit: int, gateway_url: str, api_key: str) -> list[dict[str, Any]]: payload = {"query": query, "limit": max(limit * 5, 10), "uri": uri} raw = post_json(gateway_url.rstrip("/") + "/api/search", payload, api_key=api_key) return filter_results(raw.get("results", []), uri)[:limit] def obsidian_search(query: str, scenario: str, limit: int, vault_root: str) -> dict[str, Any]: from search_obsidian_docs import load_docs, score_doc, tokenize docs = load_docs(vault_root) tokens = tokenize(query) results: list[dict[str, Any]] = [] for doc in docs: doc_scenario = doc.get("frontmatter", {}).get("scenario", "") if scenario and doc_scenario != scenario: continue score, matched_terms = score_doc(query, tokens, doc) if score <= 0: continue results.append( { "score": score, "title": doc["title"], "file_name": doc["file_name"], "relative_path": doc["relative_path"], "directory": doc["directory"], "absolute_path": str(Path(vault_root) / doc["relative_path"]), "summary": doc.get("summary", ""), "matched_terms": matched_terms, } ) results.sort(key=lambda item: item["score"], reverse=True) return {"matched_docs": results[:limit]} def build_query(args: argparse.Namespace) -> str: parts = [ args.scenario, args.alert_type, args.user, args.host, args.sender, args.subject, args.attachment, args.url, args.ip, args.summary, ] parts.extend(args.fact) return " ".join(part.strip() for part in parts if part and part.strip()) def bullet(lines: list[str], fallback: str) -> str: if not lines: return f"- {fallback}" return "\n".join(f"- {line}" for line in lines) def top_results(items: list[dict[str, Any]], limit: int = 3) -> list[dict[str, Any]]: return items[:limit] def has_fact(args: argparse.Namespace, needle: str) -> bool: haystacks = [args.summary, args.subject, args.alert_type, *args.fact] lowered = needle.lower() return any(lowered in (item or "").lower() for item in haystacks) def summarize_evidence(args: argparse.Namespace) -> list[str]: evidence: list[str] = [] if args.subject: evidence.append(f"邮件主题/诱饵:{args.subject}") if args.attachment: evidence.append(f"恶意附件:{args.attachment}") if args.url: evidence.append(f"可疑链接:{args.url}") if args.sender: evidence.append(f"发件人:{args.sender}") if args.ip: evidence.append(f"相关 IP:{args.ip}") for fact in args.fact[:4]: evidence.append(fact) return evidence[:6] def uri_to_id(uri: str) -> str: return uri.rsplit('/', 1)[-1].replace('.json', '') def infer_assessment(args: argparse.Namespace, case_results: list[dict[str, Any]]) -> str: top_case = case_results[0] if case_results else None if args.scenario == "phishing": if args.url and args.attachment and (has_fact(args, "dmarc failed") or has_fact(args, "clicked")): base = "当前告警高度符合凭证收割型钓鱼攻击特征,属于高可信 True Positive,且存在凭证泄露风险。" elif args.url or args.attachment: base = "当前告警具备明显钓鱼迹象,尤其是附件与落地页组合,倾向于高风险钓鱼事件。" else: base = "当前告警呈现出邮件钓鱼模式,但仍需补充落地页、附件和用户交互证据进一步确认。" elif args.scenario == "o365_suspicious_login": if has_fact(args, "impossible travel") and (has_fact(args, "mfa fatigue") or has_fact(args, "inbox rule") or has_fact(args, "oauth")): base = "当前告警高度符合 O365 账号接管链路,属于高可信身份威胁事件。" else: base = "当前告警表现为异常身份登录,需要结合登录轨迹、MFA 和邮箱规则进一步确认是否账号接管。" else: base = "当前告警具备明显的可疑特征,需要结合历史案例和关联知识继续判断。" if top_case: return base + f" 最相近的历史案例为 `{uri_to_id(top_case.get('uri', ''))}`,说明当前 case 与既有攻击模式存在明显重合。" return base def format_memory_results(case_results: list[dict[str, Any]], knowledge_results: list[dict[str, Any]]) -> str: lines: list[str] = [] for item in top_results(case_results, 2): uri = item.get("uri", "") abstract = (item.get("abstract") or "").strip() snippet = abstract[:140] + "..." if len(abstract) > 140 else abstract lines.append(f"`{uri_to_id(uri)}`({uri})— {snippet}") for item in top_results(knowledge_results, 2): uri = item.get("uri", "") abstract = (item.get("abstract") or "").strip() snippet = abstract[:140] + "..." if len(abstract) > 140 else abstract lines.append(f"`{uri_to_id(uri)}`({uri})— {snippet}") return bullet(lines, "未检索到直接关联的 Memory 条目") def format_obsidian_results(obsidian_docs: list[dict[str, Any]]) -> str: lines = [] for doc in top_results(obsidian_docs, 3): reason = doc.get("summary") or ", ".join(doc.get("matched_terms", [])) or "与当前场景相关" lines.append( f"`{doc['file_name']}` — `obsidian-vault/{doc['relative_path']}` " f"(absolute: `{doc['absolute_path']}`)— {reason}" ) return bullet(lines, "未找到直接关联的 Obsidian 文档") def recommend_actions(args: argparse.Namespace, case_results: list[dict[str, Any]]) -> list[str]: actions: list[str] = [] if args.scenario == "phishing": actions.extend([ "检查用户是否已点击链接或提交凭据,必要时立即重置账号并撤销会话。", "搜索同主题、同发件人、同 URL 或同附件的邮件是否已投递给其他用户。", "封锁相关域名、URL 和可疑 IP,并保留附件样本用于沙箱分析。", "如邮件面向财务或高价值角色,优先排查是否存在 BEC 或后续横向利用。", ]) elif args.scenario == "o365_suspicious_login": actions.extend([ "复核登录日志、MFA 记录和后续邮箱规则 / OAuth 变更。", "若确认账号接管迹象,立即重置凭据并撤销所有活跃会话。", "检查同源 IP、同设备指纹和同时间窗口内的其他用户活动。", "对邮箱转发、隐藏规则、恶意 OAuth 授权进行专项排查。", ]) else: actions.append("基于当前高风险迹象继续扩充调查和处置。") if case_results: actions.append("对照最相近历史案例,复用已有 IOC 和调查路径。") return actions[:5] def main() -> None: parser = argparse.ArgumentParser(description="Run a structured SOC triage using memory retrieval and Obsidian lookup.") parser.add_argument("--scenario", required=True, help="Scenario, e.g. phishing or o365_suspicious_login") parser.add_argument("--alert-type", default="", help="Alert type") parser.add_argument("--user", default="", help="Target user") parser.add_argument("--host", default="", help="Target host") parser.add_argument("--sender", default="", help="Sender email") parser.add_argument("--subject", default="", help="Email subject or short title") parser.add_argument("--attachment", default="", help="Attachment name") parser.add_argument("--url", default="", help="Suspicious URL") parser.add_argument("--ip", default="", help="Relevant IP") parser.add_argument("--summary", default="", help="One-sentence alert summary") parser.add_argument("--fact", action="append", default=[], help="Additional known fact; repeatable") parser.add_argument("--gateway-url", default=DEFAULT_GATEWAY_URL, help="Memory Gateway URL") parser.add_argument("--api-key", default=DEFAULT_GATEWAY_API_KEY, help="Memory Gateway API key") parser.add_argument("--vault-root", default=DEFAULT_VAULT_ROOT, help="Obsidian vault root") parser.add_argument("--limit", type=int, default=5, help="Search limit") args = parser.parse_args() query = build_query(args) case_results: list[dict[str, Any]] = [] knowledge_results: list[dict[str, Any]] = [] obsidian_docs: list[dict[str, Any]] = [] memory_error = "" obsidian_error = "" try: case_results = gateway_search(query, CASE_URI, args.limit, args.gateway_url, args.api_key) knowledge_results = gateway_search(query, KNOWLEDGE_URI, args.limit, args.gateway_url, args.api_key) except urllib.error.URLError as exc: memory_error = f"Memory Gateway 不可用:{exc}" try: obsidian_resp = obsidian_search(query, args.scenario, args.limit, args.vault_root) obsidian_docs = obsidian_resp.get("matched_docs", []) except Exception as exc: # noqa: BLE001 obsidian_error = f"Obsidian 检索失败:{exc}" lines = [ "## 研判结果", infer_assessment(args, case_results), "", "## 关键证据", bullet(summarize_evidence(args), "当前输入只提供了有限证据,需要继续补充调查信息"), "", "## 关联 Memory Retrieval", ] if memory_error: lines.append(f"- {memory_error}") else: lines.append(format_memory_results(case_results, knowledge_results)) lines.extend([ "", "## 关联 Obsidian 文档", ]) if obsidian_error: lines.append(f"- {obsidian_error}") else: lines.append(format_obsidian_results(obsidian_docs)) lines.extend([ "", "## 建议动作", bullet(recommend_actions(args, case_results), "继续补充告警细节后再执行更精确的响应动作"), ]) print("\n".join(lines)) if __name__ == "__main__": main()