package rest import ( "context" "encoding/json" "net" "net/http" "strings" "sync" "time" "github.com/gorilla/mux" "github.com/ocdp/cluster-service/internal/adapter/input/http/dto" "github.com/ocdp/cluster-service/internal/domain/entity" "github.com/ocdp/cluster-service/internal/domain/service" "github.com/ocdp/cluster-service/internal/pkg/authz" ) // AuthHandler 认证 Handler type AuthHandler struct { authService *service.AuthService } const ( loginRateLimitWindow = time.Minute loginRateLimitFailures = 5 ) var defaultLoginRateLimiter = newLoginRateLimiter(loginRateLimitWindow, loginRateLimitFailures) type loginRateLimiter struct { mu sync.Mutex window time.Duration limit int failures map[string]loginFailureState now func() time.Time } type loginFailureState struct { count int windowEnds time.Time } func newLoginRateLimiter(window time.Duration, limit int) *loginRateLimiter { return &loginRateLimiter{ window: window, limit: limit, failures: make(map[string]loginFailureState), now: time.Now, } } func (l *loginRateLimiter) Allow(key string) bool { if l == nil || key == "" { return true } l.mu.Lock() defer l.mu.Unlock() state, ok := l.failures[key] now := l.now() if !ok || now.After(state.windowEnds) { return true } return state.count < l.limit } func (l *loginRateLimiter) RecordFailure(key string) { if l == nil || key == "" { return } l.mu.Lock() defer l.mu.Unlock() now := l.now() state, ok := l.failures[key] if !ok || now.After(state.windowEnds) { l.failures[key] = loginFailureState{count: 1, windowEnds: now.Add(l.window)} return } state.count++ l.failures[key] = state } func (l *loginRateLimiter) Reset(key string) { if l == nil || key == "" { return } l.mu.Lock() defer l.mu.Unlock() delete(l.failures, key) } // NewAuthHandler 创建认证 Handler func NewAuthHandler(authService *service.AuthService) *AuthHandler { return &AuthHandler{ authService: authService, } } // Register 管理员创建用户 // @Summary 管理员创建用户 // @Description 创建一个新的后台用户。公开自注册已禁用,只允许 admin 调用。 // @Tags Auth // @Accept json // @Produce json // @Param request body dto.RegisterRequest true "注册信息" // @Success 201 {object} dto.UserResponse // @Failure 400 {object} dto.ErrorResponse // @Router /auth/register [post] func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) { var req dto.RegisterRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body", err.Error()) return } req.Normalize() // 调用领域服务 user, err := h.authService.Register(r.Context(), req.Username, req.Password, req.Role, req.WorkspaceID, service.UserWorkspaceOptions{ Namespace: req.Namespace, DefaultClusterID: req.DefaultClusterID, QuotaCPU: req.QuotaCPU, QuotaMemory: req.QuotaMemory, QuotaGPU: req.QuotaGPU, QuotaGPUMem: req.QuotaGPUMem, }, req.IsActive, req.MustChangePassword) if err != nil { respondServiceError(w, err, "Registration failed") return } respondJSON(w, http.StatusCreated, h.convertUserResponse(r.Context(), user)) } func (h *AuthHandler) ListUsers(w http.ResponseWriter, r *http.Request) { users, err := h.authService.ListUsers(r.Context()) if err != nil { respondServiceError(w, err, "Failed to list users") return } responses := make([]*dto.UserResponse, 0, len(users)) for _, user := range users { responses = append(responses, h.convertUserResponse(r.Context(), user)) } respondJSON(w, http.StatusOK, responses) } func (h *AuthHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { userID := mux.Vars(r)["user_id"] var req dto.UpdateUserRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body", err.Error()) return } req.Normalize() user, err := h.authService.UpdateUser(r.Context(), userID, req.Role, req.WorkspaceID, service.UserWorkspaceOptions{ Namespace: req.Namespace, DefaultClusterID: req.DefaultClusterID, QuotaCPU: req.QuotaCPU, QuotaMemory: req.QuotaMemory, QuotaGPU: req.QuotaGPU, QuotaGPUMem: req.QuotaGPUMem, }, req.IsActive, req.MustChangePassword) if err != nil { respondServiceError(w, err, "Failed to update user") return } respondJSON(w, http.StatusOK, h.convertUserResponse(r.Context(), user)) } func (h *AuthHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { userID := mux.Vars(r)["user_id"] if err := h.authService.DeleteUser(r.Context(), userID); err != nil { respondServiceError(w, err, "Failed to delete user") return } w.WriteHeader(http.StatusNoContent) } // Login 用户登录 // @Summary 用户登录 // @Description 使用用户名和密码获取访问令牌 // @Tags Auth // @Accept json // @Produce json // @Param request body dto.LoginRequest true "登录信息" // @Success 200 {object} dto.AuthResponse // @Failure 401 {object} dto.ErrorResponse // @Router /auth/login [post] func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) { var req dto.LoginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body", err.Error()) return } rateLimitKey := loginRateLimitKey(r, req.Username) if !defaultLoginRateLimiter.Allow(rateLimitKey) { w.Header().Set("Retry-After", "60") respondError(w, http.StatusTooManyRequests, "Too many login attempts", "too many login attempts; retry later") return } // 调用领域服务 accessToken, refreshToken, user, err := h.authService.Login(r.Context(), req.Username, req.Password) if err != nil { defaultLoginRateLimiter.RecordFailure(rateLimitKey) respondError(w, http.StatusUnauthorized, "Invalid username or password", "invalid username or password") return } defaultLoginRateLimiter.Reset(rateLimitKey) workspace, _ := h.authService.GetWorkspaceByID(r.Context(), user.WorkspaceID) // 返回响应 response := &dto.AuthResponse{ AccessToken: accessToken, RefreshToken: refreshToken, UserID: user.ID, Username: user.Username, Role: user.Role, WorkspaceID: user.WorkspaceID, WorkspaceName: workspaceName(workspace), Namespace: workspaceNamespace(workspace), DefaultClusterID: workspaceDefaultClusterID(workspace), QuotaCPU: workspaceQuotaCPU(workspace), QuotaMemory: workspaceQuotaMemory(workspace), QuotaGPU: workspaceQuotaGPU(workspace), QuotaGPUMem: workspaceQuotaGPUMem(workspace), Permissions: authz.PermissionsForRole(user.Role), PermissionVersion: 1, } respondJSON(w, http.StatusOK, response) } func loginRateLimitKey(r *http.Request, username string) string { client := strings.TrimSpace(r.Header.Get("X-Forwarded-For")) if idx := strings.Index(client, ","); idx >= 0 { client = strings.TrimSpace(client[:idx]) } if client == "" { client = strings.TrimSpace(r.Header.Get("X-Real-IP")) } if client == "" { client = r.RemoteAddr if host, _, err := net.SplitHostPort(client); err == nil { client = host } } return strings.ToLower(strings.TrimSpace(username)) + "|" + client } func (h *AuthHandler) convertUserResponse(ctx context.Context, user *entity.User) *dto.UserResponse { workspace, _ := h.authService.GetWorkspaceByID(ctx, user.WorkspaceID) return &dto.UserResponse{ ID: user.ID, Username: user.Username, Email: user.Email, Role: user.Role, WorkspaceID: user.WorkspaceID, WorkspaceName: workspaceName(workspace), Namespace: workspaceNamespace(workspace), DefaultClusterID: workspaceDefaultClusterID(workspace), QuotaCPU: workspaceQuotaCPU(workspace), QuotaMemory: workspaceQuotaMemory(workspace), QuotaGPU: workspaceQuotaGPU(workspace), QuotaGPUMem: workspaceQuotaGPUMem(workspace), IsActive: user.IsActive, MustChangePassword: user.MustChangePassword, CreatedAt: user.CreatedAt.Format("2006-01-02T15:04:05Z07:00"), UpdatedAt: user.UpdatedAt.Format("2006-01-02T15:04:05Z07:00"), } } // RefreshToken 刷新 Token // @Summary 刷新访问令牌 // @Description 使用刷新令牌获取新的访问令牌 // @Tags Auth // @Accept json // @Produce json // @Param request body dto.RefreshTokenRequest true "刷新令牌" // @Success 200 {object} dto.AuthResponse // @Failure 401 {object} dto.ErrorResponse // @Router /auth/refresh [post] func (h *AuthHandler) RefreshToken(w http.ResponseWriter, r *http.Request) { var req dto.RefreshTokenRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { respondError(w, http.StatusBadRequest, "Invalid request body", err.Error()) return } // 调用领域服务 newAccessToken, user, err := h.authService.RefreshToken(r.Context(), req.RefreshToken) if err != nil { respondError(w, http.StatusUnauthorized, "Token refresh failed", err.Error()) return } workspace, _ := h.authService.GetWorkspaceByID(r.Context(), user.WorkspaceID) // 返回响应 response := &dto.AuthResponse{ AccessToken: newAccessToken, RefreshToken: req.RefreshToken, UserID: user.ID, Username: user.Username, Role: user.Role, WorkspaceID: user.WorkspaceID, WorkspaceName: workspaceName(workspace), Namespace: workspaceNamespace(workspace), DefaultClusterID: workspaceDefaultClusterID(workspace), QuotaCPU: workspaceQuotaCPU(workspace), QuotaMemory: workspaceQuotaMemory(workspace), QuotaGPU: workspaceQuotaGPU(workspace), QuotaGPUMem: workspaceQuotaGPUMem(workspace), Permissions: authz.PermissionsForRole(user.Role), PermissionVersion: 1, } respondJSON(w, http.StatusOK, response) } func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) { header := r.Header.Get("Authorization") token := strings.TrimSpace(strings.TrimPrefix(header, "Bearer ")) if token == "" || token == header { respondError(w, http.StatusUnauthorized, "Unauthorized", "missing bearer token") return } principal, err := h.authService.VerifyAccessToken(r.Context(), token) if err != nil { respondError(w, http.StatusUnauthorized, "Unauthorized", err.Error()) return } respondJSON(w, http.StatusOK, &dto.AuthResponse{ UserID: principal.UserID, Username: principal.Username, Role: principal.Role, WorkspaceID: principal.WorkspaceID, WorkspaceName: principal.WorkspaceName, Namespace: principal.Namespace, DefaultClusterID: principal.DefaultClusterID, QuotaCPU: principal.QuotaCPU, QuotaMemory: principal.QuotaMemory, QuotaGPU: principal.QuotaGPU, QuotaGPUMem: principal.QuotaGPUMem, Permissions: principal.Permissions, PermissionVersion: principal.PermissionVersion, }) } func workspaceName(workspace *entity.Workspace) string { if workspace == nil { return "" } return workspace.Name } func workspaceNamespace(workspace *entity.Workspace) string { if workspace == nil { return "" } return workspace.K8sNamespace } func workspaceDefaultClusterID(workspace *entity.Workspace) string { if workspace == nil { return "" } return workspace.DefaultClusterID } func workspaceQuotaCPU(workspace *entity.Workspace) string { if workspace == nil { return "" } return workspace.QuotaCPU } func workspaceQuotaMemory(workspace *entity.Workspace) string { if workspace == nil { return "" } return workspace.QuotaMemory } func workspaceQuotaGPU(workspace *entity.Workspace) string { if workspace == nil { return "" } return workspace.QuotaGPU } func workspaceQuotaGPUMem(workspace *entity.Workspace) string { if workspace == nil { return "" } return workspace.QuotaGPUMem }