Refactor app instance to Keycloak SSO

This commit is contained in:
2026-06-15 15:54:39 +08:00
parent fc9fd93c36
commit 461d1300ad
246 changed files with 1350 additions and 52721 deletions

View File

@ -7,7 +7,6 @@ import asyncio
import io
import mimetypes
import os
import secrets
import shutil
import time
import zipfile
@ -17,6 +16,8 @@ from pathlib import Path
from types import SimpleNamespace
from typing import Any
import httpx
from beaver.engine.providers.registry import PROVIDERS, find_by_name
from beaver.foundation.config import default_config_path, load_config
from beaver.foundation.events import ChannelIdentity, InboundMessage
@ -69,6 +70,12 @@ from .files import (
workspace_file_preview,
workspace_file_path,
)
from .keycloak_auth import (
KeycloakAuthConfig,
KeycloakIdentity,
KeycloakTokenVerifier,
extract_bearer_token,
)
from .schemas import (
WebChatAcceptanceRequest,
WebChatAcceptanceResponse,
@ -556,17 +563,22 @@ def create_app(
shutdown_force=shutdown_force,
),
)
app.state.auth_tokens = {}
app.state.handoff_codes = {}
app.state.auth_file = Path(os.getenv("BEAVER_AUTH_FILE") or "")
app.state.keycloak_auth_config = KeycloakAuthConfig.from_env()
app.state.keycloak_token_verifier = KeycloakTokenVerifier(config=app.state.keycloak_auth_config)
max_file_size = 50 * 1024 * 1024
max_user_file_upload_size = _int_env("BEAVER_USER_FILES_MAX_UPLOAD_BYTES", 5 * 1024 * 1024 * 1024)
user_file_upload_part_size = _int_env("BEAVER_USER_FILES_UPLOAD_PART_SIZE", 10 * 1024 * 1024)
def _user_file_resolver(request: Request, authorization: str | None) -> UserFileStorageResolver:
username = _require_web_user(app, authorization)
identity = _require_web_identity(app, authorization)
loaded = get_agent_service(request).create_loop().boot()
auth_context = build_file_auth_context(username=username, config=loaded.config)
auth_context = build_file_auth_context(
username=identity.username,
config=loaded.config,
user_id=identity.user_id,
scopes=identity.realm_roles + identity.client_roles,
auth_source="keycloak",
)
return UserFileStorageResolver(config=loaded.config, workspace=loaded.workspace, auth_context=auth_context)
async def _user_file_service(request: Request, authorization: str | None) -> UserFileService:
@ -970,168 +982,72 @@ def create_app(
_schedule_self_restart()
return JSONResponse({"ok": True, "restarting": True}, status_code=202)
@app.post("/api/auth/login")
async def auth_login(request: Request, payload: dict[str, Any]) -> dict[str, Any]:
username = _clean_text(payload.get("username"))
password = str(payload.get("password") or "")
if not username or not password:
raise HTTPException(status_code=400, detail="Username and password are required")
@app.post("/api/auth/callback")
async def auth_callback(request: Request, payload: dict[str, Any]) -> dict[str, Any]:
code = _clean_text(payload.get("code"))
code_verifier = _clean_text(payload.get("code_verifier"))
redirect_uri = _clean_text(payload.get("redirect_uri"))
nonce = _clean_text(payload.get("nonce")) or None
if not code or not code_verifier or not redirect_uri:
raise HTTPException(status_code=400, detail="code, code_verifier, and redirect_uri are required")
users = _load_auth_users(_auth_file_path())
expected = users.get(username)
if expected is None or not secrets.compare_digest(expected, password):
raise HTTPException(status_code=401, detail="Invalid username or password")
keycloak_config: KeycloakAuthConfig = app.state.keycloak_auth_config
try:
async with httpx.AsyncClient(timeout=15.0, trust_env=False) as client:
response = await client.post(
keycloak_config.token_url,
data={
"grant_type": "authorization_code",
"client_id": keycloak_config.client_id,
"code": code,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
},
headers={"Accept": "application/json"},
)
except httpx.HTTPError as exc:
raise HTTPException(status_code=502, detail=f"Keycloak token exchange failed: {exc}") from exc
if response.is_error:
raise HTTPException(status_code=401, detail=f"Keycloak token exchange rejected: {response.text}")
token_payload = response.json()
if not isinstance(token_payload, dict):
raise HTTPException(status_code=502, detail="Invalid Keycloak token response")
access_token = _clean_text(token_payload.get("access_token"))
id_token = _clean_text(token_payload.get("id_token"))
refresh_token = _clean_text(token_payload.get("refresh_token"))
if not access_token:
raise HTTPException(status_code=502, detail="Keycloak token response missing access_token")
token = _issue_web_token(app, username)
handoff_code, handoff_expires_at = _issue_handoff_code(app, username, token)
verifier: KeycloakTokenVerifier = app.state.keycloak_token_verifier
identity = verifier.verify(id_token, expected_nonce=nonce) if id_token else verifier.verify(access_token)
verifier.verify(access_token)
return {
"access_token": token,
"refresh_token": "",
"token_type": "bearer",
"user_id": username,
"username": username,
"access_token": access_token,
"id_token": id_token,
"refresh_token": refresh_token,
"expires_in": token_payload.get("expires_in"),
"token_type": token_payload.get("token_type") or "bearer",
"user_id": identity.user_id,
"username": identity.username,
"email": identity.email,
"role": "owner",
"handoff_code": handoff_code,
"handoff_expires_at": handoff_expires_at,
"backend_connection": _backend_connection_view(request),
"local_backend": _local_backend_view(),
}
@app.post("/api/auth/register")
async def auth_register(request: Request, payload: dict[str, Any]) -> dict[str, Any]:
username = _clean_text(payload.get("username"))
password = str(payload.get("password") or "")
email = _clean_text(payload.get("email")) or ""
if not username or not password:
raise HTTPException(status_code=400, detail="Username and password are required")
auth_file = _auth_file_path()
users = _load_auth_users_if_present(auth_file)
user_exists = username in users
if user_exists and not secrets.compare_digest(users[username], password):
raise HTTPException(
status_code=409,
detail="Username already exists. Use the existing password to finish setup or log in.",
)
agent_service = get_agent_service(request)
loaded = agent_service.create_loop().boot()
config = loaded.config
authz_base_url = _clean_text(payload.get("authz_base_url")) or (config.authz.base_url if config.authz.enabled else "")
backend_name = _clean_text(payload.get("backend_name")) or config.backend_identity.name or username
requested_backend_id = _clean_text(payload.get("backend_id")) or config.backend_identity.backend_id or None
public_base_url = (
_clean_text(payload.get("base_url"))
or config.backend_identity.public_base_url
or os.getenv("BEAVER_FRONTEND_PUBLIC_BASE_URL")
or str(request.base_url).rstrip("/")
)
frontend_base_url = _clean_text(payload.get("frontend_base_url")) or public_base_url
authz_user_registered = False
authz_backend_registered = False
local_backend: dict[str, Any] | None = None
if authz_base_url:
from beaver.integrations.authz import AuthzClient
try:
authz_payload = await AuthzClient(
authz_base_url,
timeout_seconds=config.authz.request_timeout_seconds,
).register_user(
username=username,
password=password,
email=email or None,
backend_name=backend_name,
backend_id=requested_backend_id,
base_url=public_base_url,
frontend_base_url=frontend_base_url,
)
except Exception as exc: # noqa: BLE001 - expose upstream setup failures to portal
raise HTTPException(status_code=502, detail=f"AuthZ registration failed: {exc}") from exc
backend = authz_payload.get("backend") if isinstance(authz_payload, dict) else {}
if isinstance(backend, dict):
backend_id = _clean_text(backend.get("backend_id")) or requested_backend_id
client_id = _clean_text(backend.get("client_id")) or backend_id
client_secret = _clean_text(backend.get("client_secret")) or config.backend_identity.client_secret
if backend_id and client_id and client_secret:
local_backend = _save_backend_identity(
agent_service,
config_path=config.config_path or default_config_path(workspace=loaded.workspace),
backend_id=backend_id,
client_id=client_id,
client_secret=client_secret,
name=_clean_text(backend.get("name")) or backend_name,
public_base_url=public_base_url,
authz_base_url=authz_base_url,
)
authz_backend_registered = True
authz_user_registered = bool(authz_payload)
if not user_exists:
users[username] = password
_save_auth_users(auth_file, users)
token = _issue_web_token(app, username)
handoff_code, handoff_expires_at = _issue_handoff_code(app, username, token)
backend_connection = {
**_backend_connection_view(request),
"public_base_url": public_base_url,
"api_base_url": public_base_url,
"frontend_base_url": frontend_base_url,
"registered": bool(local_backend),
}
if local_backend is not None:
backend_connection.update(
{
"backend_id": local_backend.get("backend_id"),
"client_id": local_backend.get("client_id"),
"name": local_backend.get("name"),
}
)
return {
"access_token": token,
"refresh_token": "",
"token_type": "bearer",
"user_id": username,
"username": username,
"email": email,
"role": "owner",
"handoff_code": handoff_code,
"handoff_expires_at": handoff_expires_at,
"existing_user": user_exists,
"authz": {
"enabled": bool(authz_base_url),
"base_url": authz_base_url or None,
"user_registered": authz_user_registered,
"backend_registered": authz_backend_registered,
},
"backend_connection": backend_connection,
"local_backend": local_backend or _local_backend_view(),
}
@app.post("/api/auth/handoff/consume")
async def auth_handoff_consume(payload: dict[str, Any]) -> dict[str, Any]:
return _consume_handoff_code(app, str(payload.get("code") or ""))
@app.get("/api/auth/me")
async def auth_me(authorization: str | None = Header(default=None)) -> dict[str, Any]:
username = _require_web_user(app, authorization)
identity = _require_web_identity(app, authorization)
return {
"id": username,
"username": username,
"email": os.getenv("BEAVER_BACKEND_IDENTITY__EMAIL", ""),
"id": identity.user_id,
"username": identity.username,
"email": identity.email,
"role": "owner",
"quota_tier": "single-user",
}
@app.post("/api/auth/logout")
async def auth_logout(authorization: str | None = Header(default=None)) -> dict[str, Any]:
if authorization and authorization.lower().startswith("bearer "):
token = authorization[7:].strip()
app.state.auth_tokens.pop(token, None)
async def auth_logout() -> dict[str, Any]:
return {"ok": True}
@app.post("/api/providers/{provider_name}/config", response_model=WebProviderConfigResponse)
@ -3288,82 +3204,6 @@ def _provider_enabled(provider_name: str, provider_cfg: Any) -> bool:
)
def _auth_file_path() -> Path:
raw = os.getenv("BEAVER_AUTH_FILE")
if raw:
return Path(raw)
return Path.home() / ".beaver" / "web_auth_users.json"
def _load_auth_users(path: Path) -> dict[str, str]:
if not path.exists():
raise HTTPException(status_code=500, detail=f"Auth file not found: {path}")
try:
raw = json.loads(path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise HTTPException(status_code=500, detail=f"Invalid auth file: {path}") from exc
users: dict[str, str] = {}
if isinstance(raw, dict):
entries = raw.get("users") or raw.get("accounts")
if isinstance(entries, list):
for entry in entries:
if not isinstance(entry, dict):
continue
username = _clean_text(entry.get("username"))
password = entry.get("password")
if username and isinstance(password, str):
users[username] = password
for key, value in raw.items():
if key in {"users", "accounts"}:
continue
username = _clean_text(key)
if username and isinstance(value, str):
users[username] = value
if not users:
raise HTTPException(status_code=500, detail=f"No valid users found in auth file: {path}")
return users
def _load_auth_users_if_present(path: Path) -> dict[str, str]:
if not path.exists():
return {}
return _load_auth_users(path)
def _save_auth_users(path: Path, users: dict[str, str]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"users": [
{"username": username, "password": password}
for username, password in sorted(users.items())
]
}
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
def _issue_web_token(app: FastAPI, username: str) -> str:
token = secrets.token_urlsafe(32)
app.state.auth_tokens[token] = username
return token
def _handoff_ttl_seconds() -> int:
raw = os.getenv("BEAVER_HANDOFF_CODE_TTL_SECONDS", "90").strip()
try:
return max(15, int(raw))
except ValueError:
return 90
def _handoff_replay_window_seconds() -> int:
raw = os.getenv("BEAVER_HANDOFF_REPLAY_WINDOW_SECONDS", "15").strip()
try:
return max(1, int(raw))
except ValueError:
return 15
def _int_env(name: str, default: int) -> int:
raw = os.getenv(name, "").strip()
if not raw:
@ -3385,81 +3225,10 @@ def _human_upload_size(size: int) -> str:
return f"{size}B"
def _prune_handoff_codes(app: FastAPI) -> None:
now = time.time()
replay_window = _handoff_replay_window_seconds()
expired = []
for code, payload in list(app.state.handoff_codes.items()):
expires_at = float(payload.get("expires_at") or 0)
consumed_at = payload.get("consumed_at")
if expires_at <= now:
expired.append(code)
elif consumed_at is not None and now - float(consumed_at) > replay_window:
expired.append(code)
for code in expired:
app.state.handoff_codes.pop(code, None)
def _issue_handoff_code(app: FastAPI, username: str, access_token: str, refresh_token: str = "") -> tuple[str, int]:
_prune_handoff_codes(app)
code = secrets.token_urlsafe(24)
expires_at = int(time.time()) + _handoff_ttl_seconds()
app.state.handoff_codes[code] = {
"username": username,
"access_token": access_token,
"refresh_token": refresh_token,
"expires_at": expires_at,
"consumed_at": None,
}
return code, expires_at
def _consume_handoff_code(app: FastAPI, code: str) -> dict[str, Any]:
if not code.strip():
raise HTTPException(status_code=400, detail="Handoff code is required")
_prune_handoff_codes(app)
payload = app.state.handoff_codes.get(code)
if payload is None:
raise HTTPException(status_code=401, detail="Invalid or expired handoff code")
now = time.time()
expires_at = float(payload.get("expires_at") or 0)
if expires_at <= now:
app.state.handoff_codes.pop(code, None)
raise HTTPException(status_code=410, detail="Handoff code expired")
consumed_at = payload.get("consumed_at")
if consumed_at is None:
payload["consumed_at"] = now
elif now - float(consumed_at) > _handoff_replay_window_seconds():
app.state.handoff_codes.pop(code, None)
raise HTTPException(status_code=410, detail="Handoff code already used")
username = str(payload.get("username") or "").strip()
access_token = str(payload.get("access_token") or "").strip()
if not username or not access_token:
app.state.handoff_codes.pop(code, None)
raise HTTPException(status_code=401, detail="Invalid handoff payload")
return {
"access_token": access_token,
"refresh_token": str(payload.get("refresh_token") or ""),
"token_type": "bearer",
"user_id": username,
"username": username,
"role": "owner",
}
def _require_web_user(app: FastAPI, authorization: str | None) -> str:
if not authorization:
raise HTTPException(status_code=401, detail="Missing Authorization header")
prefix = "bearer "
if not authorization.lower().startswith(prefix):
raise HTTPException(status_code=401, detail="Invalid Authorization header")
token = authorization[len(prefix):].strip()
if not token:
raise HTTPException(status_code=401, detail="Invalid token")
username = app.state.auth_tokens.get(token)
if not username:
raise HTTPException(status_code=401, detail="Invalid or expired token")
return username
def _require_web_identity(app: FastAPI, authorization: str | None) -> KeycloakIdentity:
token = extract_bearer_token(authorization)
verifier: KeycloakTokenVerifier = app.state.keycloak_token_verifier
return verifier.verify(token)
def _backend_connection_view(request: Request) -> dict[str, Any]:

View File

@ -0,0 +1,152 @@
"""Keycloak OIDC token verification for the Beaver web app."""
from __future__ import annotations
from dataclasses import dataclass
import os
import time
from typing import Any
import jwt
from jwt import PyJWKClient
try:
from fastapi import HTTPException
except ModuleNotFoundError: # pragma: no cover
class HTTPException(Exception): # type: ignore[override]
def __init__(self, status_code: int, detail: str) -> None:
super().__init__(detail)
self.status_code = status_code
self.detail = detail
DEFAULT_KEYCLOAK_ISSUER = "https://keycloak.bwgdi.com/realms/beaver"
DEFAULT_KEYCLOAK_CLIENT_ID = "beaver-agnet"
@dataclass(frozen=True, slots=True)
class KeycloakAuthConfig:
issuer: str
client_id: str
token_url: str
jwks_url: str
@classmethod
def from_env(cls) -> "KeycloakAuthConfig":
issuer = _clean_base_url(os.getenv("BEAVER_KEYCLOAK_ISSUER") or DEFAULT_KEYCLOAK_ISSUER)
client_id = (os.getenv("BEAVER_KEYCLOAK_CLIENT_ID") or DEFAULT_KEYCLOAK_CLIENT_ID).strip()
token_url = (
os.getenv("BEAVER_KEYCLOAK_TOKEN_URL", "").strip()
or f"{issuer}/protocol/openid-connect/token"
)
jwks_url = (
os.getenv("BEAVER_KEYCLOAK_JWKS_URL", "").strip()
or f"{issuer}/protocol/openid-connect/certs"
)
return cls(issuer=issuer, client_id=client_id, token_url=token_url, jwks_url=jwks_url)
@dataclass(frozen=True, slots=True)
class KeycloakIdentity:
user_id: str
username: str
email: str = ""
name: str = ""
realm_roles: tuple[str, ...] = ()
client_roles: tuple[str, ...] = ()
def extract_bearer_token(authorization: str | None) -> str:
if not authorization:
raise HTTPException(status_code=401, detail="Missing Authorization header")
prefix = "bearer "
if not authorization.lower().startswith(prefix):
raise HTTPException(status_code=401, detail="Invalid Authorization header")
token = authorization[len(prefix):].strip()
if not token:
raise HTTPException(status_code=401, detail="Invalid token")
return token
class KeycloakTokenVerifier:
def __init__(self, *, config: KeycloakAuthConfig) -> None:
self.config = config
self._jwks_client = PyJWKClient(config.jwks_url)
def verify(self, token: str, *, expected_nonce: str | None = None) -> KeycloakIdentity:
try:
signing_key = self._jwks_client.get_signing_key_from_jwt(token).key
claims = jwt.decode(
token,
signing_key,
algorithms=["RS256"],
issuer=self.config.issuer,
options={
"require": ["exp", "iat", "iss"],
"verify_aud": False,
},
)
except Exception as exc: # noqa: BLE001 - normalize JWT/JWKS failures for HTTP callers
raise HTTPException(status_code=401, detail=f"Invalid token: {exc}") from exc
return self.validate_claims(claims, expected_nonce=expected_nonce)
def validate_claims(self, claims: dict[str, Any], *, expected_nonce: str | None = None) -> KeycloakIdentity:
now = int(time.time())
issuer = str(claims.get("iss") or "")
if issuer != self.config.issuer:
raise HTTPException(status_code=401, detail="Invalid token issuer")
exp = _int_claim(claims, "exp")
iat = _int_claim(claims, "iat")
if exp <= now:
raise HTTPException(status_code=401, detail="Token expired")
if iat > now + 120:
raise HTTPException(status_code=401, detail="Token issued in the future")
if not _matches_client(claims.get("aud"), self.config.client_id) and claims.get("azp") != self.config.client_id:
raise HTTPException(status_code=401, detail="Invalid token audience")
if expected_nonce is not None and claims.get("nonce") != expected_nonce:
raise HTTPException(status_code=401, detail="Invalid token nonce")
user_id = str(claims.get("sub") or "").strip()
if not user_id:
raise HTTPException(status_code=401, detail="Token subject is required")
username = (
str(claims.get("preferred_username") or "").strip()
or str(claims.get("email") or "").strip()
or user_id
)
return KeycloakIdentity(
user_id=user_id,
username=username,
email=str(claims.get("email") or "").strip(),
name=str(claims.get("name") or "").strip(),
realm_roles=_roles_from(claims.get("realm_access")),
client_roles=_roles_from((claims.get("resource_access") or {}).get(self.config.client_id) if isinstance(claims.get("resource_access"), dict) else None),
)
def _clean_base_url(value: str) -> str:
return value.strip().rstrip("/")
def _int_claim(claims: dict[str, Any], key: str) -> int:
try:
return int(claims[key])
except (KeyError, TypeError, ValueError) as exc:
raise HTTPException(status_code=401, detail=f"Token {key} claim is required") from exc
def _matches_client(audience: Any, client_id: str) -> bool:
if isinstance(audience, str):
return audience == client_id
if isinstance(audience, list):
return client_id in {str(item) for item in audience}
return False
def _roles_from(value: Any) -> tuple[str, ...]:
if not isinstance(value, dict):
return ()
roles = value.get("roles")
if not isinstance(roles, list):
return ()
return tuple(str(role) for role in roles if str(role).strip())