Refactor app instance to Keycloak SSO
This commit is contained in:
@ -7,7 +7,6 @@ import asyncio
|
||||
import io
|
||||
import mimetypes
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
import time
|
||||
import zipfile
|
||||
@ -17,6 +16,8 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from beaver.engine.providers.registry import PROVIDERS, find_by_name
|
||||
from beaver.foundation.config import default_config_path, load_config
|
||||
from beaver.foundation.events import ChannelIdentity, InboundMessage
|
||||
@ -69,6 +70,12 @@ from .files import (
|
||||
workspace_file_preview,
|
||||
workspace_file_path,
|
||||
)
|
||||
from .keycloak_auth import (
|
||||
KeycloakAuthConfig,
|
||||
KeycloakIdentity,
|
||||
KeycloakTokenVerifier,
|
||||
extract_bearer_token,
|
||||
)
|
||||
from .schemas import (
|
||||
WebChatAcceptanceRequest,
|
||||
WebChatAcceptanceResponse,
|
||||
@ -556,17 +563,22 @@ def create_app(
|
||||
shutdown_force=shutdown_force,
|
||||
),
|
||||
)
|
||||
app.state.auth_tokens = {}
|
||||
app.state.handoff_codes = {}
|
||||
app.state.auth_file = Path(os.getenv("BEAVER_AUTH_FILE") or "")
|
||||
app.state.keycloak_auth_config = KeycloakAuthConfig.from_env()
|
||||
app.state.keycloak_token_verifier = KeycloakTokenVerifier(config=app.state.keycloak_auth_config)
|
||||
max_file_size = 50 * 1024 * 1024
|
||||
max_user_file_upload_size = _int_env("BEAVER_USER_FILES_MAX_UPLOAD_BYTES", 5 * 1024 * 1024 * 1024)
|
||||
user_file_upload_part_size = _int_env("BEAVER_USER_FILES_UPLOAD_PART_SIZE", 10 * 1024 * 1024)
|
||||
|
||||
def _user_file_resolver(request: Request, authorization: str | None) -> UserFileStorageResolver:
|
||||
username = _require_web_user(app, authorization)
|
||||
identity = _require_web_identity(app, authorization)
|
||||
loaded = get_agent_service(request).create_loop().boot()
|
||||
auth_context = build_file_auth_context(username=username, config=loaded.config)
|
||||
auth_context = build_file_auth_context(
|
||||
username=identity.username,
|
||||
config=loaded.config,
|
||||
user_id=identity.user_id,
|
||||
scopes=identity.realm_roles + identity.client_roles,
|
||||
auth_source="keycloak",
|
||||
)
|
||||
return UserFileStorageResolver(config=loaded.config, workspace=loaded.workspace, auth_context=auth_context)
|
||||
|
||||
async def _user_file_service(request: Request, authorization: str | None) -> UserFileService:
|
||||
@ -970,168 +982,72 @@ def create_app(
|
||||
_schedule_self_restart()
|
||||
return JSONResponse({"ok": True, "restarting": True}, status_code=202)
|
||||
|
||||
@app.post("/api/auth/login")
|
||||
async def auth_login(request: Request, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
username = _clean_text(payload.get("username"))
|
||||
password = str(payload.get("password") or "")
|
||||
if not username or not password:
|
||||
raise HTTPException(status_code=400, detail="Username and password are required")
|
||||
@app.post("/api/auth/callback")
|
||||
async def auth_callback(request: Request, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
code = _clean_text(payload.get("code"))
|
||||
code_verifier = _clean_text(payload.get("code_verifier"))
|
||||
redirect_uri = _clean_text(payload.get("redirect_uri"))
|
||||
nonce = _clean_text(payload.get("nonce")) or None
|
||||
if not code or not code_verifier or not redirect_uri:
|
||||
raise HTTPException(status_code=400, detail="code, code_verifier, and redirect_uri are required")
|
||||
|
||||
users = _load_auth_users(_auth_file_path())
|
||||
expected = users.get(username)
|
||||
if expected is None or not secrets.compare_digest(expected, password):
|
||||
raise HTTPException(status_code=401, detail="Invalid username or password")
|
||||
keycloak_config: KeycloakAuthConfig = app.state.keycloak_auth_config
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=15.0, trust_env=False) as client:
|
||||
response = await client.post(
|
||||
keycloak_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": keycloak_config.client_id,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Keycloak token exchange failed: {exc}") from exc
|
||||
if response.is_error:
|
||||
raise HTTPException(status_code=401, detail=f"Keycloak token exchange rejected: {response.text}")
|
||||
token_payload = response.json()
|
||||
if not isinstance(token_payload, dict):
|
||||
raise HTTPException(status_code=502, detail="Invalid Keycloak token response")
|
||||
access_token = _clean_text(token_payload.get("access_token"))
|
||||
id_token = _clean_text(token_payload.get("id_token"))
|
||||
refresh_token = _clean_text(token_payload.get("refresh_token"))
|
||||
if not access_token:
|
||||
raise HTTPException(status_code=502, detail="Keycloak token response missing access_token")
|
||||
|
||||
token = _issue_web_token(app, username)
|
||||
handoff_code, handoff_expires_at = _issue_handoff_code(app, username, token)
|
||||
verifier: KeycloakTokenVerifier = app.state.keycloak_token_verifier
|
||||
identity = verifier.verify(id_token, expected_nonce=nonce) if id_token else verifier.verify(access_token)
|
||||
verifier.verify(access_token)
|
||||
return {
|
||||
"access_token": token,
|
||||
"refresh_token": "",
|
||||
"token_type": "bearer",
|
||||
"user_id": username,
|
||||
"username": username,
|
||||
"access_token": access_token,
|
||||
"id_token": id_token,
|
||||
"refresh_token": refresh_token,
|
||||
"expires_in": token_payload.get("expires_in"),
|
||||
"token_type": token_payload.get("token_type") or "bearer",
|
||||
"user_id": identity.user_id,
|
||||
"username": identity.username,
|
||||
"email": identity.email,
|
||||
"role": "owner",
|
||||
"handoff_code": handoff_code,
|
||||
"handoff_expires_at": handoff_expires_at,
|
||||
"backend_connection": _backend_connection_view(request),
|
||||
"local_backend": _local_backend_view(),
|
||||
}
|
||||
|
||||
@app.post("/api/auth/register")
|
||||
async def auth_register(request: Request, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
username = _clean_text(payload.get("username"))
|
||||
password = str(payload.get("password") or "")
|
||||
email = _clean_text(payload.get("email")) or ""
|
||||
if not username or not password:
|
||||
raise HTTPException(status_code=400, detail="Username and password are required")
|
||||
|
||||
auth_file = _auth_file_path()
|
||||
users = _load_auth_users_if_present(auth_file)
|
||||
user_exists = username in users
|
||||
if user_exists and not secrets.compare_digest(users[username], password):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="Username already exists. Use the existing password to finish setup or log in.",
|
||||
)
|
||||
|
||||
agent_service = get_agent_service(request)
|
||||
loaded = agent_service.create_loop().boot()
|
||||
config = loaded.config
|
||||
authz_base_url = _clean_text(payload.get("authz_base_url")) or (config.authz.base_url if config.authz.enabled else "")
|
||||
backend_name = _clean_text(payload.get("backend_name")) or config.backend_identity.name or username
|
||||
requested_backend_id = _clean_text(payload.get("backend_id")) or config.backend_identity.backend_id or None
|
||||
public_base_url = (
|
||||
_clean_text(payload.get("base_url"))
|
||||
or config.backend_identity.public_base_url
|
||||
or os.getenv("BEAVER_FRONTEND_PUBLIC_BASE_URL")
|
||||
or str(request.base_url).rstrip("/")
|
||||
)
|
||||
frontend_base_url = _clean_text(payload.get("frontend_base_url")) or public_base_url
|
||||
|
||||
authz_user_registered = False
|
||||
authz_backend_registered = False
|
||||
local_backend: dict[str, Any] | None = None
|
||||
|
||||
if authz_base_url:
|
||||
from beaver.integrations.authz import AuthzClient
|
||||
|
||||
try:
|
||||
authz_payload = await AuthzClient(
|
||||
authz_base_url,
|
||||
timeout_seconds=config.authz.request_timeout_seconds,
|
||||
).register_user(
|
||||
username=username,
|
||||
password=password,
|
||||
email=email or None,
|
||||
backend_name=backend_name,
|
||||
backend_id=requested_backend_id,
|
||||
base_url=public_base_url,
|
||||
frontend_base_url=frontend_base_url,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - expose upstream setup failures to portal
|
||||
raise HTTPException(status_code=502, detail=f"AuthZ registration failed: {exc}") from exc
|
||||
|
||||
backend = authz_payload.get("backend") if isinstance(authz_payload, dict) else {}
|
||||
if isinstance(backend, dict):
|
||||
backend_id = _clean_text(backend.get("backend_id")) or requested_backend_id
|
||||
client_id = _clean_text(backend.get("client_id")) or backend_id
|
||||
client_secret = _clean_text(backend.get("client_secret")) or config.backend_identity.client_secret
|
||||
if backend_id and client_id and client_secret:
|
||||
local_backend = _save_backend_identity(
|
||||
agent_service,
|
||||
config_path=config.config_path or default_config_path(workspace=loaded.workspace),
|
||||
backend_id=backend_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
name=_clean_text(backend.get("name")) or backend_name,
|
||||
public_base_url=public_base_url,
|
||||
authz_base_url=authz_base_url,
|
||||
)
|
||||
authz_backend_registered = True
|
||||
authz_user_registered = bool(authz_payload)
|
||||
|
||||
if not user_exists:
|
||||
users[username] = password
|
||||
_save_auth_users(auth_file, users)
|
||||
|
||||
token = _issue_web_token(app, username)
|
||||
handoff_code, handoff_expires_at = _issue_handoff_code(app, username, token)
|
||||
backend_connection = {
|
||||
**_backend_connection_view(request),
|
||||
"public_base_url": public_base_url,
|
||||
"api_base_url": public_base_url,
|
||||
"frontend_base_url": frontend_base_url,
|
||||
"registered": bool(local_backend),
|
||||
}
|
||||
if local_backend is not None:
|
||||
backend_connection.update(
|
||||
{
|
||||
"backend_id": local_backend.get("backend_id"),
|
||||
"client_id": local_backend.get("client_id"),
|
||||
"name": local_backend.get("name"),
|
||||
}
|
||||
)
|
||||
return {
|
||||
"access_token": token,
|
||||
"refresh_token": "",
|
||||
"token_type": "bearer",
|
||||
"user_id": username,
|
||||
"username": username,
|
||||
"email": email,
|
||||
"role": "owner",
|
||||
"handoff_code": handoff_code,
|
||||
"handoff_expires_at": handoff_expires_at,
|
||||
"existing_user": user_exists,
|
||||
"authz": {
|
||||
"enabled": bool(authz_base_url),
|
||||
"base_url": authz_base_url or None,
|
||||
"user_registered": authz_user_registered,
|
||||
"backend_registered": authz_backend_registered,
|
||||
},
|
||||
"backend_connection": backend_connection,
|
||||
"local_backend": local_backend or _local_backend_view(),
|
||||
}
|
||||
|
||||
@app.post("/api/auth/handoff/consume")
|
||||
async def auth_handoff_consume(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
return _consume_handoff_code(app, str(payload.get("code") or ""))
|
||||
|
||||
@app.get("/api/auth/me")
|
||||
async def auth_me(authorization: str | None = Header(default=None)) -> dict[str, Any]:
|
||||
username = _require_web_user(app, authorization)
|
||||
identity = _require_web_identity(app, authorization)
|
||||
return {
|
||||
"id": username,
|
||||
"username": username,
|
||||
"email": os.getenv("BEAVER_BACKEND_IDENTITY__EMAIL", ""),
|
||||
"id": identity.user_id,
|
||||
"username": identity.username,
|
||||
"email": identity.email,
|
||||
"role": "owner",
|
||||
"quota_tier": "single-user",
|
||||
}
|
||||
|
||||
@app.post("/api/auth/logout")
|
||||
async def auth_logout(authorization: str | None = Header(default=None)) -> dict[str, Any]:
|
||||
if authorization and authorization.lower().startswith("bearer "):
|
||||
token = authorization[7:].strip()
|
||||
app.state.auth_tokens.pop(token, None)
|
||||
async def auth_logout() -> dict[str, Any]:
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/providers/{provider_name}/config", response_model=WebProviderConfigResponse)
|
||||
@ -3288,82 +3204,6 @@ def _provider_enabled(provider_name: str, provider_cfg: Any) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _auth_file_path() -> Path:
|
||||
raw = os.getenv("BEAVER_AUTH_FILE")
|
||||
if raw:
|
||||
return Path(raw)
|
||||
return Path.home() / ".beaver" / "web_auth_users.json"
|
||||
|
||||
|
||||
def _load_auth_users(path: Path) -> dict[str, str]:
|
||||
if not path.exists():
|
||||
raise HTTPException(status_code=500, detail=f"Auth file not found: {path}")
|
||||
try:
|
||||
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Invalid auth file: {path}") from exc
|
||||
|
||||
users: dict[str, str] = {}
|
||||
if isinstance(raw, dict):
|
||||
entries = raw.get("users") or raw.get("accounts")
|
||||
if isinstance(entries, list):
|
||||
for entry in entries:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
username = _clean_text(entry.get("username"))
|
||||
password = entry.get("password")
|
||||
if username and isinstance(password, str):
|
||||
users[username] = password
|
||||
for key, value in raw.items():
|
||||
if key in {"users", "accounts"}:
|
||||
continue
|
||||
username = _clean_text(key)
|
||||
if username and isinstance(value, str):
|
||||
users[username] = value
|
||||
if not users:
|
||||
raise HTTPException(status_code=500, detail=f"No valid users found in auth file: {path}")
|
||||
return users
|
||||
|
||||
|
||||
def _load_auth_users_if_present(path: Path) -> dict[str, str]:
|
||||
if not path.exists():
|
||||
return {}
|
||||
return _load_auth_users(path)
|
||||
|
||||
|
||||
def _save_auth_users(path: Path, users: dict[str, str]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"users": [
|
||||
{"username": username, "password": password}
|
||||
for username, password in sorted(users.items())
|
||||
]
|
||||
}
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def _issue_web_token(app: FastAPI, username: str) -> str:
|
||||
token = secrets.token_urlsafe(32)
|
||||
app.state.auth_tokens[token] = username
|
||||
return token
|
||||
|
||||
|
||||
def _handoff_ttl_seconds() -> int:
|
||||
raw = os.getenv("BEAVER_HANDOFF_CODE_TTL_SECONDS", "90").strip()
|
||||
try:
|
||||
return max(15, int(raw))
|
||||
except ValueError:
|
||||
return 90
|
||||
|
||||
|
||||
def _handoff_replay_window_seconds() -> int:
|
||||
raw = os.getenv("BEAVER_HANDOFF_REPLAY_WINDOW_SECONDS", "15").strip()
|
||||
try:
|
||||
return max(1, int(raw))
|
||||
except ValueError:
|
||||
return 15
|
||||
|
||||
|
||||
def _int_env(name: str, default: int) -> int:
|
||||
raw = os.getenv(name, "").strip()
|
||||
if not raw:
|
||||
@ -3385,81 +3225,10 @@ def _human_upload_size(size: int) -> str:
|
||||
return f"{size}B"
|
||||
|
||||
|
||||
def _prune_handoff_codes(app: FastAPI) -> None:
|
||||
now = time.time()
|
||||
replay_window = _handoff_replay_window_seconds()
|
||||
expired = []
|
||||
for code, payload in list(app.state.handoff_codes.items()):
|
||||
expires_at = float(payload.get("expires_at") or 0)
|
||||
consumed_at = payload.get("consumed_at")
|
||||
if expires_at <= now:
|
||||
expired.append(code)
|
||||
elif consumed_at is not None and now - float(consumed_at) > replay_window:
|
||||
expired.append(code)
|
||||
for code in expired:
|
||||
app.state.handoff_codes.pop(code, None)
|
||||
|
||||
|
||||
def _issue_handoff_code(app: FastAPI, username: str, access_token: str, refresh_token: str = "") -> tuple[str, int]:
|
||||
_prune_handoff_codes(app)
|
||||
code = secrets.token_urlsafe(24)
|
||||
expires_at = int(time.time()) + _handoff_ttl_seconds()
|
||||
app.state.handoff_codes[code] = {
|
||||
"username": username,
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"expires_at": expires_at,
|
||||
"consumed_at": None,
|
||||
}
|
||||
return code, expires_at
|
||||
|
||||
|
||||
def _consume_handoff_code(app: FastAPI, code: str) -> dict[str, Any]:
|
||||
if not code.strip():
|
||||
raise HTTPException(status_code=400, detail="Handoff code is required")
|
||||
_prune_handoff_codes(app)
|
||||
payload = app.state.handoff_codes.get(code)
|
||||
if payload is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired handoff code")
|
||||
now = time.time()
|
||||
expires_at = float(payload.get("expires_at") or 0)
|
||||
if expires_at <= now:
|
||||
app.state.handoff_codes.pop(code, None)
|
||||
raise HTTPException(status_code=410, detail="Handoff code expired")
|
||||
consumed_at = payload.get("consumed_at")
|
||||
if consumed_at is None:
|
||||
payload["consumed_at"] = now
|
||||
elif now - float(consumed_at) > _handoff_replay_window_seconds():
|
||||
app.state.handoff_codes.pop(code, None)
|
||||
raise HTTPException(status_code=410, detail="Handoff code already used")
|
||||
username = str(payload.get("username") or "").strip()
|
||||
access_token = str(payload.get("access_token") or "").strip()
|
||||
if not username or not access_token:
|
||||
app.state.handoff_codes.pop(code, None)
|
||||
raise HTTPException(status_code=401, detail="Invalid handoff payload")
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": str(payload.get("refresh_token") or ""),
|
||||
"token_type": "bearer",
|
||||
"user_id": username,
|
||||
"username": username,
|
||||
"role": "owner",
|
||||
}
|
||||
|
||||
|
||||
def _require_web_user(app: FastAPI, authorization: str | None) -> str:
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
||||
prefix = "bearer "
|
||||
if not authorization.lower().startswith(prefix):
|
||||
raise HTTPException(status_code=401, detail="Invalid Authorization header")
|
||||
token = authorization[len(prefix):].strip()
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
username = app.state.auth_tokens.get(token)
|
||||
if not username:
|
||||
raise HTTPException(status_code=401, detail="Invalid or expired token")
|
||||
return username
|
||||
def _require_web_identity(app: FastAPI, authorization: str | None) -> KeycloakIdentity:
|
||||
token = extract_bearer_token(authorization)
|
||||
verifier: KeycloakTokenVerifier = app.state.keycloak_token_verifier
|
||||
return verifier.verify(token)
|
||||
|
||||
|
||||
def _backend_connection_view(request: Request) -> dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user