"""Keycloak OIDC token verification for the Beaver web app.""" from __future__ import annotations from dataclasses import dataclass import os import time from typing import Any import jwt from jwt import PyJWKClient try: from fastapi import HTTPException except ModuleNotFoundError: # pragma: no cover class HTTPException(Exception): # type: ignore[override] def __init__(self, status_code: int, detail: str) -> None: super().__init__(detail) self.status_code = status_code self.detail = detail DEFAULT_KEYCLOAK_ISSUER = "https://keycloak.bwgdi.com/realms/beaver" DEFAULT_KEYCLOAK_CLIENT_ID = "beaver-agnet" @dataclass(frozen=True, slots=True) class KeycloakAuthConfig: issuer: str client_id: str token_url: str jwks_url: str @classmethod def from_env(cls) -> "KeycloakAuthConfig": issuer = _clean_base_url(os.getenv("BEAVER_KEYCLOAK_ISSUER") or DEFAULT_KEYCLOAK_ISSUER) client_id = (os.getenv("BEAVER_KEYCLOAK_CLIENT_ID") or DEFAULT_KEYCLOAK_CLIENT_ID).strip() token_url = ( os.getenv("BEAVER_KEYCLOAK_TOKEN_URL", "").strip() or f"{issuer}/protocol/openid-connect/token" ) jwks_url = ( os.getenv("BEAVER_KEYCLOAK_JWKS_URL", "").strip() or f"{issuer}/protocol/openid-connect/certs" ) return cls(issuer=issuer, client_id=client_id, token_url=token_url, jwks_url=jwks_url) @dataclass(frozen=True, slots=True) class KeycloakIdentity: user_id: str username: str email: str = "" name: str = "" realm_roles: tuple[str, ...] = () client_roles: tuple[str, ...] = () def extract_bearer_token(authorization: str | None) -> str: if not authorization: raise HTTPException(status_code=401, detail="Missing Authorization header") prefix = "bearer " if not authorization.lower().startswith(prefix): raise HTTPException(status_code=401, detail="Invalid Authorization header") token = authorization[len(prefix):].strip() if not token: raise HTTPException(status_code=401, detail="Invalid token") return token class KeycloakTokenVerifier: def __init__(self, *, config: KeycloakAuthConfig) -> None: self.config = config self._jwks_client = PyJWKClient(config.jwks_url) def verify(self, token: str, *, expected_nonce: str | None = None) -> KeycloakIdentity: try: signing_key = self._jwks_client.get_signing_key_from_jwt(token).key claims = jwt.decode( token, signing_key, algorithms=["RS256"], issuer=self.config.issuer, options={ "require": ["exp", "iat", "iss"], "verify_aud": False, }, ) except Exception as exc: # noqa: BLE001 - normalize JWT/JWKS failures for HTTP callers raise HTTPException(status_code=401, detail=f"Invalid token: {exc}") from exc return self.validate_claims(claims, expected_nonce=expected_nonce) def validate_claims(self, claims: dict[str, Any], *, expected_nonce: str | None = None) -> KeycloakIdentity: now = int(time.time()) issuer = str(claims.get("iss") or "") if issuer != self.config.issuer: raise HTTPException(status_code=401, detail="Invalid token issuer") exp = _int_claim(claims, "exp") iat = _int_claim(claims, "iat") if exp <= now: raise HTTPException(status_code=401, detail="Token expired") if iat > now + 120: raise HTTPException(status_code=401, detail="Token issued in the future") if not _matches_client(claims.get("aud"), self.config.client_id) and claims.get("azp") != self.config.client_id: raise HTTPException(status_code=401, detail="Invalid token audience") if expected_nonce is not None and claims.get("nonce") != expected_nonce: raise HTTPException(status_code=401, detail="Invalid token nonce") user_id = str(claims.get("sub") or "").strip() if not user_id: raise HTTPException(status_code=401, detail="Token subject is required") username = ( str(claims.get("preferred_username") or "").strip() or str(claims.get("email") or "").strip() or user_id ) return KeycloakIdentity( user_id=user_id, username=username, email=str(claims.get("email") or "").strip(), name=str(claims.get("name") or "").strip(), realm_roles=_roles_from(claims.get("realm_access")), client_roles=_roles_from((claims.get("resource_access") or {}).get(self.config.client_id) if isinstance(claims.get("resource_access"), dict) else None), ) def _clean_base_url(value: str) -> str: return value.strip().rstrip("/") def _int_claim(claims: dict[str, Any], key: str) -> int: try: return int(claims[key]) except (KeyError, TypeError, ValueError) as exc: raise HTTPException(status_code=401, detail=f"Token {key} claim is required") from exc def _matches_client(audience: Any, client_id: str) -> bool: if isinstance(audience, str): return audience == client_id if isinstance(audience, list): return client_id in {str(item) for item in audience} return False def _roles_from(value: Any) -> tuple[str, ...]: if not isinstance(value, dict): return () roles = value.get("roles") if not isinstance(roles, list): return () return tuple(str(role) for role in roles if str(role).strip())