feat(tasks): add skill-templated task graph execution
This commit is contained in:
@ -5,10 +5,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from html import unescape
|
||||
from html.parser import HTMLParser
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import quote_plus, urlparse
|
||||
from urllib.parse import quote_plus, urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
@ -24,6 +25,10 @@ def _strip_html(value: str) -> str:
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
def _compact_text(value: str) -> str:
|
||||
return re.sub(r"\s+", " ", unescape(value)).strip()
|
||||
|
||||
|
||||
def _safe_url(url: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
@ -31,6 +36,77 @@ def _safe_url(url: str) -> str:
|
||||
return url
|
||||
|
||||
|
||||
class _HtmlMetadataParser(HTMLParser):
|
||||
def __init__(self, base_url: str) -> None:
|
||||
super().__init__(convert_charrefs=True)
|
||||
self.base_url = base_url
|
||||
self.title = ""
|
||||
self.links: list[dict[str, str]] = []
|
||||
self._in_title = False
|
||||
self._current_href: str | None = None
|
||||
self._current_text: list[str] = []
|
||||
self._skip_depth = 0
|
||||
self._seen_urls: set[str] = set()
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
lowered = tag.lower()
|
||||
if lowered in {"script", "style"}:
|
||||
self._skip_depth += 1
|
||||
return
|
||||
if self._skip_depth:
|
||||
return
|
||||
if lowered == "title":
|
||||
self._in_title = True
|
||||
return
|
||||
if lowered == "a":
|
||||
href = dict(attrs).get("href")
|
||||
if href:
|
||||
self._current_href = urljoin(self.base_url, href)
|
||||
self._current_text = []
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
lowered = tag.lower()
|
||||
if lowered in {"script", "style"} and self._skip_depth:
|
||||
self._skip_depth -= 1
|
||||
return
|
||||
if self._skip_depth:
|
||||
return
|
||||
if lowered == "title":
|
||||
self._in_title = False
|
||||
self.title = _compact_text(self.title)
|
||||
return
|
||||
if lowered == "a" and self._current_href:
|
||||
parsed = urlparse(self._current_href)
|
||||
if parsed.scheme in {"http", "https"} and self._current_href not in self._seen_urls:
|
||||
text = _compact_text(" ".join(self._current_text))
|
||||
self.links.append({"text": text, "url": self._current_href})
|
||||
self._seen_urls.add(self._current_href)
|
||||
self._current_href = None
|
||||
self._current_text = []
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._skip_depth:
|
||||
return
|
||||
if self._in_title:
|
||||
self.title += data
|
||||
if self._current_href:
|
||||
self._current_text.append(data)
|
||||
|
||||
|
||||
def _extract_html_metadata(html: str, base_url: str, *, max_links: int = 80) -> dict[str, Any]:
|
||||
parser = _HtmlMetadataParser(base_url)
|
||||
parser.feed(html)
|
||||
links = parser.links[:max_links]
|
||||
pdf_links = [
|
||||
link for link in links if urlparse(link["url"]).path.lower().endswith(".pdf")
|
||||
][:30]
|
||||
return {
|
||||
"title": parser.title,
|
||||
"links": links,
|
||||
"pdf_links": pdf_links,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WebFetchTool:
|
||||
name: str = "web_fetch"
|
||||
@ -61,13 +137,20 @@ class WebFetchTool:
|
||||
response.raise_for_status()
|
||||
content_type = response.headers.get("content-type", "")
|
||||
raw = response.text
|
||||
text = _strip_html(raw) if "html" in content_type.lower() else raw
|
||||
is_html = "html" in content_type.lower()
|
||||
text = _strip_html(raw) if is_html else raw
|
||||
metadata = _extract_html_metadata(raw, str(response.url)) if is_html else {
|
||||
"title": "",
|
||||
"links": [],
|
||||
"pdf_links": [],
|
||||
}
|
||||
truncated = len(text) > limit
|
||||
return _json_result(
|
||||
True,
|
||||
url=str(response.url),
|
||||
status_code=response.status_code,
|
||||
content_type=content_type,
|
||||
**metadata,
|
||||
content=text[:limit],
|
||||
truncated=truncated,
|
||||
)
|
||||
@ -97,6 +180,15 @@ class WebSearchTool:
|
||||
if not str(query).strip():
|
||||
raise ValueError("query is required")
|
||||
bounded = max(1, min(int(limit or 5), 10))
|
||||
errors: list[str] = []
|
||||
try:
|
||||
ddgs_results = await asyncio.to_thread(_search_ddgs, query, bounded)
|
||||
except Exception as exc:
|
||||
ddgs_results = []
|
||||
errors.append(str(exc))
|
||||
if ddgs_results:
|
||||
return _json_result(True, **_search_result_payload(query, "ddgs", ddgs_results))
|
||||
|
||||
headers = {"User-Agent": "Mozilla/5.0 Beaver/1.0"}
|
||||
timeout = httpx.Timeout(connect=5, read=8, write=5, pool=5)
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True, trust_env=True) as client:
|
||||
@ -118,7 +210,6 @@ class WebSearchTool:
|
||||
)
|
||||
),
|
||||
]
|
||||
errors: list[str] = []
|
||||
try:
|
||||
for completed in asyncio.as_completed(tasks):
|
||||
try:
|
||||
@ -127,7 +218,7 @@ class WebSearchTool:
|
||||
errors.append(str(exc))
|
||||
continue
|
||||
if results:
|
||||
return _json_result(True, query=query, engine=engine, results=results)
|
||||
return _json_result(True, **_search_result_payload(query, engine, results))
|
||||
detail = "; ".join(error for error in errors if error) or "no search results"
|
||||
return _json_result(False, query=query, error=detail)
|
||||
finally:
|
||||
@ -182,6 +273,62 @@ def _parse_bing_results(html: str, limit: int) -> list[dict[str, str]]:
|
||||
return results
|
||||
|
||||
|
||||
def _search_ddgs(query: str, limit: int) -> list[dict[str, str]]:
|
||||
from ddgs import DDGS # type: ignore[import-not-found]
|
||||
|
||||
rows = DDGS().text(query, max_results=limit)
|
||||
results: list[dict[str, str]] = []
|
||||
for row in rows or []:
|
||||
title = _compact_text(str(row.get("title") or ""))
|
||||
result_url = str(row.get("href") or row.get("url") or "").strip()
|
||||
snippet = _compact_text(str(row.get("body") or row.get("snippet") or ""))
|
||||
if title and result_url:
|
||||
results.append({"title": title, "url": result_url, "snippet": snippet})
|
||||
if len(results) >= limit:
|
||||
break
|
||||
return results
|
||||
|
||||
|
||||
def _search_result_payload(query: str, engine: str, results: list[dict[str, str]]) -> dict[str, Any]:
|
||||
quality, reason = _assess_search_quality(query, results)
|
||||
payload: dict[str, Any] = {
|
||||
"query": query,
|
||||
"engine": engine,
|
||||
"quality": quality,
|
||||
"results": results,
|
||||
}
|
||||
if reason:
|
||||
payload["low_relevance_reason"] = reason
|
||||
return payload
|
||||
|
||||
|
||||
def _search_terms(value: str) -> set[str]:
|
||||
return {
|
||||
term
|
||||
for term in re.findall(r"[a-z0-9]+", value.lower())
|
||||
if len(term) > 2
|
||||
}
|
||||
|
||||
|
||||
def _assess_search_quality(query: str, results: list[dict[str, str]]) -> tuple[str, str | None]:
|
||||
terms = _search_terms(query)
|
||||
if not terms:
|
||||
return "high", None
|
||||
required_overlap = min(2, len(terms))
|
||||
for result in results:
|
||||
haystack = " ".join(
|
||||
[
|
||||
result.get("title", ""),
|
||||
result.get("snippet", ""),
|
||||
urlparse(result.get("url", "")).netloc,
|
||||
urlparse(result.get("url", "")).path,
|
||||
]
|
||||
)
|
||||
if len(terms & _search_terms(haystack)) >= required_overlap:
|
||||
return "high", None
|
||||
return "low", "results do not overlap enough with query terms"
|
||||
|
||||
|
||||
def _parse_duckduckgo_results(html: str, limit: int) -> list[dict[str, str]]:
|
||||
results: list[dict[str, str]] = []
|
||||
pattern = re.compile(
|
||||
|
||||
@ -37,6 +37,14 @@ class ToolExecutor:
|
||||
) -> ToolResult:
|
||||
"""按工具名执行一次调用。"""
|
||||
|
||||
allowed = context.metadata.get("allowed_tool_names") if context is not None else None
|
||||
if isinstance(allowed, list) and tool_name not in allowed:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
content=f"Tool {tool_name} is not allowed for this node.",
|
||||
tool_name=tool_name,
|
||||
error="tool_not_allowed",
|
||||
)
|
||||
tool = self.registry.get(tool_name)
|
||||
if tool is None:
|
||||
return ToolResult(
|
||||
|
||||
Reference in New Issue
Block a user