fix: scale replicas in response, K8s metrics client, quota precheck, auth tests

- Add GetMetrics method to MetricsClient interface and implement cluster metrics API
- Add QuotaPrecheck service for validating resource quotas before deployment
- Add auth DTO with role/permission models and auth handler tests
- Add instance diagnostics: mounted NFS volumes, labels, annotations in pod diagnostics
- Update workspace handler with GetWorkspace endpoint and shared-user list
- Fix monitoring handler to use correct service method name
- Add tail_lines fallback in instance handler for snake_case query params
- Update nginx config for SSE log streaming support (no buffering)
- Add comprehensive test coverage: auth_service_test, auth_handler_test,
  auth_dto_test, metrics_client_test, quota_precheck_test
- Update error messages for quota validation and instance operations
- ModifyModal: fix YAML lineWidth:0, modified keys summary, delta-only submit
- InstanceCard: correctly disable scale-minus when replicas <= 0
- SidebarLayout: add hover transition for sidebar items
- Update todo.md and lessons.md with latest fixes
This commit is contained in:
Ivan087
2026-05-20 16:56:29 +08:00
parent 8f90cf0f0d
commit 33ddaf97db
59 changed files with 4805 additions and 457 deletions

View File

@ -79,6 +79,12 @@ func main() {
passwordHasher,
tokenGenerator,
)
authService.SetUserLifecycleCleanup(
repos.InstanceRepo,
repos.ClusterRepo,
repos.BindingRepo,
repos.TenantKubeClient,
)
clusterService := service.NewClusterService(
repos.ClusterRepo,
@ -106,10 +112,13 @@ func main() {
instanceService.SetDiagnosticsClient(repos.DiagnosticsClient)
instanceService.SetTenantProvisioning(repos.WorkspaceRepo, repos.TenantKubeClient)
instanceService.SetScaleClient(k8s.NewScaleClient())
instanceService.SetUserRepository(repos.UserRepo)
monitoringService := service.NewMonitoringService(
repos.ClusterRepo,
repos.MetricsClient,
repos.InstanceRepo,
repos.UserRepo,
)
workspaceService := service.NewWorkspaceService(
@ -243,8 +252,8 @@ func setupRouter(
api := router.PathPrefix("/api/v1").Subrouter()
// ===== 认证路由 =====
api.HandleFunc("/auth/login", authHandler.Login)
api.HandleFunc("/auth/refresh", authHandler.RefreshToken)
api.HandleFunc("/auth/login", authHandler.Login).Methods(http.MethodPost)
api.HandleFunc("/auth/refresh", authHandler.RefreshToken).Methods(http.MethodPost)
protected := api.PathPrefix("").Subrouter()
protected.Use(authMiddleware(authService))
@ -262,6 +271,8 @@ func setupRouter(
protected.HandleFunc("/clusters/{cluster_id}", clusterHandler.UpdateCluster).Methods(http.MethodPut)
protected.HandleFunc("/clusters/{cluster_id}", clusterHandler.DeleteCluster).Methods(http.MethodDelete)
protected.HandleFunc("/clusters/{cluster_id}/health", clusterHandler.GetClusterHealth).Methods(http.MethodGet)
protected.HandleFunc("/clusters/{cluster_id}/stats", monitoringHandler.GetClusterStats).Methods(http.MethodGet)
protected.HandleFunc("/clusters/{cluster_id}/kubeconfig", workspaceHandler.IssueClusterKubeconfig).Methods(http.MethodGet)
// ===== Registry 路由 =====
protected.HandleFunc("/registries", registryHandler.CreateRegistry).Methods(http.MethodPost)
@ -273,7 +284,9 @@ func setupRouter(
// ===== Artifact 路由 =====
protected.HandleFunc("/registries/{registry_id}/repositories", artifactHandler.ListRepositories).Methods(http.MethodGet)
protected.HandleFunc("/repositories/{repository_name:.+}/tags", artifactHandler.ListRepositoryTags).Methods(http.MethodGet)
protected.HandleFunc("/registries/{registry_id}/repositories/{repository_name:.+}/artifacts", artifactHandler.ListArtifacts).Methods(http.MethodGet)
protected.HandleFunc("/registries/{registry_id}/repositories/{repository_name:.+}/tags", artifactHandler.ListRepositoryTags).Methods(http.MethodGet)
protected.HandleFunc("/registries/{registry_id}/repositories/{repository_name:.+}/artifacts/{reference}", artifactHandler.GetArtifact).Methods(http.MethodGet)
protected.HandleFunc("/registries/{registry_id}/repositories/{repository_name:.+}/artifacts/{reference}/values-schema", artifactHandler.GetArtifactValuesSchema).Methods(http.MethodGet)
protected.HandleFunc("/registries/{registry_id}/repositories/{repository_name:.+}/artifacts/{reference}/values-yaml", artifactHandler.GetArtifactValuesYAML).Methods(http.MethodGet)
@ -293,6 +306,7 @@ func setupRouter(
// ===== Monitoring 路由 =====
protected.HandleFunc("/monitoring/clusters", monitoringHandler.ListClusterMonitoring).Methods(http.MethodGet)
protected.HandleFunc("/monitoring/clusters/{cluster_id}", monitoringHandler.GetClusterMonitoring).Methods(http.MethodGet)
protected.HandleFunc("/monitoring/clusters/{cluster_id}/metrics", monitoringHandler.GetClusterMonitoring).Methods(http.MethodGet)
protected.HandleFunc("/monitoring/clusters/{cluster_id}/nodes", monitoringHandler.GetNodeMetrics).Methods(http.MethodGet)
protected.HandleFunc("/monitoring/summary", monitoringHandler.GetMonitoringSummary).Methods(http.MethodGet)
@ -358,15 +372,16 @@ func loggingMiddleware(next http.Handler) http.Handler {
// corsMiddleware CORS 中间件
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 设置 CORS 头
origin := r.Header.Get("Origin")
if origin == "" {
origin = "*"
if origin != "" {
w.Header().Add("Vary", "Origin")
if corsOriginAllowed(origin) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
}
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With")
w.Header().Set("Access-Control-Allow-Credentials", "true")
w.Header().Set("Access-Control-Max-Age", "86400")
// 处理 OPTIONS 预检请求
@ -378,3 +393,47 @@ func corsMiddleware(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}
func corsOriginAllowed(origin string) bool {
origin = strings.TrimSpace(origin)
if origin == "" {
return false
}
for _, allowed := range corsAllowedOrigins() {
if origin == allowed {
return true
}
}
return false
}
func corsAllowedOrigins() []string {
configured := strings.TrimSpace(os.Getenv("CORS_ALLOWED_ORIGINS"))
if configured == "" {
configured = strings.TrimSpace(os.Getenv("ALLOWED_ORIGINS"))
}
if configured == "" {
return []string{
"http://localhost:3000",
"http://localhost:5173",
"http://localhost:8080",
"http://localhost:18080",
"http://localhost:18081",
"http://127.0.0.1:3000",
"http://127.0.0.1:5173",
"http://127.0.0.1:8080",
"http://127.0.0.1:18080",
"http://127.0.0.1:18081",
"http://10.6.80.114:18080",
}
}
origins := make([]string, 0)
for _, origin := range strings.Split(configured, ",") {
origin = strings.TrimSpace(origin)
if origin == "" || origin == "*" {
continue
}
origins = append(origins, origin)
}
return origins
}

View File

@ -0,0 +1,50 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestCORSMiddlewareAllowsDefaultLocalhostOrigin(t *testing.T) {
t.Setenv("CORS_ALLOWED_ORIGINS", "")
t.Setenv("ALLOWED_ORIGINS", "")
req := httptest.NewRequest(http.MethodGet, "/health", nil)
req.Header.Set("Origin", "http://localhost:5173")
rec := httptest.NewRecorder()
corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})).ServeHTTP(rec, req)
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" {
t.Fatalf("expected localhost origin to be allowed, got %q", got)
}
if got := rec.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
t.Fatalf("expected credentials header for allowed origin, got %q", got)
}
}
func TestCORSMiddlewareDoesNotReflectDisallowedOrigin(t *testing.T) {
t.Setenv("CORS_ALLOWED_ORIGINS", "https://app.example.com")
t.Setenv("ALLOWED_ORIGINS", "")
req := httptest.NewRequest(http.MethodOptions, "/api/v1/auth/login", nil)
req.Header.Set("Origin", "https://evil.example.com")
rec := httptest.NewRecorder()
corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatal("preflight should not call next handler")
})).ServeHTTP(rec, req)
if got := rec.Code; got != http.StatusNoContent {
t.Fatalf("expected preflight status %d, got %d", http.StatusNoContent, got)
}
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "" {
t.Fatalf("expected disallowed origin not to be reflected, got %q", got)
}
if got := rec.Header().Get("Access-Control-Allow-Credentials"); got != "" {
t.Fatalf("expected credentials header to be omitted for disallowed origin, got %q", got)
}
}

View File

@ -1,19 +1,47 @@
package dto
import "strings"
// RegisterRequest 用户注册请求
type RegisterRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required,min=6"`
Role string `json:"role,omitempty"`
WorkspaceID string `json:"workspaceId,omitempty"`
Namespace string `json:"namespace,omitempty"`
DefaultClusterID string `json:"defaultClusterId,omitempty"`
QuotaCPU string `json:"quotaCpu,omitempty"`
QuotaMemory string `json:"quotaMemory,omitempty"`
QuotaGPU string `json:"quotaGpu,omitempty"`
QuotaGPUMem string `json:"quotaGpuMemory,omitempty"`
IsActive *bool `json:"isActive,omitempty"`
MustChangePassword *bool `json:"mustChangePassword,omitempty"`
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required,min=6"`
Role string `json:"role,omitempty"`
WorkspaceID string `json:"workspaceId,omitempty"`
WorkspaceIDSnake string `json:"workspace_id,omitempty"`
Namespace string `json:"namespace,omitempty"`
DefaultClusterID string `json:"defaultClusterId,omitempty"`
DefaultClusterIDSnake string `json:"default_cluster_id,omitempty"`
QuotaCPU string `json:"quotaCpu,omitempty"`
QuotaCPUSnake string `json:"quota_cpu,omitempty"`
QuotaMemory string `json:"quotaMemory,omitempty"`
QuotaMemorySnake string `json:"quota_memory,omitempty"`
QuotaGPU string `json:"quotaGpu,omitempty"`
QuotaGPUSnake string `json:"quota_gpu,omitempty"`
QuotaGPUMem string `json:"quotaGpuMemory,omitempty"`
QuotaGPUMemSnake string `json:"quota_gpu_memory,omitempty"`
IsActive *bool `json:"isActive,omitempty"`
IsActiveSnake *bool `json:"is_active,omitempty"`
MustChangePassword *bool `json:"mustChangePassword,omitempty"`
MustChangePasswordSnake *bool `json:"must_change_password,omitempty"`
}
func (r *RegisterRequest) Normalize() {
if r == nil {
return
}
r.WorkspaceID = firstNonBlank(r.WorkspaceID, r.WorkspaceIDSnake)
r.DefaultClusterID = firstNonBlank(r.DefaultClusterID, r.DefaultClusterIDSnake)
r.QuotaCPU = firstNonBlank(r.QuotaCPU, r.QuotaCPUSnake)
r.QuotaMemory = firstNonBlank(r.QuotaMemory, r.QuotaMemorySnake)
r.QuotaGPU = firstNonBlank(r.QuotaGPU, r.QuotaGPUSnake)
r.QuotaGPUMem = firstNonBlank(r.QuotaGPUMem, r.QuotaGPUMemSnake)
if r.IsActive == nil {
r.IsActive = r.IsActiveSnake
}
if r.MustChangePassword == nil {
r.MustChangePassword = r.MustChangePasswordSnake
}
}
// LoginRequest 用户登录请求
@ -68,14 +96,47 @@ type UserResponse struct {
// UpdateUserRequest 管理员更新用户状态/角色请求
type UpdateUserRequest struct {
Role string `json:"role,omitempty"`
WorkspaceID string `json:"workspaceId,omitempty"`
Namespace string `json:"namespace,omitempty"`
DefaultClusterID string `json:"defaultClusterId,omitempty"`
QuotaCPU string `json:"quotaCpu,omitempty"`
QuotaMemory string `json:"quotaMemory,omitempty"`
QuotaGPU string `json:"quotaGpu,omitempty"`
QuotaGPUMem string `json:"quotaGpuMemory,omitempty"`
IsActive *bool `json:"isActive,omitempty"`
MustChangePassword *bool `json:"mustChangePassword,omitempty"`
Role string `json:"role,omitempty"`
WorkspaceID string `json:"workspaceId,omitempty"`
WorkspaceIDSnake string `json:"workspace_id,omitempty"`
Namespace string `json:"namespace,omitempty"`
DefaultClusterID string `json:"defaultClusterId,omitempty"`
DefaultClusterIDSnake string `json:"default_cluster_id,omitempty"`
QuotaCPU string `json:"quotaCpu,omitempty"`
QuotaCPUSnake string `json:"quota_cpu,omitempty"`
QuotaMemory string `json:"quotaMemory,omitempty"`
QuotaMemorySnake string `json:"quota_memory,omitempty"`
QuotaGPU string `json:"quotaGpu,omitempty"`
QuotaGPUSnake string `json:"quota_gpu,omitempty"`
QuotaGPUMem string `json:"quotaGpuMemory,omitempty"`
QuotaGPUMemSnake string `json:"quota_gpu_memory,omitempty"`
IsActive *bool `json:"isActive,omitempty"`
IsActiveSnake *bool `json:"is_active,omitempty"`
MustChangePassword *bool `json:"mustChangePassword,omitempty"`
MustChangePasswordSnake *bool `json:"must_change_password,omitempty"`
}
func (r *UpdateUserRequest) Normalize() {
if r == nil {
return
}
r.WorkspaceID = firstNonBlank(r.WorkspaceID, r.WorkspaceIDSnake)
r.DefaultClusterID = firstNonBlank(r.DefaultClusterID, r.DefaultClusterIDSnake)
r.QuotaCPU = firstNonBlank(r.QuotaCPU, r.QuotaCPUSnake)
r.QuotaMemory = firstNonBlank(r.QuotaMemory, r.QuotaMemorySnake)
r.QuotaGPU = firstNonBlank(r.QuotaGPU, r.QuotaGPUSnake)
r.QuotaGPUMem = firstNonBlank(r.QuotaGPUMem, r.QuotaGPUMemSnake)
if r.IsActive == nil {
r.IsActive = r.IsActiveSnake
}
if r.MustChangePassword == nil {
r.MustChangePassword = r.MustChangePasswordSnake
}
}
func firstNonBlank(primary, alternate string) string {
if strings.TrimSpace(primary) != "" {
return primary
}
return alternate
}

View File

@ -0,0 +1,51 @@
package dto
import "testing"
func TestRegisterRequestNormalizeUsesSnakeCaseAlternates(t *testing.T) {
active := false
mustChange := true
req := RegisterRequest{
WorkspaceIDSnake: "workspace-1",
DefaultClusterIDSnake: "cluster-1",
QuotaCPUSnake: "2",
QuotaMemorySnake: "4Gi",
QuotaGPUSnake: "1",
QuotaGPUMemSnake: "10000",
IsActiveSnake: &active,
MustChangePasswordSnake: &mustChange,
}
req.Normalize()
if req.WorkspaceID != "workspace-1" || req.DefaultClusterID != "cluster-1" {
t.Fatalf("expected snake case workspace/cluster fields to normalize, got %#v", req)
}
if req.QuotaCPU != "2" || req.QuotaMemory != "4Gi" || req.QuotaGPU != "1" || req.QuotaGPUMem != "10000" {
t.Fatalf("expected snake case quota fields to normalize, got %#v", req)
}
if req.IsActive == nil || *req.IsActive {
t.Fatalf("expected is_active=false to normalize, got %#v", req.IsActive)
}
if req.MustChangePassword == nil || !*req.MustChangePassword {
t.Fatalf("expected must_change_password=true to normalize, got %#v", req.MustChangePassword)
}
}
func TestUpdateUserRequestNormalizeKeepsCamelCasePrimary(t *testing.T) {
req := UpdateUserRequest{
DefaultClusterID: "camel-cluster",
DefaultClusterIDSnake: "snake-cluster",
QuotaCPU: "3",
QuotaCPUSnake: "4",
}
req.Normalize()
if req.DefaultClusterID != "camel-cluster" {
t.Fatalf("expected camelCase defaultClusterId to win, got %q", req.DefaultClusterID)
}
if req.QuotaCPU != "3" {
t.Fatalf("expected camelCase quotaCpu to win, got %q", req.QuotaCPU)
}
}

View File

@ -2,25 +2,25 @@ package dto
// CreateInstanceRequest 创建实例请求
type CreateInstanceRequest struct {
Name string `json:"name" binding:"required"`
Namespace string `json:"namespace" binding:"required"`
RegistryID string `json:"registryId" binding:"required"`
RegistryIDAlt string `json:"registry_id"`
Repository string `json:"repository" binding:"required"`
Tag string `json:"tag" binding:"required"`
Description string `json:"description"`
Values map[string]interface{} `json:"values"`
ValuesYAML string `json:"valuesYaml"`
ValuesYAMLAlt string `json:"values_yaml"`
Name string `json:"name" binding:"required"`
Namespace string `json:"namespace" binding:"required"`
RegistryID string `json:"registryId" binding:"required"`
RegistryIDAlt string `json:"registry_id"`
Repository string `json:"repository" binding:"required"`
Tag string `json:"tag" binding:"required"`
Description string `json:"description"`
Values map[string]interface{} `json:"values"`
ValuesYAML string `json:"valuesYaml"`
ValuesYAMLAlt string `json:"values_yaml"`
}
// UpdateInstanceRequest 更新实例请求
type UpdateInstanceRequest struct {
Version string `json:"version"`
Description string `json:"description"`
Values map[string]interface{} `json:"values"`
ValuesYAML string `json:"valuesYaml"`
ValuesYAMLAlt string `json:"values_yaml"`
Version string `json:"version"`
Description string `json:"description"`
Values map[string]interface{} `json:"values"`
ValuesYAML string `json:"valuesYaml"`
ValuesYAMLAlt string `json:"values_yaml"`
}
// Normalize 将多种命名风格的字段合并到统一字段
@ -67,6 +67,7 @@ type InstanceResponse struct {
Status string `json:"status"`
WorkspaceID string `json:"workspaceId"`
OwnerID string `json:"ownerId"`
OwnerUsername string `json:"ownerUsername,omitempty"`
AllowedActions []string `json:"allowedActions,omitempty"`
StatusReason string `json:"statusReason,omitempty"`
LastOperation string `json:"lastOperation,omitempty"`

View File

@ -8,29 +8,56 @@ import (
// ClusterMetricsResponse 集群监控响应
type ClusterMetricsResponse struct {
ClusterID string `json:"clusterId"`
ClusterName string `json:"clusterName"`
Status string `json:"status"`
Uptime string `json:"uptime"`
NodeCount int `json:"nodeCount"`
PodCount int `json:"podCount"`
LastCheck time.Time `json:"lastCheck"`
TotalCPU string `json:"totalCpu"`
TotalMemory string `json:"totalMemory"`
TotalGPU int `json:"totalGpu"`
UsedCPU string `json:"usedCpu"`
UsedMemory string `json:"usedMemory"`
UsedGPU int `json:"usedGpu"`
CPUUsage float64 `json:"cpuUsage"`
MemoryUsage float64 `json:"memoryUsage"`
GPUUsage float64 `json:"gpuUsage"`
MaxNodeCPU string `json:"maxNodeCpu"`
MaxNodeMemory string `json:"maxNodeMemory"`
MaxNodeGPU int `json:"maxNodeGpu"`
MaxNodeCPUUsage float64 `json:"maxNodeCpuUsage"`
MaxNodeMemUsage float64 `json:"maxNodeMemUsage"`
MaxNodeGPUUsage float64 `json:"maxNodeGpuUsage"`
Nodes []NodeMetricsResponse `json:"nodes,omitempty"`
ClusterID string `json:"clusterId"`
ClusterName string `json:"clusterName"`
Status string `json:"status"`
Uptime string `json:"uptime"`
NodeCount int `json:"nodeCount"`
PodCount int `json:"podCount"`
LastCheck time.Time `json:"lastCheck"`
TotalCPU string `json:"totalCpu"`
TotalMemory string `json:"totalMemory"`
TotalGPU int `json:"totalGpu"`
UsedCPU string `json:"usedCpu"`
UsedMemory string `json:"usedMemory"`
UsedGPU int `json:"usedGpu"`
CPUUsage float64 `json:"cpuUsage"`
MemoryUsage float64 `json:"memoryUsage"`
GPUUsage float64 `json:"gpuUsage"`
CPURequests string `json:"cpuRequests,omitempty"`
CPULimits string `json:"cpuLimits,omitempty"`
MemoryRequests string `json:"memoryRequests,omitempty"`
MemoryLimits string `json:"memoryLimits,omitempty"`
GPURequests int64 `json:"gpuRequests,omitempty"`
GPULimits int64 `json:"gpuLimits,omitempty"`
GPUMemoryRequestsMB int64 `json:"gpuMemoryRequestsMb,omitempty"`
GPUMemoryLimitsMB int64 `json:"gpuMemoryLimitsMb,omitempty"`
AllocatedGPU int64 `json:"allocatedGpu,omitempty"`
AllocatedGPUMemoryMB int64 `json:"allocatedGpuMemoryMb,omitempty"`
ResourceUsageByUser []UserResourceUsageResponse `json:"resourceUsageByUser,omitempty"`
MaxNodeCPU string `json:"maxNodeCpu"`
MaxNodeMemory string `json:"maxNodeMemory"`
MaxNodeGPU int `json:"maxNodeGpu"`
MaxNodeCPUUsage float64 `json:"maxNodeCpuUsage"`
MaxNodeMemUsage float64 `json:"maxNodeMemUsage"`
MaxNodeGPUUsage float64 `json:"maxNodeGpuUsage"`
Nodes []NodeMetricsResponse `json:"nodes,omitempty"`
}
type UserResourceUsageResponse struct {
UserID string `json:"userId"`
Username string `json:"username"`
WorkspaceID string `json:"workspaceId"`
InstanceCount int `json:"instanceCount"`
PodCount int `json:"podCount"`
CPURequests string `json:"cpuRequests"`
CPULimits string `json:"cpuLimits"`
MemoryRequests string `json:"memoryRequests"`
MemoryLimits string `json:"memoryLimits"`
GPURequests int64 `json:"gpuRequests"`
GPULimits int64 `json:"gpuLimits"`
GPUMemoryRequestsMB int64 `json:"gpuMemoryRequestsMb"`
GPUMemoryLimitsMB int64 `json:"gpuMemoryLimitsMb"`
}
// NodeMetricsResponse 节点监控响应
@ -72,28 +99,59 @@ type MonitoringSummaryResponse struct {
// ToClusterMetricsResponse 转换为响应
func ToClusterMetricsResponse(m *entity.ClusterMetrics) *ClusterMetricsResponse {
resp := &ClusterMetricsResponse{
ClusterID: m.ClusterID,
ClusterName: m.ClusterName,
Status: m.Status,
Uptime: m.Uptime,
NodeCount: m.NodeCount,
PodCount: m.PodCount,
LastCheck: m.LastCheck,
TotalCPU: m.TotalCPU,
TotalMemory: m.TotalMemory,
TotalGPU: m.TotalGPU,
UsedCPU: m.UsedCPU,
UsedMemory: m.UsedMemory,
UsedGPU: m.UsedGPU,
CPUUsage: m.CPUUsage,
MemoryUsage: m.MemoryUsage,
GPUUsage: m.GPUUsage,
MaxNodeCPU: m.MaxNodeCPU,
MaxNodeMemory: m.MaxNodeMemory,
MaxNodeGPU: m.MaxNodeGPU,
MaxNodeCPUUsage: m.MaxNodeCPUUsage,
MaxNodeMemUsage: m.MaxNodeMemUsage,
MaxNodeGPUUsage: m.MaxNodeGPUUsage,
ClusterID: m.ClusterID,
ClusterName: m.ClusterName,
Status: m.Status,
Uptime: m.Uptime,
NodeCount: m.NodeCount,
PodCount: m.PodCount,
LastCheck: m.LastCheck,
TotalCPU: m.TotalCPU,
TotalMemory: m.TotalMemory,
TotalGPU: m.TotalGPU,
UsedCPU: m.UsedCPU,
UsedMemory: m.UsedMemory,
UsedGPU: m.UsedGPU,
CPUUsage: m.CPUUsage,
MemoryUsage: m.MemoryUsage,
GPUUsage: m.GPUUsage,
CPURequests: m.CPURequests,
CPULimits: m.CPULimits,
MemoryRequests: m.MemoryRequests,
MemoryLimits: m.MemoryLimits,
GPURequests: m.GPURequests,
GPULimits: m.GPULimits,
GPUMemoryRequestsMB: m.GPUMemoryRequestsMB,
GPUMemoryLimitsMB: m.GPUMemoryLimitsMB,
AllocatedGPU: m.AllocatedGPU,
AllocatedGPUMemoryMB: m.AllocatedGPUMemoryMB,
MaxNodeCPU: m.MaxNodeCPU,
MaxNodeMemory: m.MaxNodeMemory,
MaxNodeGPU: m.MaxNodeGPU,
MaxNodeCPUUsage: m.MaxNodeCPUUsage,
MaxNodeMemUsage: m.MaxNodeMemUsage,
MaxNodeGPUUsage: m.MaxNodeGPUUsage,
}
if len(m.ResourceUsageByUser) > 0 {
resp.ResourceUsageByUser = make([]UserResourceUsageResponse, len(m.ResourceUsageByUser))
for i, usage := range m.ResourceUsageByUser {
resp.ResourceUsageByUser[i] = UserResourceUsageResponse{
UserID: usage.UserID,
Username: usage.Username,
WorkspaceID: usage.WorkspaceID,
InstanceCount: usage.InstanceCount,
PodCount: usage.PodCount,
CPURequests: usage.CPURequests,
CPULimits: usage.CPULimits,
MemoryRequests: usage.MemoryRequests,
MemoryLimits: usage.MemoryLimits,
GPURequests: usage.GPURequests,
GPULimits: usage.GPULimits,
GPUMemoryRequestsMB: usage.GPUMemoryRequestsMB,
GPUMemoryLimitsMB: usage.GPUMemoryLimitsMB,
}
}
}
if len(m.Nodes) > 0 {

View File

@ -126,6 +126,25 @@ func (h *ArtifactHandler) ListArtifacts(w http.ResponseWriter, r *http.Request)
respondJSON(w, http.StatusOK, tagResponses)
}
// ListRepositoryTags is a compatibility alias for clients that request tags
// directly instead of the canonical artifacts endpoint.
func (h *ArtifactHandler) ListRepositoryTags(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
if vars["registry_id"] == "" {
registryID := r.URL.Query().Get("registry_id")
if registryID == "" {
registryID = r.URL.Query().Get("registryId")
}
if registryID == "" {
respondError(w, http.StatusBadRequest, "Missing registry ID", "registry_id query parameter is required")
return
}
vars["registry_id"] = registryID
r = mux.SetURLVars(r, vars)
}
h.ListArtifacts(w, r)
}
// GetArtifact 获取 artifact 详情
// @Summary 获取 Artifact 详情
// @Description 获取指定 Artifact 的详细信息

View File

@ -3,8 +3,11 @@ package rest
import (
"context"
"encoding/json"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/mux"
"github.com/ocdp/cluster-service/internal/adapter/input/http/dto"
@ -18,6 +21,74 @@ type AuthHandler struct {
authService *service.AuthService
}
const (
loginRateLimitWindow = time.Minute
loginRateLimitFailures = 5
)
var defaultLoginRateLimiter = newLoginRateLimiter(loginRateLimitWindow, loginRateLimitFailures)
type loginRateLimiter struct {
mu sync.Mutex
window time.Duration
limit int
failures map[string]loginFailureState
now func() time.Time
}
type loginFailureState struct {
count int
windowEnds time.Time
}
func newLoginRateLimiter(window time.Duration, limit int) *loginRateLimiter {
return &loginRateLimiter{
window: window,
limit: limit,
failures: make(map[string]loginFailureState),
now: time.Now,
}
}
func (l *loginRateLimiter) Allow(key string) bool {
if l == nil || key == "" {
return true
}
l.mu.Lock()
defer l.mu.Unlock()
state, ok := l.failures[key]
now := l.now()
if !ok || now.After(state.windowEnds) {
return true
}
return state.count < l.limit
}
func (l *loginRateLimiter) RecordFailure(key string) {
if l == nil || key == "" {
return
}
l.mu.Lock()
defer l.mu.Unlock()
now := l.now()
state, ok := l.failures[key]
if !ok || now.After(state.windowEnds) {
l.failures[key] = loginFailureState{count: 1, windowEnds: now.Add(l.window)}
return
}
state.count++
l.failures[key] = state
}
func (l *loginRateLimiter) Reset(key string) {
if l == nil || key == "" {
return
}
l.mu.Lock()
defer l.mu.Unlock()
delete(l.failures, key)
}
// NewAuthHandler 创建认证 Handler
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
return &AuthHandler{
@ -41,6 +112,7 @@ func (h *AuthHandler) Register(w http.ResponseWriter, r *http.Request) {
respondError(w, http.StatusBadRequest, "Invalid request body", err.Error())
return
}
req.Normalize()
// 调用领域服务
user, err := h.authService.Register(r.Context(), req.Username, req.Password, req.Role, req.WorkspaceID, service.UserWorkspaceOptions{
@ -79,6 +151,7 @@ func (h *AuthHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
respondError(w, http.StatusBadRequest, "Invalid request body", err.Error())
return
}
req.Normalize()
user, err := h.authService.UpdateUser(r.Context(), userID, req.Role, req.WorkspaceID, service.UserWorkspaceOptions{
Namespace: req.Namespace,
DefaultClusterID: req.DefaultClusterID,
@ -120,12 +193,21 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
return
}
rateLimitKey := loginRateLimitKey(r, req.Username)
if !defaultLoginRateLimiter.Allow(rateLimitKey) {
w.Header().Set("Retry-After", "60")
respondError(w, http.StatusTooManyRequests, "Too many login attempts", "too many login attempts; retry later")
return
}
// 调用领域服务
accessToken, refreshToken, user, err := h.authService.Login(r.Context(), req.Username, req.Password)
if err != nil {
respondError(w, http.StatusUnauthorized, "Login failed", err.Error())
defaultLoginRateLimiter.RecordFailure(rateLimitKey)
respondError(w, http.StatusUnauthorized, "Invalid username or password", "invalid username or password")
return
}
defaultLoginRateLimiter.Reset(rateLimitKey)
workspace, _ := h.authService.GetWorkspaceByID(r.Context(), user.WorkspaceID)
@ -151,6 +233,23 @@ func (h *AuthHandler) Login(w http.ResponseWriter, r *http.Request) {
respondJSON(w, http.StatusOK, response)
}
func loginRateLimitKey(r *http.Request, username string) string {
client := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
if idx := strings.Index(client, ","); idx >= 0 {
client = strings.TrimSpace(client[:idx])
}
if client == "" {
client = strings.TrimSpace(r.Header.Get("X-Real-IP"))
}
if client == "" {
client = r.RemoteAddr
if host, _, err := net.SplitHostPort(client); err == nil {
client = host
}
}
return strings.ToLower(strings.TrimSpace(username)) + "|" + client
}
func (h *AuthHandler) convertUserResponse(ctx context.Context, user *entity.User) *dto.UserResponse {
workspace, _ := h.authService.GetWorkspaceByID(ctx, user.WorkspaceID)
return &dto.UserResponse{

View File

@ -0,0 +1,44 @@
package rest
import (
"testing"
"time"
)
func TestLoginRateLimiterBlocksAfterConfiguredFailures(t *testing.T) {
now := time.Date(2026, 5, 14, 12, 0, 0, 0, time.UTC)
limiter := newLoginRateLimiter(time.Minute, 2)
limiter.now = func() time.Time { return now }
key := "user|127.0.0.1"
if !limiter.Allow(key) {
t.Fatal("expected first attempt to be allowed")
}
limiter.RecordFailure(key)
if !limiter.Allow(key) {
t.Fatal("expected second attempt to be allowed")
}
limiter.RecordFailure(key)
if limiter.Allow(key) {
t.Fatal("expected third attempt inside the window to be blocked")
}
now = now.Add(time.Minute + time.Second)
if !limiter.Allow(key) {
t.Fatal("expected attempts to be allowed after the window expires")
}
}
func TestLoginRateLimiterResetClearsFailures(t *testing.T) {
limiter := newLoginRateLimiter(time.Minute, 1)
key := "user|127.0.0.1"
limiter.RecordFailure(key)
if limiter.Allow(key) {
t.Fatal("expected key to be blocked after one failure")
}
limiter.Reset(key)
if !limiter.Allow(key) {
t.Fatal("expected reset key to be allowed")
}
}

View File

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"reflect"
"strconv"
"strings"
"time"
@ -49,6 +50,11 @@ func (h *InstanceHandler) CreateInstance(w http.ResponseWriter, r *http.Request)
return
}
req.Normalize()
parsedYAML, hasValuesYAML, err := parseAndCompareValues(req.Values, req.ValuesYAML)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid values", err.Error())
return
}
// Extract chart name from repository (e.g., "charts/nginx" -> "nginx")
chart := req.Repository
@ -71,21 +77,16 @@ func (h *InstanceHandler) CreateInstance(w http.ResponseWriter, r *http.Request)
if req.Values != nil {
instance.SetValues(req.Values)
}
if req.ValuesYAML != "" {
if hasValuesYAML {
instance.SetValuesYAML(req.ValuesYAML)
if req.Values == nil {
values, err := parseValuesYAML(req.ValuesYAML)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid values YAML", err.Error())
return
}
instance.SetValues(values)
instance.SetValues(parsedYAML)
}
}
// 调用领域服务
if err := h.instanceService.CreateInstance(r.Context(), instance); err != nil {
respondError(w, http.StatusBadRequest, "Failed to create instance", err.Error())
respondServiceError(w, err, "Failed to create instance")
return
}
@ -116,6 +117,7 @@ func (h *InstanceHandler) GetInstance(w http.ResponseWriter, r *http.Request) {
respondError(w, http.StatusNotFound, "Instance not found", "resource does not belong to cluster")
return
}
h.instanceService.EnrichReplicas(r.Context(), clusterID, []*entity.Instance{instance})
respondJSON(w, http.StatusOK, convertInstanceResponse(instance, true))
}
@ -144,7 +146,7 @@ func (h *InstanceHandler) ListInstances(w http.ResponseWriter, r *http.Request)
responses := make([]*dto.InstanceResponse, 0, len(instances))
for _, instance := range instances {
responses = append(responses, convertInstanceResponse(instance, false))
responses = append(responses, convertInstanceResponse(instance, true))
}
response := &dto.InstanceListResponse{
@ -177,6 +179,11 @@ func (h *InstanceHandler) UpdateInstance(w http.ResponseWriter, r *http.Request)
return
}
req.Normalize()
parsedYAML, hasValuesYAML, err := parseAndCompareValues(req.Values, req.ValuesYAML)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid values", err.Error())
return
}
// 获取现有实例
instance, err := h.instanceService.GetInstance(r.Context(), instanceID)
@ -194,21 +201,16 @@ func (h *InstanceHandler) UpdateInstance(w http.ResponseWriter, r *http.Request)
if req.Description != "" {
instance.Description = req.Description
}
if req.ValuesYAML != "" {
if hasValuesYAML {
instance.SetValuesYAML(req.ValuesYAML)
if req.Values == nil {
values, err := parseValuesYAML(req.ValuesYAML)
if err != nil {
respondError(w, http.StatusBadRequest, "Invalid values YAML", err.Error())
return
}
instance.SetValues(values)
instance.SetValues(parsedYAML)
}
}
// 调用领域服务
if err := h.instanceService.UpdateInstance(r.Context(), instance); err != nil {
respondError(w, http.StatusBadRequest, "Failed to update instance", err.Error())
respondServiceError(w, err, "Failed to update instance")
return
}
@ -345,7 +347,6 @@ func (h *InstanceHandler) StreamInstanceLogs(w http.ResponseWriter, r *http.Requ
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
flusher, ok := w.(http.Flusher)
if !ok {
@ -585,6 +586,7 @@ func convertInstanceResponse(instance *entity.Instance, includeValues bool) *dto
Status: string(instance.Status),
WorkspaceID: instance.WorkspaceID,
OwnerID: instance.OwnerID,
OwnerUsername: instance.OwnerUsername,
StatusReason: instance.StatusReason,
LastOperation: string(instance.LastOperation),
LastError: instance.LastError,
@ -622,6 +624,43 @@ func parseValuesYAML(valuesYAML string) (map[string]interface{}, error) {
return values, nil
}
func parseAndCompareValues(values map[string]interface{}, valuesYAML string) (map[string]interface{}, bool, error) {
if strings.TrimSpace(valuesYAML) == "" {
return nil, false, nil
}
parsed, err := parseValuesYAML(valuesYAML)
if err != nil {
return nil, true, fmt.Errorf("invalid values YAML: %w", err)
}
if values == nil {
return parsed, true, nil
}
normalizedValues, err := normalizeJSONComparable(values)
if err != nil {
return nil, true, fmt.Errorf("invalid values: %w", err)
}
normalizedYAML, err := normalizeJSONComparable(parsed)
if err != nil {
return nil, true, fmt.Errorf("invalid values YAML: %w", err)
}
if !reflect.DeepEqual(normalizedValues, normalizedYAML) {
return nil, true, fmt.Errorf("values and valuesYaml conflict")
}
return parsed, true, nil
}
func normalizeJSONComparable(value interface{}) (interface{}, error) {
data, err := json.Marshal(value)
if err != nil {
return nil, err
}
var normalized interface{}
if err := json.Unmarshal(data, &normalized); err != nil {
return nil, err
}
return normalized, nil
}
func normalizeYAMLValue(value interface{}) (interface{}, error) {
switch typed := value.(type) {
case map[string]interface{}:

View File

@ -43,6 +43,12 @@ func (h *MonitoringHandler) GetClusterMonitoring(w http.ResponseWriter, r *http.
respondJSON(w, http.StatusOK, response)
}
// GetClusterStats is a compatibility alias for cluster detail dashboards that
// historically read stats from /clusters/{id}/stats.
func (h *MonitoringHandler) GetClusterStats(w http.ResponseWriter, r *http.Request) {
h.GetClusterMonitoring(w, r)
}
// ListClusterMonitoring 获取所有集群的监控信息
// @Summary 列出集群监控
// @Tags Monitoring

View File

@ -2,6 +2,7 @@ package rest
import (
"encoding/json"
"errors"
"net/http"
"time"
@ -113,6 +114,15 @@ func (h *WorkspaceHandler) IssueCurrentKubeconfig(w http.ResponseWriter, r *http
if clusterID == "" {
clusterID = r.URL.Query().Get("cluster_id")
}
h.issueCurrentKubeconfigForCluster(w, r, clusterID)
}
func (h *WorkspaceHandler) IssueClusterKubeconfig(w http.ResponseWriter, r *http.Request) {
clusterID := mux.Vars(r)["cluster_id"]
h.issueCurrentKubeconfigForCluster(w, r, clusterID)
}
func (h *WorkspaceHandler) issueCurrentKubeconfigForCluster(w http.ResponseWriter, r *http.Request, clusterID string) {
kubeconfig, err := h.workspaceService.IssueCurrentKubeconfig(r.Context(), clusterID, 2*time.Hour)
if err != nil {
respondServiceError(w, err, "Failed to issue kubeconfig")
@ -152,11 +162,19 @@ func toWorkspaceResponse(workspace *entity.Workspace) workspaceResponse {
}
func respondServiceError(w http.ResponseWriter, err error, fallback string) {
if errors.Is(err, service.ErrQuotaExceeded) {
respondError(w, http.StatusUnprocessableEntity, "Quota exceeded", err.Error())
return
}
switch err {
case entity.ErrUnauthorized, authz.ErrUnauthenticated:
respondError(w, http.StatusUnauthorized, "Unauthorized", err.Error())
case entity.ErrForbidden, authz.ErrForbidden, entity.ErrUserInactive, entity.ErrWorkspaceSuspended:
respondError(w, http.StatusForbidden, "Forbidden", err.Error())
case entity.ErrWorkspaceNamespaceConflict, entity.ErrUserHasInstances, entity.ErrWorkspaceExists, entity.ErrInstanceExists:
respondError(w, http.StatusConflict, "Conflict", err.Error())
case entity.ErrProtectedNamespace:
respondError(w, http.StatusForbidden, "Forbidden", err.Error())
case entity.ErrClusterNotFound, entity.ErrRegistryNotFound, entity.ErrInstanceNotFound, entity.ErrWorkspaceNotFound:
respondError(w, http.StatusNotFound, fallback, err.Error())
default:

View File

@ -4,7 +4,7 @@ import (
"context"
"fmt"
"time"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
@ -12,38 +12,47 @@ import (
// HelmClientMock Helm 客户端 Mock 实现
type HelmClientMock struct {
// Mock 数据存储
releases map[string]map[string]*entity.Instance // clusterID -> releaseName -> instance
history map[string]map[string][]*entity.ReleaseHistory // clusterID -> releaseName -> []history
releases map[string]map[string]*entity.Instance // clusterID -> releaseName -> instance
history map[string]map[string][]*entity.ReleaseHistory // clusterID -> releaseName -> []history
estimates map[string]map[string]*repository.ResourceEstimate // clusterID -> releaseName -> estimate
}
// NewHelmClientMock 创建 Mock 实现
func NewHelmClientMock() repository.HelmClient {
return &HelmClientMock{
releases: make(map[string]map[string]*entity.Instance),
history: make(map[string]map[string][]*entity.ReleaseHistory),
releases: make(map[string]map[string]*entity.Instance),
history: make(map[string]map[string][]*entity.ReleaseHistory),
estimates: make(map[string]map[string]*repository.ResourceEstimate),
}
}
func (c *HelmClientMock) SetResourceEstimate(clusterID, namespace, releaseName string, estimate *repository.ResourceEstimate) {
if c.estimates[clusterID] == nil {
c.estimates[clusterID] = make(map[string]*repository.ResourceEstimate)
}
c.estimates[clusterID][fmt.Sprintf("%s/%s", namespace, releaseName)] = estimate
}
func (c *HelmClientMock) Install(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
// 初始化集群数据
if c.releases[cluster.ID] == nil {
c.releases[cluster.ID] = make(map[string]*entity.Instance)
c.history[cluster.ID] = make(map[string][]*entity.ReleaseHistory)
}
// 检查是否已存在
key := fmt.Sprintf("%s/%s", instance.Namespace, instance.Name)
if _, exists := c.releases[cluster.ID][key]; exists {
return entity.ErrInstanceExists
}
// Mock 安装
instance.Status = entity.StatusDeployed
instance.Revision = 1
instance.UpdatedAt = time.Now()
c.releases[cluster.ID][key] = instance
// 添加历史记录
c.history[cluster.ID][key] = []*entity.ReleaseHistory{
{
@ -55,25 +64,25 @@ func (c *HelmClientMock) Install(ctx context.Context, cluster *entity.Cluster, i
Description: "Install complete",
},
}
return nil
}
func (c *HelmClientMock) Upgrade(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
key := fmt.Sprintf("%s/%s", instance.Namespace, instance.Name)
existing, exists := c.releases[cluster.ID][key]
if !exists {
return entity.ErrInstanceNotFound
}
// Mock 升级
instance.Revision = existing.Revision + 1
instance.Status = entity.StatusDeployed
instance.UpdatedAt = time.Now()
c.releases[cluster.ID][key] = instance
// 添加历史记录
history := &entity.ReleaseHistory{
Revision: instance.Revision,
@ -84,44 +93,44 @@ func (c *HelmClientMock) Upgrade(ctx context.Context, cluster *entity.Cluster, i
Description: "Upgrade complete",
}
c.history[cluster.ID][key] = append(c.history[cluster.ID][key], history)
return nil
}
func (c *HelmClientMock) Uninstall(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) error {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
if _, exists := c.releases[cluster.ID][key]; !exists {
return entity.ErrInstanceNotFound
}
// Mock 卸载
delete(c.releases[cluster.ID], key)
return nil
}
func (c *HelmClientMock) Rollback(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string, revision int) error {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
instance, exists := c.releases[cluster.ID][key]
if !exists {
return entity.ErrInstanceNotFound
}
// 检查历史记录是否存在
histories := c.history[cluster.ID][key]
if revision > len(histories) || revision < 1 {
return fmt.Errorf("revision %d not found", revision)
}
// Mock 回滚
instance.Revision = len(histories) + 1
instance.Status = entity.StatusDeployed
instance.UpdatedAt = time.Now()
c.releases[cluster.ID][key] = instance
// 添加回滚历史记录
history := &entity.ReleaseHistory{
Revision: instance.Revision,
@ -132,33 +141,33 @@ func (c *HelmClientMock) Rollback(ctx context.Context, cluster *entity.Cluster,
Description: fmt.Sprintf("Rollback to revision %d", revision),
}
c.history[cluster.ID][key] = append(c.history[cluster.ID][key], history)
return nil
}
func (c *HelmClientMock) GetStatus(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (*entity.Instance, error) {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
instance, exists := c.releases[cluster.ID][key]
if !exists {
return nil, entity.ErrInstanceNotFound
}
return instance, nil
}
func (c *HelmClientMock) GetHistory(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) ([]*entity.ReleaseHistory, error) {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
if _, exists := c.releases[cluster.ID][key]; !exists {
return nil, entity.ErrInstanceNotFound
}
histories := c.history[cluster.ID][key]
if histories == nil {
return []*entity.ReleaseHistory{}, nil
}
return histories, nil
}
@ -167,7 +176,7 @@ func (c *HelmClientMock) List(ctx context.Context, cluster *entity.Cluster, name
if clusterReleases == nil {
return []*entity.Instance{}, nil
}
instances := make([]*entity.Instance, 0)
for key, instance := range clusterReleases {
// 如果指定了 namespace只返回该 namespace 的
@ -179,18 +188,18 @@ func (c *HelmClientMock) List(ctx context.Context, cluster *entity.Cluster, name
}
instances = append(instances, c.releases[cluster.ID][key])
}
return instances, nil
}
func (c *HelmClientMock) GetValues(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (map[string]interface{}, error) {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
instance, exists := c.releases[cluster.ID][key]
if !exists {
return nil, entity.ErrInstanceNotFound
}
return instance.Values, nil
}
@ -204,3 +213,16 @@ func (c *HelmClientMock) GetChartDefaultValues(chartPath string) (map[string]int
}, nil
}
func (c *HelmClientMock) EstimateInstanceResources(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) (*repository.ResourceEstimate, error) {
clusterID := ""
if cluster != nil {
clusterID = cluster.ID
}
key := fmt.Sprintf("%s/%s", instance.Namespace, instance.Name)
if c.estimates[clusterID] != nil {
if estimate := c.estimates[clusterID][key]; estimate != nil {
return estimate, nil
}
}
return &repository.ResourceEstimate{}, nil
}

View File

@ -10,6 +10,7 @@ import (
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
domainservice "github.com/ocdp/cluster-service/internal/domain/service"
"helm.sh/helm/v3/pkg/action"
"helm.sh/helm/v3/pkg/chart/loader"
"helm.sh/helm/v3/pkg/cli"
@ -346,6 +347,41 @@ func (h *HelmClient) GetChartDefaultValues(chartPath string) (map[string]interfa
return vals, nil
}
func (h *HelmClient) EstimateInstanceResources(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) (*repository.ResourceEstimate, error) {
chartPath := fmt.Sprintf("/tmp/charts/%s-%s.tgz", instance.Chart, instance.Version)
chart, err := loader.Load(chartPath)
if err != nil {
return nil, fmt.Errorf("failed to load chart: %w", err)
}
actionConfig := new(action.Configuration)
actionConfig.Log = func(format string, v ...interface{}) {}
install := action.NewInstall(actionConfig)
install.ReleaseName = instance.Name
if install.ReleaseName == "" {
install.ReleaseName = "quota-precheck"
}
install.Namespace = instance.Namespace
if install.Namespace == "" {
install.Namespace = "default"
}
install.DryRun = true
install.DryRunOption = "client"
install.ClientOnly = true
install.Replace = true
install.SkipSchemaValidation = true
values := instance.Values
if values == nil {
values = map[string]interface{}{}
}
release, err := install.RunWithContext(ctx, chart, values)
if err != nil {
return nil, fmt.Errorf("failed to render chart for quota estimate: %w", err)
}
return domainservice.EstimateRenderedManifestResources(release.Manifest)
}
// convertReleaseToInstance 转换 Helm Release 为 Instance
func (h *HelmClient) convertReleaseToInstance(rel *release.Release) *entity.Instance {
return &entity.Instance{

View File

@ -63,7 +63,7 @@ func (c *MetricsClient) GetClusterMetrics(ctx context.Context, clusterID string)
// 计算集群级别汇总
metrics := c.aggregateClusterMetrics(cluster, nodes.Items, pods.Items, nodeMetrics)
return metrics, nil
}
@ -87,6 +87,37 @@ func (c *MetricsClient) GetNodeMetrics(ctx context.Context, clusterID string) ([
return c.getNodeMetricsData(ctx, clientset, metricsClient, nodes.Items)
}
// GetPodResourceAllocations returns Kubernetes Pod requests/limits without
// inventing utilization values. GPU memory is treated as vendor integer MB.
func (c *MetricsClient) GetPodResourceAllocations(ctx context.Context, clusterID string) ([]*entity.PodResourceAllocation, error) {
cluster, err := c.clusterRepo.GetByID(ctx, clusterID)
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
clientset, _, err := c.createK8sClients(cluster)
if err != nil {
return nil, fmt.Errorf("failed to create k8s client: %w", err)
}
pods, err := clientset.CoreV1().Pods("").List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list pods: %w", err)
}
result := make([]*entity.PodResourceAllocation, 0, len(pods.Items))
for _, pod := range pods.Items {
result = append(result, &entity.PodResourceAllocation{
ClusterID: clusterID,
Namespace: pod.Namespace,
PodName: pod.Name,
InstanceName: inferHelmReleaseName(pod.Labels),
Allocation: podResourceAllocation(&pod),
})
}
return result, nil
}
// createK8sClients 创建 Kubernetes 客户端
func (c *MetricsClient) createK8sClients(cluster *entity.Cluster) (*kubernetes.Clientset, *metricsv.Clientset, error) {
config, err := clientcmd.RESTConfigFromKubeConfig([]byte(cluster.GetKubeConfig()))
@ -127,14 +158,14 @@ func (c *MetricsClient) getNodeMetricsData(
for _, node := range nodes {
nodeMetric := &entity.NodeMetrics{
NodeName: node.Name,
Status: getNodeStatus(&node),
Role: getNodeRole(&node),
Age: getNodeAge(&node),
OSImage: node.Status.NodeInfo.OSImage,
KernelVersion: node.Status.NodeInfo.KernelVersion,
ContainerRuntime: node.Status.NodeInfo.ContainerRuntimeVersion,
KubeletVersion: node.Status.NodeInfo.KubeletVersion,
NodeName: node.Name,
Status: getNodeStatus(&node),
Role: getNodeRole(&node),
Age: getNodeAge(&node),
OSImage: node.Status.NodeInfo.OSImage,
KernelVersion: node.Status.NodeInfo.KernelVersion,
ContainerRuntime: node.Status.NodeInfo.ContainerRuntimeVersion,
KubeletVersion: node.Status.NodeInfo.KubeletVersion,
}
// CPU
@ -213,7 +244,7 @@ func (c *MetricsClient) aggregateClusterMetrics(
var totalCPU, totalMem, usedCPU, usedMem int64
var totalGPU, usedGPU int
healthyNodes := 0
// 单机最大值
var maxNodeCPU, maxNodeMem int64
var maxNodeGPU int
@ -251,7 +282,7 @@ func (c *MetricsClient) aggregateClusterMetrics(
// 从 nodeMetrics 获取使用情况
if i < len(nodeMetrics) && nodeMetrics[i] != nil {
metrics.Nodes = append(metrics.Nodes, *nodeMetrics[i])
// 更新单机最大使用率
if nodeMetrics[i].CPUPercent > maxNodeCPUUsage {
maxNodeCPUUsage = nodeMetrics[i].CPUPercent
@ -274,7 +305,7 @@ func (c *MetricsClient) aggregateClusterMetrics(
metrics.TotalCPU = fmt.Sprintf("%.2f cores", float64(totalCPU)/1000.0)
metrics.TotalMemory = formatBytes(totalMem)
metrics.TotalGPU = totalGPU
// 格式化单机最大值
metrics.MaxNodeCPU = fmt.Sprintf("%.2f cores", float64(maxNodeCPU)/1000.0)
metrics.MaxNodeMemory = formatBytes(maxNodeMem)
@ -292,7 +323,7 @@ func (c *MetricsClient) aggregateClusterMetrics(
usedMem += int64(nm.MemoryPercent * float64(totalMem) / 100.0)
usedGPU += nm.GPUUsage
}
if totalCPU > 0 {
metrics.CPUUsage = float64(usedCPU) / float64(totalCPU) * 100
}
@ -302,7 +333,7 @@ func (c *MetricsClient) aggregateClusterMetrics(
if totalGPU > 0 {
metrics.GPUUsage = float64(usedGPU) / float64(totalGPU) * 100
}
metrics.UsedCPU = fmt.Sprintf("%.2f cores", float64(usedCPU)/1000.0)
metrics.UsedMemory = formatBytes(usedMem)
metrics.UsedGPU = usedGPU
@ -348,7 +379,7 @@ func getNodeAge(node *corev1.Node) string {
age := time.Since(node.CreationTimestamp.Time)
days := int(age.Hours() / 24)
hours := int(age.Hours()) % 24
if days > 0 {
return fmt.Sprintf("%dd %dh", days, hours)
}
@ -368,3 +399,110 @@ func formatBytes(bytes int64) string {
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
func inferHelmReleaseName(labels map[string]string) string {
if labels == nil {
return ""
}
for _, key := range []string{
"app.kubernetes.io/instance",
"release",
"helm.sh/release",
"meta.helm.sh/release-name",
"app",
} {
if value := labels[key]; value != "" {
return value
}
}
return ""
}
func podResourceAllocation(pod *corev1.Pod) entity.ResourceAllocation {
if pod == nil {
return entity.ResourceAllocation{}
}
sum := entity.ResourceAllocation{}
for _, container := range pod.Spec.Containers {
sum = addContainerAllocation(sum, container)
}
initMax := entity.ResourceAllocation{}
for _, container := range pod.Spec.InitContainers {
initMax = maxAllocation(initMax, containerAllocation(container))
}
return maxAllocation(sum, initMax)
}
func addContainerAllocation(base entity.ResourceAllocation, container corev1.Container) entity.ResourceAllocation {
return addAllocation(base, containerAllocation(container))
}
func containerAllocation(container corev1.Container) entity.ResourceAllocation {
requests := container.Resources.Requests
limits := container.Resources.Limits
return entity.ResourceAllocation{
CPURequestsMilli: quantityMilliValue(requests, corev1.ResourceCPU),
CPULimitsMilli: quantityMilliValue(limits, corev1.ResourceCPU),
MemoryRequestsBytes: quantityValue(requests, corev1.ResourceMemory),
MemoryLimitsBytes: quantityValue(limits, corev1.ResourceMemory),
GPURequests: quantityValue(requests, corev1.ResourceName("nvidia.com/gpu")),
GPULimits: quantityValue(limits, corev1.ResourceName("nvidia.com/gpu")),
GPUMemoryRequestsMB: quantityValueAny(requests, corev1.ResourceName("nvidia.com/gpumem"), corev1.ResourceName("requests.nvidia.com/gpumem")),
GPUMemoryLimitsMB: quantityValueAny(limits, corev1.ResourceName("nvidia.com/gpumem"), corev1.ResourceName("requests.nvidia.com/gpumem")),
}
}
func addAllocation(left, right entity.ResourceAllocation) entity.ResourceAllocation {
return entity.ResourceAllocation{
CPURequestsMilli: left.CPURequestsMilli + right.CPURequestsMilli,
CPULimitsMilli: left.CPULimitsMilli + right.CPULimitsMilli,
MemoryRequestsBytes: left.MemoryRequestsBytes + right.MemoryRequestsBytes,
MemoryLimitsBytes: left.MemoryLimitsBytes + right.MemoryLimitsBytes,
GPURequests: left.GPURequests + right.GPURequests,
GPULimits: left.GPULimits + right.GPULimits,
GPUMemoryRequestsMB: left.GPUMemoryRequestsMB + right.GPUMemoryRequestsMB,
GPUMemoryLimitsMB: left.GPUMemoryLimitsMB + right.GPUMemoryLimitsMB,
}
}
func maxAllocation(left, right entity.ResourceAllocation) entity.ResourceAllocation {
return entity.ResourceAllocation{
CPURequestsMilli: maxInt64(left.CPURequestsMilli, right.CPURequestsMilli),
CPULimitsMilli: maxInt64(left.CPULimitsMilli, right.CPULimitsMilli),
MemoryRequestsBytes: maxInt64(left.MemoryRequestsBytes, right.MemoryRequestsBytes),
MemoryLimitsBytes: maxInt64(left.MemoryLimitsBytes, right.MemoryLimitsBytes),
GPURequests: maxInt64(left.GPURequests, right.GPURequests),
GPULimits: maxInt64(left.GPULimits, right.GPULimits),
GPUMemoryRequestsMB: maxInt64(left.GPUMemoryRequestsMB, right.GPUMemoryRequestsMB),
GPUMemoryLimitsMB: maxInt64(left.GPUMemoryLimitsMB, right.GPUMemoryLimitsMB),
}
}
func quantityMilliValue(resources corev1.ResourceList, name corev1.ResourceName) int64 {
if quantity, ok := resources[name]; ok {
return quantity.MilliValue()
}
return 0
}
func quantityValue(resources corev1.ResourceList, name corev1.ResourceName) int64 {
if quantity, ok := resources[name]; ok {
return quantity.Value()
}
return 0
}
func quantityValueAny(resources corev1.ResourceList, names ...corev1.ResourceName) int64 {
for _, name := range names {
if quantity, ok := resources[name]; ok {
return quantity.Value()
}
}
return 0
}
func maxInt64(left, right int64) int64 {
if left > right {
return left
}
return right
}

View File

@ -0,0 +1,29 @@
package k8s
import (
"testing"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)
func TestContainerAllocationCountsVendorGPUMemoryKey(t *testing.T) {
container := corev1.Container{
Resources: corev1.ResourceRequirements{
Requests: corev1.ResourceList{
corev1.ResourceName("nvidia.com/gpumem"): resource.MustParse("10000"),
},
Limits: corev1.ResourceList{
corev1.ResourceName("nvidia.com/gpumem"): resource.MustParse("12000"),
},
},
}
allocation := containerAllocation(container)
if allocation.GPUMemoryRequestsMB != 10000 {
t.Fatalf("expected GPU memory requests 10000 MB, got %d", allocation.GPUMemoryRequestsMB)
}
if allocation.GPUMemoryLimitsMB != 12000 {
t.Fatalf("expected GPU memory limits 12000 MB, got %d", allocation.GPUMemoryLimitsMB)
}
}

View File

@ -106,6 +106,25 @@ func (c *TenantClient) IssueKubeconfig(ctx context.Context, cluster *entity.Clus
}, nil
}
func (c *TenantClient) GetResourceQuotaUsage(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) (*repository.ResourceQuotaUsage, error) {
binding = binding.WithDefaults()
if err := binding.Validate(); err != nil {
return nil, err
}
clientset, _, err := c.clientsetForCluster(cluster)
if err != nil {
return nil, err
}
quota, err := clientset.CoreV1().ResourceQuotas(binding.Namespace).Get(ctx, binding.ResourceQuotaName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("failed to get tenant resource quota usage: %w", err)
}
return &repository.ResourceQuotaUsage{
Hard: resourceVectorFromList(quota.Status.Hard),
Used: resourceVectorFromList(quota.Status.Used),
}, nil
}
// SuspendTenant revokes tenant API access by deleting only the RoleBinding.
func (c *TenantClient) SuspendTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
binding = binding.WithDefaults()
@ -128,6 +147,82 @@ func (c *TenantClient) SuspendTenant(ctx context.Context, cluster *entity.Cluste
return nil
}
func (c *TenantClient) DeleteTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
binding = binding.WithDefaults()
if err := binding.Validate(); err != nil {
return err
}
if isProtectedTenantNamespace(binding.Namespace) {
return entity.ErrProtectedNamespace
}
clientset, _, err := c.clientsetForCluster(cluster)
if err != nil {
return err
}
if err := deleteIgnoringNotFound(ctx, func() error {
return clientset.RbacV1().RoleBindings(binding.Namespace).Delete(ctx, binding.RoleBindingName, metav1.DeleteOptions{})
}); err != nil {
return fmt.Errorf("failed to delete tenant role binding: %w", err)
}
if err := deleteIgnoringNotFound(ctx, func() error {
return clientset.CoreV1().ResourceQuotas(binding.Namespace).Delete(ctx, binding.ResourceQuotaName, metav1.DeleteOptions{})
}); err != nil {
return fmt.Errorf("failed to delete tenant resource quota: %w", err)
}
if err := deleteIgnoringNotFound(ctx, func() error {
return clientset.CoreV1().ServiceAccounts(binding.Namespace).Delete(ctx, binding.ServiceAccountName, metav1.DeleteOptions{})
}); err != nil {
return fmt.Errorf("failed to delete tenant service account: %w", err)
}
namespace, err := clientset.CoreV1().Namespaces().Get(ctx, binding.Namespace, metav1.GetOptions{})
if apierrors.IsNotFound(err) {
return nil
}
if err != nil {
return fmt.Errorf("failed to get tenant namespace before deletion: %w", err)
}
if namespace.Labels["ocdp.io/managed-by"] != "ocdp" || namespace.Labels["ocdp.io/tenant"] != binding.Namespace {
return fmt.Errorf("refusing to delete unmanaged namespace %q", binding.Namespace)
}
if err := deleteIgnoringNotFound(ctx, func() error {
return clientset.CoreV1().Namespaces().Delete(ctx, binding.Namespace, metav1.DeleteOptions{})
}); err != nil {
return fmt.Errorf("failed to delete tenant namespace: %w", err)
}
return nil
}
func deleteIgnoringNotFound(ctx context.Context, deleteFn func() error) error {
if err := ctx.Err(); err != nil {
return err
}
err := deleteFn()
if apierrors.IsNotFound(err) {
return nil
}
return err
}
func isProtectedTenantNamespace(namespace string) bool {
switch strings.TrimSpace(namespace) {
case "", "default", "kube-system", "kube-public", "kube-node-lease":
return true
default:
return false
}
}
func resourceVectorFromList(values corev1.ResourceList) repository.ResourceVector {
gpu := values[corev1.ResourceName("requests.nvidia.com/gpu")]
gpuMem := values[corev1.ResourceName("requests.nvidia.com/gpumem")]
return repository.ResourceVector{
CPU: values[corev1.ResourceName("requests.cpu")],
Memory: values[corev1.ResourceName("requests.memory")],
GPU: gpu.Value(),
GPUMemoryMB: gpuMem.Value(),
}
}
func (c *TenantClient) clientsetForCluster(cluster *entity.Cluster) (kubernetes.Interface, *rest.Config, error) {
if c.clientset != nil {
config := &rest.Config{Host: "https://kubernetes.default.svc"}

View File

@ -2,6 +2,7 @@ package k8s
import (
"context"
"errors"
"strings"
"testing"
"time"
@ -58,7 +59,7 @@ func TestTenantClientEnsureTenantUpdatesExistingResources(t *testing.T) {
ctx := context.Background()
binding := tenantBinding()
clientset := fake.NewSimpleClientset(
&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: binding.Namespace}},
&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: binding.Namespace, Labels: binding.Labels}},
&corev1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{Name: binding.ServiceAccountName, Namespace: binding.Namespace}},
&rbacv1.RoleBinding{
ObjectMeta: metav1.ObjectMeta{Name: binding.RoleBindingName, Namespace: binding.Namespace},
@ -100,7 +101,7 @@ func TestTenantClientSuspendTenantDeletesOnlyRoleBinding(t *testing.T) {
ctx := context.Background()
binding := tenantBinding()
clientset := fake.NewSimpleClientset(
&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: binding.Namespace}},
&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: binding.Namespace, Labels: binding.Labels}},
&corev1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{Name: binding.ServiceAccountName, Namespace: binding.Namespace}},
desiredRoleBinding(binding),
)
@ -117,6 +118,47 @@ func TestTenantClientSuspendTenantDeletesOnlyRoleBinding(t *testing.T) {
}
}
func TestTenantClientDeleteTenantDeletesTenantResources(t *testing.T) {
ctx := context.Background()
binding := tenantBinding()
clientset := fake.NewSimpleClientset(
&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: binding.Namespace, Labels: binding.Labels}},
&corev1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{Name: binding.ServiceAccountName, Namespace: binding.Namespace}},
desiredRoleBinding(binding),
&corev1.ResourceQuota{ObjectMeta: metav1.ObjectMeta{Name: binding.ResourceQuotaName, Namespace: binding.Namespace}},
)
client := NewTenantClientForClientset(clientset)
if err := client.DeleteTenant(ctx, nil, binding); err != nil {
t.Fatalf("DeleteTenant returned error: %v", err)
}
if _, err := clientset.RbacV1().RoleBindings(binding.Namespace).Get(ctx, binding.RoleBindingName, metav1.GetOptions{}); !apierrors.IsNotFound(err) {
t.Fatalf("expected role binding deleted, got %v", err)
}
if _, err := clientset.CoreV1().ResourceQuotas(binding.Namespace).Get(ctx, binding.ResourceQuotaName, metav1.GetOptions{}); !apierrors.IsNotFound(err) {
t.Fatalf("expected resource quota deleted, got %v", err)
}
if _, err := clientset.CoreV1().ServiceAccounts(binding.Namespace).Get(ctx, binding.ServiceAccountName, metav1.GetOptions{}); !apierrors.IsNotFound(err) {
t.Fatalf("expected service account deleted, got %v", err)
}
if _, err := clientset.CoreV1().Namespaces().Get(ctx, binding.Namespace, metav1.GetOptions{}); !apierrors.IsNotFound(err) {
t.Fatalf("expected namespace deleted, got %v", err)
}
}
func TestTenantClientDeleteTenantRejectsProtectedNamespace(t *testing.T) {
ctx := context.Background()
client := NewTenantClientForClientset(fake.NewSimpleClientset(
&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: "default"}},
))
binding := entity.NewTenantBinding("default")
err := client.DeleteTenant(ctx, nil, binding)
if !errors.Is(err, entity.ErrProtectedNamespace) {
t.Fatalf("expected protected namespace error, got %v", err)
}
}
func TestTenantClientIssueKubeconfigCapsTokenTTL(t *testing.T) {
ctx := context.Background()
binding := tenantBinding()

View File

@ -31,6 +31,28 @@ func (c *MockTenantClient) IssueKubeconfig(ctx context.Context, cluster *entity.
}, nil
}
func (c *MockTenantClient) GetResourceQuotaUsage(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) (*repository.ResourceQuotaUsage, error) {
if err := binding.Validate(); err != nil {
return nil, err
}
return &repository.ResourceQuotaUsage{
Hard: resourceVectorFromList(binding.ResourceQuotaHard),
Used: repository.ResourceVector{},
}, nil
}
func (c *MockTenantClient) SuspendTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
return binding.Validate()
}
func (c *MockTenantClient) DeleteTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
if err := binding.Validate(); err != nil {
return err
}
switch binding.Namespace {
case "", "default", "kube-system", "kube-public", "kube-node-lease":
return entity.ErrProtectedNamespace
default:
return nil
}
}

View File

@ -72,6 +72,16 @@ func (r *WorkspaceRepositoryMock) Update(ctx context.Context, workspace *entity.
return nil
}
func (r *WorkspaceRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.workspaces[id]; !ok {
return entity.ErrWorkspaceNotFound
}
delete(r.workspaces, id)
return nil
}
func (r *WorkspaceRepositoryMock) List(ctx context.Context) ([]*entity.Workspace, error) {
r.mu.RLock()
defer r.mu.RUnlock()
@ -118,6 +128,20 @@ func (r *WorkspaceClusterBindingRepositoryMock) Get(ctx context.Context, workspa
return &copy, nil
}
func (r *WorkspaceClusterBindingRepositoryMock) ListByWorkspace(ctx context.Context, workspaceID string) ([]*entity.WorkspaceClusterBinding, error) {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]*entity.WorkspaceClusterBinding, 0)
for _, binding := range r.bindings {
if binding.WorkspaceID != workspaceID {
continue
}
copy := *binding
result = append(result, &copy)
}
return result, nil
}
func (r *WorkspaceClusterBindingRepositoryMock) Delete(ctx context.Context, workspaceID, clusterID string) error {
r.mu.Lock()
defer r.mu.Unlock()

View File

@ -27,8 +27,9 @@ func (r *WorkspaceRepository) Create(ctx context.Context, workspace *entity.Work
query := `
INSERT INTO workspaces (id, name, status, k8s_namespace, k8s_sa_name, default_cluster_id, quota_cpu, quota_memory, quota_gpu, quota_gpu_memory, created_by, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (name) DO NOTHING
`
_, err := r.db.conn.ExecContext(ctx, query,
result, err := r.db.conn.ExecContext(ctx, query,
workspace.ID,
workspace.Name,
workspace.Status,
@ -46,6 +47,13 @@ func (r *WorkspaceRepository) Create(ctx context.Context, workspace *entity.Work
if err != nil {
return fmt.Errorf("failed to create workspace: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrWorkspaceExists
}
return nil
}
@ -132,6 +140,21 @@ func (r *WorkspaceRepository) Update(ctx context.Context, workspace *entity.Work
return nil
}
func (r *WorkspaceRepository) Delete(ctx context.Context, id string) error {
result, err := r.db.conn.ExecContext(ctx, `DELETE FROM workspaces WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("failed to delete workspace: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrWorkspaceNotFound
}
return nil
}
func (r *WorkspaceRepository) List(ctx context.Context) ([]*entity.Workspace, error) {
query := `
SELECT id, name, status, k8s_namespace, k8s_sa_name, default_cluster_id, quota_cpu, quota_memory, quota_gpu, quota_gpu_memory, created_by, created_at, updated_at
@ -256,6 +279,42 @@ func (r *WorkspaceClusterBindingRepository) Get(ctx context.Context, workspaceID
return binding, nil
}
func (r *WorkspaceClusterBindingRepository) ListByWorkspace(ctx context.Context, workspaceID string) ([]*entity.WorkspaceClusterBinding, error) {
query := `
SELECT id, workspace_id, cluster_id, namespace, service_account, quota_cpu, quota_memory, quota_gpu, quota_gpu_memory, status, created_at, updated_at
FROM workspace_cluster_bindings
WHERE workspace_id = $1
ORDER BY created_at ASC
`
rows, err := r.db.conn.QueryContext(ctx, query, workspaceID)
if err != nil {
return nil, fmt.Errorf("failed to list workspace cluster bindings: %w", err)
}
defer rows.Close()
bindings := make([]*entity.WorkspaceClusterBinding, 0)
for rows.Next() {
binding := &entity.WorkspaceClusterBinding{}
if err := rows.Scan(
&binding.ID,
&binding.WorkspaceID,
&binding.ClusterID,
&binding.Namespace,
&binding.ServiceAccount,
&binding.QuotaCPU,
&binding.QuotaMemory,
&binding.QuotaGPU,
&binding.QuotaGPUMem,
&binding.Status,
&binding.CreatedAt,
&binding.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan workspace cluster binding: %w", err)
}
bindings = append(bindings, binding)
}
return bindings, rows.Err()
}
func (r *WorkspaceClusterBindingRepository) Delete(ctx context.Context, workspaceID, clusterID string) error {
_, err := r.db.conn.ExecContext(ctx, `DELETE FROM workspace_cluster_bindings WHERE workspace_id = $1 AND cluster_id = $2`, workspaceID, clusterID)
return err

View File

@ -43,6 +43,9 @@ var (
ErrValuesSchemaNotFound = errors.New("values schema not found")
// Workspace errors
ErrWorkspaceNotFound = errors.New("workspace not found")
ErrWorkspaceExists = errors.New("workspace already exists")
ErrWorkspaceNotFound = errors.New("workspace not found")
ErrWorkspaceExists = errors.New("workspace already exists")
ErrWorkspaceNamespaceConflict = errors.New("workspace namespace conflict")
ErrUserHasInstances = errors.New("user has active instances")
ErrProtectedNamespace = errors.New("protected namespace")
)

View File

@ -54,6 +54,7 @@ type Instance struct {
CreatedAt time.Time
UpdatedAt time.Time
Replicas int // Running K8s replicas (enriched, not persisted)
OwnerUsername string
}
// NewInstance 创建新实例

View File

@ -25,6 +25,18 @@ type ClusterMetrics struct {
MemoryUsage float64 `json:"memory_usage"` // 百分比
GPUUsage float64 `json:"gpu_usage"` // 百分比
CPURequests string `json:"cpu_requests,omitempty"`
CPULimits string `json:"cpu_limits,omitempty"`
MemoryRequests string `json:"memory_requests,omitempty"`
MemoryLimits string `json:"memory_limits,omitempty"`
GPURequests int64 `json:"gpu_requests,omitempty"`
GPULimits int64 `json:"gpu_limits,omitempty"`
GPUMemoryRequestsMB int64 `json:"gpu_memory_requests_mb,omitempty"`
GPUMemoryLimitsMB int64 `json:"gpu_memory_limits_mb,omitempty"`
AllocatedGPU int64 `json:"allocated_gpu,omitempty"`
AllocatedGPUMemoryMB int64 `json:"allocated_gpu_memory_mb,omitempty"`
ResourceUsageByUser []UserResourceUsage `json:"resource_usage_by_user,omitempty"`
// 单机资源最大值
MaxNodeCPU string `json:"max_node_cpu"` // 单机最大CPU容量如 "8 cores"
MaxNodeMemory string `json:"max_node_memory"` // 单机最大内存容量,如 "32 GB"
@ -37,6 +49,42 @@ type ClusterMetrics struct {
Nodes []NodeMetrics `json:"nodes,omitempty"`
}
// ResourceAllocation is derived from Kubernetes Pod resources requests/limits.
type ResourceAllocation struct {
CPURequestsMilli int64
CPULimitsMilli int64
MemoryRequestsBytes int64
MemoryLimitsBytes int64
GPURequests int64
GPULimits int64
GPUMemoryRequestsMB int64
GPUMemoryLimitsMB int64
}
type PodResourceAllocation struct {
ClusterID string
Namespace string
PodName string
InstanceName string
Allocation ResourceAllocation
}
type UserResourceUsage struct {
UserID string `json:"user_id"`
Username string `json:"username"`
WorkspaceID string `json:"workspace_id"`
InstanceCount int `json:"instance_count"`
PodCount int `json:"pod_count"`
CPURequests string `json:"cpu_requests"`
CPULimits string `json:"cpu_limits"`
MemoryRequests string `json:"memory_requests"`
MemoryLimits string `json:"memory_limits"`
GPURequests int64 `json:"gpu_requests"`
GPULimits int64 `json:"gpu_limits"`
GPUMemoryRequestsMB int64 `json:"gpu_memory_requests_mb"`
GPUMemoryLimitsMB int64 `json:"gpu_memory_limits_mb"`
}
// NodeMetrics 节点监控指标
type NodeMetrics struct {
NodeName string `json:"node_name"`

View File

@ -3,8 +3,21 @@ package repository
import (
"context"
"github.com/ocdp/cluster-service/internal/domain/entity"
"k8s.io/apimachinery/pkg/api/resource"
)
type ResourceVector struct {
CPU resource.Quantity
Memory resource.Quantity
GPU int64
GPUMemoryMB int64
}
type ResourceEstimate struct {
Requests ResourceVector
Limits ResourceVector
}
// HelmClient Helm 客户端接口Output Port
type HelmClient interface {
// Install 安装 Helm Chart
@ -33,4 +46,7 @@ type HelmClient interface {
// GetChartDefaultValues 从 chart 包中读取默认 values
GetChartDefaultValues(chartPath string) (map[string]interface{}, error)
// EstimateInstanceResources renders an instance chart with final values and sums Pod template resources.
EstimateInstanceResources(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) (*ResourceEstimate, error)
}

View File

@ -13,4 +13,7 @@ type MetricsClient interface {
// GetNodeMetrics 获取集群的节点指标
GetNodeMetrics(ctx context.Context, clusterID string) ([]*entity.NodeMetrics, error)
// GetPodResourceAllocations returns Pod requests/limits grouped by Pod.
GetPodResourceAllocations(ctx context.Context, clusterID string) ([]*entity.PodResourceAllocation, error)
}

View File

@ -11,5 +11,12 @@ import (
type TenantKubeClient interface {
EnsureTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error
IssueKubeconfig(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding, ttl time.Duration) (*entity.TenantKubeconfig, error)
GetResourceQuotaUsage(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) (*ResourceQuotaUsage, error)
SuspendTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error
DeleteTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error
}
type ResourceQuotaUsage struct {
Hard ResourceVector
Used ResourceVector
}

View File

@ -11,12 +11,14 @@ type WorkspaceRepository interface {
GetByID(ctx context.Context, id string) (*entity.Workspace, error)
GetByName(ctx context.Context, name string) (*entity.Workspace, error)
Update(ctx context.Context, workspace *entity.Workspace) error
Delete(ctx context.Context, id string) error
List(ctx context.Context) ([]*entity.Workspace, error)
}
type WorkspaceClusterBindingRepository interface {
Upsert(ctx context.Context, binding *entity.WorkspaceClusterBinding) error
Get(ctx context.Context, workspaceID, clusterID string) (*entity.WorkspaceClusterBinding, error)
ListByWorkspace(ctx context.Context, workspaceID string) ([]*entity.WorkspaceClusterBinding, error)
Delete(ctx context.Context, workspaceID, clusterID string) error
}

View File

@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"strings"
"time"
@ -18,6 +19,10 @@ import (
type AuthService struct {
userRepo repository.UserRepository
workspaceRepo repository.WorkspaceRepository
instanceRepo repository.InstanceRepository
clusterRepo repository.ClusterRepository
bindingRepo repository.WorkspaceClusterBindingRepository
tenantClient repository.TenantKubeClient
passwordHasher PasswordHasher
tokenGenerator TokenGenerator
}
@ -53,6 +58,18 @@ func NewAuthService(
}
}
func (s *AuthService) SetUserLifecycleCleanup(
instanceRepo repository.InstanceRepository,
clusterRepo repository.ClusterRepository,
bindingRepo repository.WorkspaceClusterBindingRepository,
tenantClient repository.TenantKubeClient,
) {
s.instanceRepo = instanceRepo
s.clusterRepo = clusterRepo
s.bindingRepo = bindingRepo
s.tenantClient = tenantClient
}
// Register 注册新用户。业务入口只允许 admin 调用;初始 admin 由 bootstrap seeder 创建。
type UserWorkspaceOptions struct {
Namespace string
@ -87,6 +104,9 @@ func (s *AuthService) Register(ctx context.Context, username, password, role, wo
if err != nil {
return nil, err
}
if normalizeUserRole(role) == authz.RoleUser {
normalizedOpts = defaultUserQuotaOptions(normalizedOpts)
}
// 默认生成占位邮箱,避免数据库约束失败
email := username + "@local.ocdp"
@ -96,7 +116,7 @@ func (s *AuthService) Register(ctx context.Context, username, password, role, wo
user.ID = uuid.New().String()
user.Role = normalizeUserRole(role)
user.WorkspaceID = workspaceID
if user.Role == authz.RoleUser && (user.WorkspaceID == "" || user.WorkspaceID == entity.DefaultWorkspaceID) {
if user.Role == authz.RoleUser {
workspace, err := s.createUserWorkspace(ctx, username, principal.UserID, normalizedOpts)
if err != nil {
return nil, err
@ -131,10 +151,7 @@ func (s *AuthService) createUserWorkspace(ctx context.Context, username, created
if s.workspaceRepo == nil {
return nil, entity.ErrWorkspaceNotFound
}
name := strings.TrimPrefix(entity.NamespaceForUser(username), "ocdp-u-")
workspace := entity.NewWorkspace(name, createdBy)
workspace.ID = uuid.New().String()
workspace.DefaultClusterID = strings.TrimSpace(opts.DefaultClusterID)
name := userWorkspaceName(username)
namespace := strings.TrimSpace(opts.Namespace)
if namespace == "" {
namespace = entity.NamespaceForUser(username)
@ -143,6 +160,32 @@ func (s *AuthService) createUserWorkspace(ctx context.Context, username, created
if len(validation.IsDNS1123Label(namespace)) > 0 {
return nil, entity.ErrInvalidNamespace
}
}
if existing, err := s.workspaceRepo.GetByName(ctx, name); err == nil && existing != nil {
if namespace != "" && existing.K8sNamespace != namespace {
if err := s.ensureNamespaceAvailable(ctx, namespace, existing.ID); err != nil {
return nil, err
}
}
applyWorkspaceOptions(existing, opts)
if namespace != "" {
existing.K8sNamespace = namespace
existing.K8sSAName = entity.ServiceAccountForNamespace(namespace)
}
if err := s.workspaceRepo.Update(ctx, existing); err != nil {
return nil, err
}
return existing, nil
} else if err != nil && !errors.Is(err, entity.ErrWorkspaceNotFound) {
return nil, err
}
if err := s.ensureNamespaceAvailable(ctx, namespace, ""); err != nil {
return nil, err
}
workspace := entity.NewWorkspace(name, createdBy)
workspace.ID = uuid.New().String()
workspace.DefaultClusterID = strings.TrimSpace(opts.DefaultClusterID)
if namespace != "" {
workspace.K8sNamespace = namespace
workspace.K8sSAName = entity.ServiceAccountForNamespace(namespace)
}
@ -151,11 +194,45 @@ func (s *AuthService) createUserWorkspace(ctx context.Context, username, created
workspace.QuotaGPU = strings.TrimSpace(opts.QuotaGPU)
workspace.QuotaGPUMem = strings.TrimSpace(opts.QuotaGPUMem)
if err := s.workspaceRepo.Create(ctx, workspace); err != nil {
if errors.Is(err, entity.ErrWorkspaceExists) {
existing, getErr := s.workspaceRepo.GetByName(ctx, name)
if getErr != nil {
return nil, err
}
if existing.K8sNamespace != namespace {
return nil, entity.ErrWorkspaceNamespaceConflict
}
return existing, nil
}
return nil, err
}
return workspace, nil
}
func userWorkspaceName(username string) string {
return strings.TrimPrefix(entity.NamespaceForUser(username), "ocdp-u-")
}
func (s *AuthService) ensureNamespaceAvailable(ctx context.Context, namespace, allowedWorkspaceID string) error {
if s.workspaceRepo == nil || strings.TrimSpace(namespace) == "" {
return nil
}
workspaces, err := s.workspaceRepo.List(ctx)
if err != nil {
return err
}
for _, workspace := range workspaces {
if workspace == nil || workspace.K8sNamespace != namespace {
continue
}
if allowedWorkspaceID != "" && workspace.ID == allowedWorkspaceID {
continue
}
return entity.ErrWorkspaceNamespaceConflict
}
return nil
}
func normalizeQuotaOptions(opts UserWorkspaceOptions) (UserWorkspaceOptions, error) {
opts.Namespace = strings.TrimSpace(opts.Namespace)
opts.DefaultClusterID = strings.TrimSpace(opts.DefaultClusterID)
@ -181,6 +258,16 @@ func normalizeQuotaOptions(opts UserWorkspaceOptions) (UserWorkspaceOptions, err
return opts, nil
}
func defaultUserQuotaOptions(opts UserWorkspaceOptions) UserWorkspaceOptions {
if strings.TrimSpace(opts.QuotaGPU) == "" {
opts.QuotaGPU = "0"
}
if strings.TrimSpace(opts.QuotaGPUMem) == "" {
opts.QuotaGPUMem = "0"
}
return opts
}
func (s *AuthService) ListUsers(ctx context.Context) ([]*entity.User, error) {
principal, err := authz.RequirePrincipal(ctx)
if err != nil {
@ -204,25 +291,35 @@ func (s *AuthService) UpdateUser(ctx context.Context, userID, role, workspaceID
if err != nil {
return nil, entity.ErrUserNotFound
}
previousRole := user.Role
if role != "" {
user.Role = normalizeUserRole(role)
}
if workspaceID != "" {
if workspaceID != "" && user.Role != authz.RoleUser {
user.WorkspaceID = workspaceID
}
workspaceHandled := false
if user.Role == authz.RoleAdmin {
user.WorkspaceID = entity.DefaultWorkspaceID
}
if user.Role == authz.RoleUser && (user.WorkspaceID == "" || user.WorkspaceID == entity.DefaultWorkspaceID) {
if user.Role == authz.RoleUser && (role != "" || workspaceID != "" || hasWorkspaceUpdates(opts)) {
normalizedOpts, err := normalizeQuotaOptions(opts)
if err != nil {
return nil, err
}
workspace, err := s.createUserWorkspace(ctx, user.Username, principal.UserID, normalizedOpts)
normalizedOpts = defaultUserQuotaOptions(normalizedOpts)
currentWorkspace, _ := s.currentUserWorkspace(ctx, user)
if currentWorkspace != nil && shouldCreatePrivateWorkspace(user, previousRole, currentWorkspace) {
if normalizedOpts.Namespace == "" || normalizedOpts.Namespace == currentWorkspace.K8sNamespace {
normalizedOpts.Namespace = ""
}
}
workspace, err := s.ensureUserWorkspaceForUpdate(ctx, user, previousRole, currentWorkspace, opts, normalizedOpts, principal.UserID)
if err != nil {
return nil, err
}
user.WorkspaceID = workspace.ID
workspaceHandled = true
}
if isActive != nil {
if user.ID == principal.UserID && !*isActive {
@ -233,7 +330,7 @@ func (s *AuthService) UpdateUser(ctx context.Context, userID, role, workspaceID
if mustChangePassword != nil {
user.MustChangePassword = *mustChangePassword
}
if user.Role != authz.RoleAdmin && hasWorkspaceUpdates(opts) {
if user.Role != authz.RoleAdmin && !workspaceHandled && hasWorkspaceUpdates(opts) {
normalizedOpts, err := normalizeQuotaOptions(opts)
if err != nil {
return nil, err
@ -242,10 +339,13 @@ func (s *AuthService) UpdateUser(ctx context.Context, userID, role, workspaceID
if err != nil {
return nil, err
}
applyWorkspaceOptions(workspace, normalizedOpts)
applyWorkspaceOptionsForUpdate(workspace, opts, normalizedOpts)
if err := s.workspaceRepo.Update(ctx, workspace); err != nil {
return nil, err
}
if err := s.syncWorkspaceBindings(ctx, workspace); err != nil {
return nil, err
}
}
user.RevokedAfter = time.Now()
user.UpdatedAt = time.Now()
@ -289,6 +389,115 @@ func applyWorkspaceOptions(workspace *entity.Workspace, opts UserWorkspaceOption
}
}
func (s *AuthService) currentUserWorkspace(ctx context.Context, user *entity.User) (*entity.Workspace, error) {
if s.workspaceRepo == nil || user == nil || user.WorkspaceID == "" {
return nil, entity.ErrWorkspaceNotFound
}
return s.workspaceRepo.GetByID(ctx, user.WorkspaceID)
}
func shouldCreatePrivateWorkspace(user *entity.User, previousRole string, current *entity.Workspace) bool {
if user == nil {
return true
}
if previousRole == authz.RoleAdmin || user.WorkspaceID == "" || user.WorkspaceID == entity.DefaultWorkspaceID {
return true
}
if current == nil {
return true
}
return current.Name != userWorkspaceName(user.Username)
}
func (s *AuthService) ensureUserWorkspaceForUpdate(ctx context.Context, user *entity.User, previousRole string, current *entity.Workspace, rawOpts, normalizedOpts UserWorkspaceOptions, createdBy string) (*entity.Workspace, error) {
if s.workspaceRepo == nil {
return nil, entity.ErrWorkspaceNotFound
}
if shouldCreatePrivateWorkspace(user, previousRole, current) {
return s.createUserWorkspace(ctx, user.Username, createdBy, normalizedOpts)
}
if rawNamespace := strings.TrimSpace(rawOpts.Namespace); rawNamespace != "" && rawNamespace != current.K8sNamespace {
if err := s.ensureNamespaceAvailable(ctx, rawNamespace, current.ID); err != nil {
return nil, err
}
}
applyWorkspaceOptionsForUpdate(current, rawOpts, normalizedOpts)
if err := s.workspaceRepo.Update(ctx, current); err != nil {
return nil, err
}
if err := s.syncWorkspaceBindings(ctx, current); err != nil {
return nil, err
}
return current, nil
}
func applyWorkspaceOptionsForUpdate(workspace *entity.Workspace, rawOpts, normalizedOpts UserWorkspaceOptions) {
if namespace := strings.TrimSpace(rawOpts.Namespace); namespace != "" {
workspace.K8sNamespace = namespace
workspace.K8sSAName = entity.ServiceAccountForNamespace(namespace)
}
if strings.TrimSpace(rawOpts.DefaultClusterID) != "" {
workspace.DefaultClusterID = normalizedOpts.DefaultClusterID
}
if strings.TrimSpace(rawOpts.QuotaCPU) != "" {
workspace.QuotaCPU = normalizedOpts.QuotaCPU
}
if strings.TrimSpace(rawOpts.QuotaMemory) != "" {
workspace.QuotaMemory = normalizedOpts.QuotaMemory
}
if strings.TrimSpace(rawOpts.QuotaGPU) != "" {
workspace.QuotaGPU = normalizedOpts.QuotaGPU
}
if strings.TrimSpace(rawOpts.QuotaGPUMem) != "" {
workspace.QuotaGPUMem = normalizedOpts.QuotaGPUMem
}
}
func (s *AuthService) syncWorkspaceBindings(ctx context.Context, workspace *entity.Workspace) error {
if workspace == nil || s.bindingRepo == nil {
return nil
}
bindings, err := s.bindingRepo.ListByWorkspace(ctx, workspace.ID)
if err != nil {
return err
}
for _, binding := range bindings {
if binding == nil {
continue
}
binding.QuotaCPU = strings.TrimSpace(workspace.QuotaCPU)
binding.QuotaMemory = strings.TrimSpace(workspace.QuotaMemory)
binding.QuotaGPU = strings.TrimSpace(workspace.QuotaGPU)
if binding.QuotaGPU == "" {
binding.QuotaGPU = "0"
}
binding.QuotaGPUMem = strings.TrimSpace(workspace.QuotaGPUMem)
if binding.QuotaGPUMem == "" {
binding.QuotaGPUMem = "0"
}
binding.UpdatedAt = time.Now()
if s.tenantClient != nil && s.clusterRepo != nil {
cluster, err := s.clusterRepo.GetByID(ctx, binding.ClusterID)
if err != nil {
if errors.Is(err, entity.ErrClusterNotFound) {
continue
}
return err
}
tenantBinding := entity.NewTenantBinding(binding.Namespace)
tenantBinding.ServiceAccountName = binding.ServiceAccount
tenantBinding.ResourceQuotaHard = bindingQuotaHard(binding)
if err := s.tenantClient.EnsureTenant(ctx, cluster, tenantBinding); err != nil {
return err
}
}
if err := s.bindingRepo.Upsert(ctx, binding); err != nil {
return err
}
}
return nil
}
func (s *AuthService) DeleteUser(ctx context.Context, userID string) error {
principal, err := authz.RequirePrincipal(ctx)
if err != nil {
@ -300,9 +509,117 @@ func (s *AuthService) DeleteUser(ctx context.Context, userID string) error {
if userID == principal.UserID {
return entity.ErrForbidden
}
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return entity.ErrUserNotFound
}
if err := s.ensureUserHasNoInstances(ctx, user); err != nil {
return err
}
if s.isExclusiveUserWorkspace(ctx, user) {
if err := s.cleanupUserWorkspace(ctx, user.WorkspaceID); err != nil {
return err
}
}
return s.userRepo.Delete(ctx, userID)
}
func (s *AuthService) ensureUserHasNoInstances(ctx context.Context, user *entity.User) error {
if s.instanceRepo == nil || user == nil {
return nil
}
instances, err := s.instanceRepo.List(ctx)
if err != nil {
return err
}
for _, instance := range instances {
if instance == nil {
continue
}
if instance.OwnerID == user.ID {
return entity.ErrUserHasInstances
}
if user.WorkspaceID != "" && user.WorkspaceID != entity.DefaultWorkspaceID && instance.WorkspaceID == user.WorkspaceID {
return entity.ErrUserHasInstances
}
}
return nil
}
func (s *AuthService) isExclusiveUserWorkspace(ctx context.Context, user *entity.User) bool {
if user == nil || user.Role == authz.RoleAdmin || user.WorkspaceID == "" || user.WorkspaceID == entity.DefaultWorkspaceID {
return false
}
users, err := s.userRepo.List(ctx)
if err != nil {
return false
}
for _, other := range users {
if other == nil || other.ID == user.ID {
continue
}
if other.WorkspaceID == user.WorkspaceID {
return false
}
}
return true
}
func (s *AuthService) cleanupUserWorkspace(ctx context.Context, workspaceID string) error {
if s.workspaceRepo == nil || s.bindingRepo == nil {
return nil
}
workspace, err := s.workspaceRepo.GetByID(ctx, workspaceID)
if err != nil {
return err
}
if isProtectedWorkspaceNamespace(workspace.K8sNamespace) {
return entity.ErrProtectedNamespace
}
bindings, err := s.bindingRepo.ListByWorkspace(ctx, workspace.ID)
if err != nil {
return err
}
for _, binding := range bindings {
if binding == nil {
continue
}
if isProtectedWorkspaceNamespace(binding.Namespace) {
return entity.ErrProtectedNamespace
}
if s.tenantClient != nil && s.clusterRepo != nil {
cluster, err := s.clusterRepo.GetByID(ctx, binding.ClusterID)
if err != nil && !errors.Is(err, entity.ErrClusterNotFound) {
return err
}
if err == nil {
tenantBinding := entity.NewTenantBinding(binding.Namespace)
tenantBinding.ServiceAccountName = binding.ServiceAccount
tenantBinding.ResourceQuotaHard = resourceQuotaHard(workspace)
if err := s.tenantClient.DeleteTenant(ctx, cluster, tenantBinding); err != nil {
return err
}
}
}
if err := s.bindingRepo.Delete(ctx, binding.WorkspaceID, binding.ClusterID); err != nil {
return err
}
}
if err := s.workspaceRepo.Delete(ctx, workspace.ID); err != nil && !errors.Is(err, entity.ErrWorkspaceNotFound) {
return err
}
return nil
}
func isProtectedWorkspaceNamespace(namespace string) bool {
switch strings.TrimSpace(namespace) {
case "", "default", "kube-system", "kube-public", "kube-node-lease":
return true
default:
return false
}
}
func normalizeUserRole(role string) string {
if role == authz.RoleAdmin {
return authz.RoleAdmin

View File

@ -0,0 +1,322 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/adapter/output/persistence/mock"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/authz"
jwtpkg "github.com/ocdp/cluster-service/internal/pkg/jwt"
)
func TestAuthServiceUpdateUserDowngradeReusesUsernameWorkspace(t *testing.T) {
ctx := adminContext()
userRepo := mock.NewUserRepositoryMock()
workspaceRepo := mock.NewWorkspaceRepositoryMock()
svc := NewAuthService(userRepo, workspaceRepo, testPasswordHasher{}, testTokenGenerator{})
target := testUser("user-1", "alice", authz.RoleAdmin, entity.DefaultWorkspaceID)
if err := userRepo.Create(ctx, target); err != nil {
t.Fatalf("seed user: %v", err)
}
workspace := entity.NewWorkspace(userWorkspaceName("alice"), "admin")
workspace.ID = "workspace-alice"
workspace.K8sNamespace = entity.NamespaceForUser("alice")
workspace.K8sSAName = entity.ServiceAccountForNamespace(workspace.K8sNamespace)
if err := workspaceRepo.Create(ctx, workspace); err != nil {
t.Fatalf("seed workspace: %v", err)
}
updated, err := svc.UpdateUser(ctx, target.ID, authz.RoleUser, "", UserWorkspaceOptions{DefaultClusterID: "cluster-1"}, nil, nil)
if err != nil {
t.Fatalf("UpdateUser returned error: %v", err)
}
if updated.Role != authz.RoleUser {
t.Fatalf("expected user role, got %q", updated.Role)
}
if updated.WorkspaceID != workspace.ID {
t.Fatalf("expected reused workspace %q, got %q", workspace.ID, updated.WorkspaceID)
}
reused, err := workspaceRepo.GetByID(ctx, workspace.ID)
if err != nil {
t.Fatalf("get reused workspace: %v", err)
}
if reused.DefaultClusterID != "cluster-1" {
t.Fatalf("expected updated default cluster, got %q", reused.DefaultClusterID)
}
}
func TestAuthServiceRegisterUserAlwaysCreatesPrivateWorkspaceWithZeroDefaultQuotas(t *testing.T) {
ctx := adminContext()
userRepo := mock.NewUserRepositoryMock()
workspaceRepo := mock.NewWorkspaceRepositoryMock()
svc := NewAuthService(userRepo, workspaceRepo, testPasswordHasher{}, testTokenGenerator{})
user, err := svc.Register(ctx, "alice", "password", authz.RoleUser, "shared-workspace", UserWorkspaceOptions{}, nil, nil)
if err != nil {
t.Fatalf("Register returned error: %v", err)
}
if user.WorkspaceID == "shared-workspace" || user.WorkspaceID == entity.DefaultWorkspaceID {
t.Fatalf("expected private user workspace, got %q", user.WorkspaceID)
}
workspace, err := workspaceRepo.GetByID(ctx, user.WorkspaceID)
if err != nil {
t.Fatalf("get user workspace: %v", err)
}
if workspace.K8sNamespace != entity.NamespaceForUser("alice") {
t.Fatalf("expected user namespace %q, got %q", entity.NamespaceForUser("alice"), workspace.K8sNamespace)
}
if workspace.QuotaCPU != "" || workspace.QuotaMemory != "" || workspace.QuotaGPU != "0" || workspace.QuotaGPUMem != "0" {
t.Fatalf("expected omitted CPU/memory to stay unlimited and GPU/gpumem to default zero, got cpu=%q memory=%q gpu=%q gpumem=%q", workspace.QuotaCPU, workspace.QuotaMemory, workspace.QuotaGPU, workspace.QuotaGPUMem)
}
}
func TestAuthServiceUpdateUserDowngradeRejectsNamespaceConflict(t *testing.T) {
ctx := adminContext()
userRepo := mock.NewUserRepositoryMock()
workspaceRepo := mock.NewWorkspaceRepositoryMock()
svc := NewAuthService(userRepo, workspaceRepo, testPasswordHasher{}, testTokenGenerator{})
target := testUser("user-1", "alice", authz.RoleAdmin, entity.DefaultWorkspaceID)
if err := userRepo.Create(ctx, target); err != nil {
t.Fatalf("seed user: %v", err)
}
conflicting := entity.NewWorkspace("someone-else", "admin")
conflicting.ID = "workspace-other"
conflicting.K8sNamespace = entity.NamespaceForUser("alice")
conflicting.K8sSAName = entity.ServiceAccountForNamespace(conflicting.K8sNamespace)
if err := workspaceRepo.Create(ctx, conflicting); err != nil {
t.Fatalf("seed conflicting workspace: %v", err)
}
_, err := svc.UpdateUser(ctx, target.ID, authz.RoleUser, "", UserWorkspaceOptions{}, nil, nil)
if !errors.Is(err, entity.ErrWorkspaceNamespaceConflict) {
t.Fatalf("expected namespace conflict, got %v", err)
}
}
func TestAuthServiceDeleteUserRejectsUserWithInstances(t *testing.T) {
ctx := adminContext()
userRepo := mock.NewUserRepositoryMock()
workspaceRepo := mock.NewWorkspaceRepositoryMock()
instanceRepo := mock.NewInstanceRepositoryMock()
svc := NewAuthService(userRepo, workspaceRepo, testPasswordHasher{}, testTokenGenerator{})
svc.SetUserLifecycleCleanup(instanceRepo, nil, nil, nil)
user := testUser("user-1", "alice", authz.RoleUser, "workspace-alice")
if err := userRepo.Create(ctx, user); err != nil {
t.Fatalf("seed user: %v", err)
}
instance := entity.NewInstance("cluster-1", "app", "ocdp-u-alice", "registry-1", "repo", "chart", "1.0.0")
instance.ID = "instance-1"
instance.OwnerID = user.ID
instance.WorkspaceID = user.WorkspaceID
if err := instanceRepo.Create(ctx, instance); err != nil {
t.Fatalf("seed instance: %v", err)
}
err := svc.DeleteUser(ctx, user.ID)
if !errors.Is(err, entity.ErrUserHasInstances) {
t.Fatalf("expected user instance conflict, got %v", err)
}
if _, err := userRepo.GetByID(ctx, user.ID); err != nil {
t.Fatalf("user should not be deleted: %v", err)
}
}
func TestAuthServiceDeleteUserRejectsWorkspaceInstanceEvenWithDifferentOwner(t *testing.T) {
ctx := adminContext()
userRepo := mock.NewUserRepositoryMock()
workspaceRepo := mock.NewWorkspaceRepositoryMock()
instanceRepo := mock.NewInstanceRepositoryMock()
svc := NewAuthService(userRepo, workspaceRepo, testPasswordHasher{}, testTokenGenerator{})
svc.SetUserLifecycleCleanup(instanceRepo, nil, nil, nil)
user := testUser("user-1", "alice", authz.RoleUser, "workspace-alice")
if err := userRepo.Create(ctx, user); err != nil {
t.Fatalf("seed user: %v", err)
}
instance := entity.NewInstance("cluster-1", "shared-workspace-app", "ocdp-u-alice", "registry-1", "repo", "chart", "1.0.0")
instance.ID = "instance-1"
instance.OwnerID = "other-user"
instance.WorkspaceID = user.WorkspaceID
if err := instanceRepo.Create(ctx, instance); err != nil {
t.Fatalf("seed workspace instance: %v", err)
}
err := svc.DeleteUser(ctx, user.ID)
if !errors.Is(err, entity.ErrUserHasInstances) {
t.Fatalf("expected workspace instance conflict, got %v", err)
}
if _, err := userRepo.GetByID(ctx, user.ID); err != nil {
t.Fatalf("user should not be deleted: %v", err)
}
}
func TestAuthServiceDeleteUserCleansExclusiveWorkspaceBindings(t *testing.T) {
ctx := adminContext()
userRepo := mock.NewUserRepositoryMock()
workspaceRepo := mock.NewWorkspaceRepositoryMock()
instanceRepo := mock.NewInstanceRepositoryMock()
bindingRepo := mock.NewWorkspaceClusterBindingRepositoryMock()
clusterRepo := &testClusterRepo{clusters: map[string]*entity.Cluster{
"cluster-1": {ID: "cluster-1", Name: "cluster-1", Host: "https://cluster.invalid", Token: "token"},
}}
tenantClient := &recordingTenantClient{}
svc := NewAuthService(userRepo, workspaceRepo, testPasswordHasher{}, testTokenGenerator{})
svc.SetUserLifecycleCleanup(instanceRepo, clusterRepo, bindingRepo, tenantClient)
workspace := entity.NewWorkspace(userWorkspaceName("alice"), "admin")
workspace.ID = "workspace-alice"
workspace.K8sNamespace = entity.NamespaceForUser("alice")
workspace.K8sSAName = entity.ServiceAccountForNamespace(workspace.K8sNamespace)
if err := workspaceRepo.Create(ctx, workspace); err != nil {
t.Fatalf("seed workspace: %v", err)
}
user := testUser("user-1", "alice", authz.RoleUser, workspace.ID)
if err := userRepo.Create(ctx, user); err != nil {
t.Fatalf("seed user: %v", err)
}
if err := bindingRepo.Upsert(ctx, &entity.WorkspaceClusterBinding{
ID: "binding-1",
WorkspaceID: workspace.ID,
ClusterID: "cluster-1",
Namespace: workspace.K8sNamespace,
ServiceAccount: workspace.K8sSAName,
Status: "active",
}); err != nil {
t.Fatalf("seed binding: %v", err)
}
if err := svc.DeleteUser(ctx, user.ID); err != nil {
t.Fatalf("DeleteUser returned error: %v", err)
}
if _, err := userRepo.GetByID(ctx, user.ID); !errors.Is(err, entity.ErrUserNotFound) {
t.Fatalf("expected user deleted, got %v", err)
}
if bindings, err := bindingRepo.ListByWorkspace(ctx, workspace.ID); err != nil || len(bindings) != 0 {
t.Fatalf("expected bindings cleaned, got len=%d err=%v", len(bindings), err)
}
if len(tenantClient.deleted) != 1 || tenantClient.deleted[0] != workspace.K8sNamespace {
t.Fatalf("expected tenant namespace cleanup, got %#v", tenantClient.deleted)
}
if _, err := workspaceRepo.GetByID(ctx, workspace.ID); !errors.Is(err, entity.ErrWorkspaceNotFound) {
t.Fatalf("expected exclusive workspace deleted, got %v", err)
}
}
func adminContext() context.Context {
return authz.WithPrincipal(context.Background(), &authz.Principal{
UserID: "admin-1",
Username: "admin",
Role: authz.RoleAdmin,
WorkspaceID: entity.DefaultWorkspaceID,
})
}
func testUser(id, username, role, workspaceID string) *entity.User {
user := entity.NewUser(username, "hash", username+"@local.ocdp")
user.ID = id
user.Role = role
user.WorkspaceID = workspaceID
return user
}
type testPasswordHasher struct{}
func (testPasswordHasher) Hash(password string) (string, error) { return "hash:" + password, nil }
func (testPasswordHasher) Verify(password, hash string) error { return nil }
type testTokenGenerator struct{}
func (testTokenGenerator) Generate(userID, username, role, workspaceID string) (string, string, error) {
return "access", "refresh", nil
}
func (testTokenGenerator) Verify(token string) (string, string, error) { return "", "", nil }
func (testTokenGenerator) VerifyWithIssuedAt(token string) (string, string, int64, error) {
return "", "", 0, nil
}
func (testTokenGenerator) VerifyAccess(token string) (*jwtpkg.Claims, error) { return nil, nil }
func (testTokenGenerator) VerifyRefresh(token string) (*jwtpkg.Claims, error) { return nil, nil }
func (testTokenGenerator) Refresh(refreshToken string) (string, error) { return "access", nil }
type testClusterRepo struct {
clusters map[string]*entity.Cluster
}
func (r *testClusterRepo) Create(ctx context.Context, cluster *entity.Cluster) error {
if cluster.ID == "" {
cluster.ID = uuid.New().String()
}
copy := *cluster
r.clusters[cluster.ID] = &copy
return nil
}
func (r *testClusterRepo) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
cluster, ok := r.clusters[id]
if !ok {
return nil, entity.ErrClusterNotFound
}
copy := *cluster
return &copy, nil
}
func (r *testClusterRepo) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
for _, cluster := range r.clusters {
if cluster.Name == name {
copy := *cluster
return &copy, nil
}
}
return nil, entity.ErrClusterNotFound
}
func (r *testClusterRepo) Update(ctx context.Context, cluster *entity.Cluster) error {
copy := *cluster
r.clusters[cluster.ID] = &copy
return nil
}
func (r *testClusterRepo) Delete(ctx context.Context, id string) error {
delete(r.clusters, id)
return nil
}
func (r *testClusterRepo) List(ctx context.Context) ([]*entity.Cluster, error) {
result := make([]*entity.Cluster, 0, len(r.clusters))
for _, cluster := range r.clusters {
copy := *cluster
result = append(result, &copy)
}
return result, nil
}
type recordingTenantClient struct {
deleted []string
usage *repository.ResourceQuotaUsage
}
func (c *recordingTenantClient) EnsureTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
return nil
}
func (c *recordingTenantClient) IssueKubeconfig(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding, ttl time.Duration) (*entity.TenantKubeconfig, error) {
return nil, nil
}
func (c *recordingTenantClient) GetResourceQuotaUsage(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) (*repository.ResourceQuotaUsage, error) {
if c.usage != nil {
return c.usage, nil
}
return &repository.ResourceQuotaUsage{}, nil
}
func (c *recordingTenantClient) SuspendTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
return nil
}
func (c *recordingTenantClient) DeleteTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
if err := binding.Validate(); err != nil {
return err
}
c.deleted = append(c.deleted, binding.Namespace)
return nil
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
@ -34,6 +35,7 @@ type InstanceService struct {
entryClient repository.InstanceEntryClient
diagClient repository.InstanceDiagnosticsClient
workspaceRepo repository.WorkspaceRepository
userRepo repository.UserRepository
tenantClient repository.TenantKubeClient
scaleClient ScaleClient
}
@ -76,6 +78,10 @@ func (s *InstanceService) SetTenantProvisioning(workspaceRepo repository.Workspa
s.tenantClient = tenantClient
}
func (s *InstanceService) SetUserRepository(userRepo repository.UserRepository) {
s.userRepo = userRepo
}
const chartCacheDir = "/tmp/charts"
func (s *InstanceService) chartArchivePath(instance *entity.Instance) string {
@ -131,15 +137,21 @@ func (s *InstanceService) CreateInstance(ctx context.Context, instance *entity.I
return err
}
enforceNamespaceValues(instance)
if err := s.ensureTenantForInstance(ctx, principal, cluster, instance); err != nil {
return err
}
// 检查实例是否已存在
existingInstance, _ := s.instanceRepo.GetByClusterAndName(ctx, instance.ClusterID, instance.Name)
if existingInstance != nil {
return entity.ErrInstanceExists
}
if err := s.downloadChart(ctx, registry, instance); err != nil {
return err
}
binding, err := s.ensureTenantForInstance(ctx, principal, cluster, instance)
if err != nil {
return err
}
if err := s.precheckInstanceQuota(ctx, principal, cluster, binding, instance, nil); err != nil {
return err
}
instance.BeginOperation(entity.OperationInstall, "Preparing installation")
@ -148,13 +160,6 @@ func (s *InstanceService) CreateInstance(ctx context.Context, instance *entity.I
return err
}
// 下载 chart artifact 供 Helm 使用
if err := s.downloadChart(ctx, registry, instance); err != nil {
instance.MarkFailure("Failed to download chart", err)
_ = s.instanceRepo.Update(ctx, instance)
return err
}
// 异步执行 Helm 安装并监控状态
go s.executeAndSyncInstall(context.Background(), instance.ID, cluster, registry, instance)
@ -175,6 +180,7 @@ func (s *InstanceService) GetInstance(ctx context.Context, id string) (*entity.I
if !s.canReadInstance(principal, instance) {
return nil, entity.ErrInstanceNotFound
}
s.enrichOwnerUsernames(ctx, []*entity.Instance{instance})
return instance, nil
}
@ -219,8 +225,22 @@ func (s *InstanceService) UpdateInstance(ctx context.Context, instance *entity.I
if !s.canWriteInstance(principal, existingInstance) {
return entity.ErrForbidden
}
instance.ClusterID = existingInstance.ClusterID
instance.WorkspaceID = existingInstance.WorkspaceID
instance.OwnerID = existingInstance.OwnerID
instance.Name = existingInstance.Name
if instance.RegistryID == "" {
instance.RegistryID = existingInstance.RegistryID
}
if instance.Repository == "" {
instance.Repository = existingInstance.Repository
}
if instance.Chart == "" {
instance.Chart = existingInstance.Chart
}
if instance.Version == "" {
instance.Version = existingInstance.Version
}
// 获取集群信息
cluster, err := s.clusterRepo.GetByID(ctx, existingInstance.ClusterID)
@ -236,15 +256,21 @@ func (s *InstanceService) UpdateInstance(ctx context.Context, instance *entity.I
instance.Namespace = existingInstance.Namespace
enforceNamespaceValues(instance)
instance.BeginOperation(entity.OperationUpgrade, "Pending upgrade")
if err := s.instanceRepo.Update(ctx, instance); err != nil {
return err
}
// 下载所需 Chart
if err := s.downloadChart(ctx, registry, instance); err != nil {
instance.MarkFailure("Failed to download chart", err)
_ = s.instanceRepo.Update(ctx, instance)
return err
}
binding, err := s.ensureTenantForInstance(ctx, principal, cluster, instance)
if err != nil {
return err
}
if err := s.precheckInstanceQuota(ctx, principal, cluster, binding, instance, existingInstance); err != nil {
return err
}
instance.BeginOperation(entity.OperationUpgrade, "Pending upgrade")
if err := s.instanceRepo.Update(ctx, instance); err != nil {
return err
}
@ -364,9 +390,32 @@ func (s *InstanceService) ListInstancesByCluster(ctx context.Context, clusterID
visible = append(visible, instance)
}
}
s.enrichOwnerUsernames(ctx, visible)
return visible, nil
}
func (s *InstanceService) enrichOwnerUsernames(ctx context.Context, instances []*entity.Instance) {
if s.userRepo == nil || len(instances) == 0 {
return
}
usernames := make(map[string]string)
for _, instance := range instances {
if instance == nil || instance.OwnerID == "" {
continue
}
if username, ok := usernames[instance.OwnerID]; ok {
instance.OwnerUsername = username
continue
}
user, err := s.userRepo.GetByID(ctx, instance.OwnerID)
if err != nil || user == nil {
continue
}
usernames[instance.OwnerID] = user.Username
instance.OwnerUsername = user.Username
}
}
// ListInstanceEntries 列出实例关联的入口信息Service / Ingress
func (s *InstanceService) ListInstanceEntries(ctx context.Context, clusterID, instanceID string) ([]*entity.InstanceEntry, error) {
instance, err := s.GetInstance(ctx, instanceID)
@ -442,27 +491,57 @@ func (s *InstanceService) ScaleInstance(ctx context.Context, clusterID, instance
if !s.canWriteInstance(principal, instance) {
return nil, entity.ErrForbidden
}
if instance.ClusterID != clusterID {
return nil, entity.ErrInstanceNotFound
}
cluster, err := s.clusterRepo.GetByID(ctx, clusterID)
if err != nil {
return nil, entity.ErrClusterNotFound
}
current := cloneInstanceForQuota(instance)
currentValues, err := s.helmClient.GetValues(ctx, cluster, instance.Name, instance.Namespace)
if err == nil && currentValues != nil {
current.SetValues(currentValues)
}
target := cloneInstanceForQuota(instance)
targetValues := copyValues(current.Values)
if targetValues == nil {
targetValues = copyValues(instance.Values)
}
if targetValues == nil {
targetValues = map[string]interface{}{}
}
targetValues["replicaCount"] = replicas
target.SetValues(targetValues)
registry, err := s.registryRepo.GetByID(ctx, instance.RegistryID)
if err != nil {
return nil, entity.ErrRegistryNotFound
}
if err := s.downloadChart(ctx, registry, target); err != nil {
return nil, err
}
binding, err := s.ensureTenantForInstance(ctx, principal, cluster, target)
if err != nil {
return nil, err
}
if err := s.precheckInstanceQuota(ctx, principal, cluster, binding, target, current); err != nil {
return nil, err
}
// Scale via K8s API directly (like kubectl scale deploy --replicas=N)
if s.scaleClient != nil {
if err := s.scaleClient.ScaleDeployment(ctx, cluster, instance.Namespace, instance.Name, int32(replicas)); err != nil {
return nil, fmt.Errorf("failed to scale deployment: %w", err)
}
instance.SetValues(targetValues)
instance.Replicas = replicas
if err := s.instanceRepo.Update(ctx, instance); err != nil {
return nil, err
}
} else {
// Fallback: Helm upgrade with replicaCount
vals, err := s.helmClient.GetValues(ctx, cluster, instance.Name, instance.Namespace)
if err != nil {
return nil, fmt.Errorf("failed to get current values: %w", err)
}
if vals == nil {
vals = make(map[string]interface{})
}
vals["replicaCount"] = replicas
instance.SetValues(vals)
instance.SetValues(targetValues)
instance.BeginOperation(entity.OperationUpgrade, fmt.Sprintf("Scaling to %d replicas", replicas))
if err := s.instanceRepo.Update(ctx, instance); err != nil {
return nil, err
@ -516,6 +595,9 @@ func (s *InstanceService) GetInstanceValuesDiff(ctx context.Context, clusterID,
if !s.canReadInstance(principal, instance) {
return nil, entity.ErrInstanceNotFound
}
if instance.ClusterID != clusterID {
return nil, entity.ErrInstanceNotFound
}
cluster, err := s.clusterRepo.GetByID(ctx, clusterID)
if err != nil {
return nil, entity.ErrClusterNotFound
@ -528,6 +610,18 @@ func (s *InstanceService) GetInstanceValuesDiff(ctx context.Context, clusterID,
// Get default values from the chart archive
chartPath := s.chartArchivePath(instance)
if _, statErr := os.Stat(chartPath); statErr != nil {
if !errors.Is(statErr, os.ErrNotExist) {
return nil, fmt.Errorf("failed to inspect chart defaults: %w", statErr)
}
registry, err := s.registryRepo.GetByID(ctx, instance.RegistryID)
if err != nil {
return nil, entity.ErrRegistryNotFound
}
if err := s.downloadChart(ctx, registry, instance); err != nil {
return nil, err
}
}
defaults, err := s.helmClient.GetChartDefaultValues(chartPath)
if err != nil {
return nil, fmt.Errorf("failed to read chart defaults: %w", err)
@ -593,9 +687,6 @@ func (s *InstanceService) applyNamespacePolicy(ctx context.Context, principal *a
}
return nil
}
if isReservedNamespace(instance.Namespace) {
return entity.ErrInvalidNamespace
}
if cluster.Visibility != authz.VisibilityPrivate || cluster.OwnerID != principal.UserID {
namespace := principal.Namespace
if namespace == "" {
@ -606,9 +697,15 @@ func (s *InstanceService) applyNamespacePolicy(ctx context.Context, principal *a
namespace = binding.Namespace
}
}
if instance.Namespace != "" && instance.Namespace != namespace {
return entity.ErrForbidden
}
instance.Namespace = namespace
return nil
}
if isReservedNamespace(instance.Namespace) {
return entity.ErrInvalidNamespace
}
if instance.Namespace == "" {
if cluster.DefaultNamespace != "" {
instance.Namespace = cluster.DefaultNamespace
@ -621,8 +718,62 @@ func (s *InstanceService) applyNamespacePolicy(ctx context.Context, principal *a
return nil
}
func (s *InstanceService) ensureTenantForInstance(ctx context.Context, principal *authz.Principal, cluster *entity.Cluster, instance *entity.Instance) error {
func (s *InstanceService) ensureTenantForInstance(ctx context.Context, principal *authz.Principal, cluster *entity.Cluster, instance *entity.Instance) (*entity.WorkspaceClusterBinding, error) {
if principal.IsAdmin() || s.workspaceRepo == nil || s.tenantClient == nil {
return nil, nil
}
workspace, err := s.workspaceRepo.GetByID(ctx, principal.WorkspaceID)
if err != nil {
return nil, err
}
if workspace.Status == entity.WorkspaceSuspended {
return nil, entity.ErrWorkspaceSuspended
}
binding := &entity.WorkspaceClusterBinding{
ID: uuid.New().String(),
WorkspaceID: workspace.ID,
ClusterID: cluster.ID,
Namespace: instance.Namespace,
ServiceAccount: workspace.K8sSAName,
QuotaCPU: strings.TrimSpace(workspace.QuotaCPU),
QuotaMemory: strings.TrimSpace(workspace.QuotaMemory),
QuotaGPU: zeroIfEmptyQuota(workspace.QuotaGPU),
QuotaGPUMem: zeroIfEmptyQuota(workspace.QuotaGPUMem),
Status: "active",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if s.bindingRepo != nil {
if existing, err := s.bindingRepo.Get(ctx, workspace.ID, cluster.ID); err == nil && existing != nil {
binding.ID = existing.ID
binding.CreatedAt = existing.CreatedAt
if existing.Namespace != "" {
binding.Namespace = existing.Namespace
instance.Namespace = existing.Namespace
enforceNamespaceValues(instance)
}
if existing.ServiceAccount != "" {
binding.ServiceAccount = existing.ServiceAccount
}
if existing.Status != "" {
binding.Status = existing.Status
}
}
}
tenantBinding := tenantBindingFromWorkspaceClusterBinding(binding)
if err := s.tenantClient.EnsureTenant(ctx, cluster, tenantBinding); err != nil {
return nil, err
}
if s.bindingRepo != nil {
if err := s.bindingRepo.Upsert(ctx, binding); err != nil {
return nil, err
}
}
return binding, nil
}
func (s *InstanceService) precheckInstanceQuota(ctx context.Context, principal *authz.Principal, cluster *entity.Cluster, binding *entity.WorkspaceClusterBinding, target, current *entity.Instance) error {
if principal.IsAdmin() || s.workspaceRepo == nil || s.helmClient == nil {
return nil
}
workspace, err := s.workspaceRepo.GetByID(ctx, principal.WorkspaceID)
@ -632,29 +783,45 @@ func (s *InstanceService) ensureTenantForInstance(ctx context.Context, principal
if workspace.Status == entity.WorkspaceSuspended {
return entity.ErrWorkspaceSuspended
}
binding := entity.NewTenantBinding(instance.Namespace)
binding.ServiceAccountName = workspace.K8sSAName
binding.ResourceQuotaHard = instanceResourceQuotaHard(workspace)
if err := s.tenantClient.EnsureTenant(ctx, cluster, binding); err != nil {
return err
if binding == nil {
binding = &entity.WorkspaceClusterBinding{
WorkspaceID: principal.WorkspaceID,
ClusterID: cluster.ID,
Namespace: target.Namespace,
QuotaCPU: strings.TrimSpace(workspace.QuotaCPU),
QuotaMemory: strings.TrimSpace(workspace.QuotaMemory),
QuotaGPU: zeroIfEmptyQuota(workspace.QuotaGPU),
QuotaGPUMem: zeroIfEmptyQuota(workspace.QuotaGPUMem),
}
}
if s.bindingRepo != nil {
_ = s.bindingRepo.Upsert(ctx, &entity.WorkspaceClusterBinding{
ID: uuid.New().String(),
WorkspaceID: workspace.ID,
ClusterID: cluster.ID,
Namespace: instance.Namespace,
ServiceAccount: workspace.K8sSAName,
QuotaCPU: workspace.QuotaCPU,
QuotaMemory: workspace.QuotaMemory,
QuotaGPU: workspace.QuotaGPU,
QuotaGPUMem: workspace.QuotaGPUMem,
Status: "active",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
var usage *repository.ResourceQuotaUsage
if s.tenantClient != nil {
tenantBinding := tenantBindingFromWorkspaceClusterBinding(binding)
quotaUsage, err := s.tenantClient.GetResourceQuotaUsage(ctx, cluster, tenantBinding)
if err != nil {
return err
}
usage = quotaUsage
}
return nil
result, err := NewQuotaPrecheckService(s.helmClient).EstimateAndCompareBinding(ctx, cluster, binding, usage, target, current)
if err == nil {
return nil
}
if errors.Is(err, ErrQuotaExceeded) && result != nil {
return fmt.Errorf("%w: %s", ErrQuotaExceeded, formatQuotaExceeded(result.Exceeded))
}
return err
}
func formatQuotaExceeded(exceeded []QuotaExceededResource) string {
if len(exceeded) == 0 {
return "requested resources exceed workspace quota"
}
parts := make([]string, 0, len(exceeded))
for _, item := range exceeded {
parts = append(parts, fmt.Sprintf("%s required=%s quota=%s", item.Name, item.Required, item.Hard))
}
return strings.Join(parts, "; ")
}
func instanceResourceQuotaHard(workspace *entity.Workspace) corev1.ResourceList {
@ -687,6 +854,46 @@ func instanceResourceQuotaHard(workspace *entity.Workspace) corev1.ResourceList
return hard
}
func tenantBindingFromWorkspaceClusterBinding(binding *entity.WorkspaceClusterBinding) entity.TenantBinding {
namespace := ""
if binding != nil {
namespace = binding.Namespace
}
tenantBinding := entity.NewTenantBinding(namespace)
if binding != nil {
tenantBinding.ServiceAccountName = binding.ServiceAccount
tenantBinding.ResourceQuotaHard = bindingQuotaHard(binding)
}
return tenantBinding
}
func zeroIfEmptyQuota(value string) string {
if strings.TrimSpace(value) == "" {
return "0"
}
return strings.TrimSpace(value)
}
func cloneInstanceForQuota(instance *entity.Instance) *entity.Instance {
if instance == nil {
return nil
}
cloned := *instance
cloned.SetValues(copyValues(instance.Values))
return &cloned
}
func copyValues(values map[string]interface{}) map[string]interface{} {
if values == nil {
return nil
}
copied := make(map[string]interface{}, len(values))
for key, value := range values {
copied[key] = value
}
return copied
}
func isReservedNamespace(namespace string) bool {
switch namespace {
case "default", "kube-system", "kube-public", "kube-node-lease":

View File

@ -10,6 +10,7 @@ import (
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/authz"
"k8s.io/apimachinery/pkg/api/resource"
)
func TestDeleteInstanceIgnoresMissingRelease(t *testing.T) {
@ -85,6 +86,210 @@ func TestEnforceNamespaceValuesOverridesChartNamespaceKnobs(t *testing.T) {
}
}
func TestApplyNamespacePolicyRejectsMismatchedTenantNamespace(t *testing.T) {
principal := &authz.Principal{
UserID: "user-1",
Username: "alice",
Role: authz.RoleUser,
WorkspaceID: "workspace-1",
WorkspaceName: "alice",
Namespace: "ocdp-u-alice",
}
cluster := &entity.Cluster{
ID: "cluster-1",
OwnerID: "admin",
Visibility: authz.VisibilityWorkspaceShared,
}
instance := &entity.Instance{Namespace: "other-namespace"}
svc := NewInstanceService(nil, nil, nil, nil, nil, nil)
if err := svc.applyNamespacePolicy(context.Background(), principal, cluster, instance); !errors.Is(err, entity.ErrForbidden) {
t.Fatalf("expected ErrForbidden for mismatched tenant namespace, got %v", err)
}
if instance.Namespace != "other-namespace" {
t.Fatalf("expected namespace to remain unchanged on rejection, got %q", instance.Namespace)
}
}
func TestApplyNamespacePolicyAllowsTenantNamespace(t *testing.T) {
principal := &authz.Principal{
UserID: "user-1",
Username: "alice",
Role: authz.RoleUser,
WorkspaceID: "workspace-1",
WorkspaceName: "alice",
Namespace: "ocdp-u-alice",
}
cluster := &entity.Cluster{
ID: "cluster-1",
OwnerID: "admin",
Visibility: authz.VisibilityWorkspaceShared,
}
instance := &entity.Instance{Namespace: "ocdp-u-alice"}
svc := NewInstanceService(nil, nil, nil, nil, nil, nil)
if err := svc.applyNamespacePolicy(context.Background(), principal, cluster, instance); err != nil {
t.Fatalf("expected matching tenant namespace to be allowed, got %v", err)
}
if instance.Namespace != "ocdp-u-alice" {
t.Fatalf("expected namespace to remain the allowed tenant namespace, got %q", instance.Namespace)
}
}
func TestEnrichReplicasSetsLiveReplicaCount(t *testing.T) {
ctx := context.Background()
cluster := &entity.Cluster{ID: "cluster-1", Name: "cluster"}
svc := NewInstanceService(nil, &stubClusterRepo{cluster: cluster}, nil, nil, nil, nil)
svc.SetScaleClient(&stubScaleClient{replicas: 3})
instances := []*entity.Instance{{
ID: "inst-1",
ClusterID: "cluster-1",
Name: "demo",
Namespace: "ocdp-u-alice",
Replicas: 1,
}}
enriched := svc.EnrichReplicas(ctx, "cluster-1", instances)
if enriched[0].Replicas != 3 {
t.Fatalf("expected live replicas to overwrite stored count, got %d", enriched[0].Replicas)
}
}
func TestListInstancesByClusterHydratesOwnerUsername(t *testing.T) {
ctx := authz.WithPrincipal(context.Background(), &authz.Principal{
UserID: "admin-1",
Username: "admin",
Role: authz.RoleAdmin,
WorkspaceID: "workspace-admin",
})
instanceRepo := persistencemock.NewInstanceRepositoryMock()
userRepo := persistencemock.NewUserRepositoryMock()
if err := userRepo.Create(ctx, &entity.User{ID: "user-1", Username: "alice", PasswordHash: "hash", Role: "user", WorkspaceID: "workspace-1"}); err != nil {
t.Fatalf("failed to seed user: %v", err)
}
instance := &entity.Instance{
ID: "inst-1",
WorkspaceID: "workspace-1",
OwnerID: "user-1",
ClusterID: "cluster-1",
Name: "demo",
Namespace: "ocdp-u-alice",
}
if err := instanceRepo.Create(ctx, instance); err != nil {
t.Fatalf("failed to seed instance: %v", err)
}
svc := NewInstanceService(
instanceRepo,
&stubClusterRepo{cluster: &entity.Cluster{ID: "cluster-1", Name: "cluster"}},
nil,
nil,
nil,
nil,
)
svc.SetUserRepository(userRepo)
instances, err := svc.ListInstancesByCluster(ctx, "cluster-1")
if err != nil {
t.Fatalf("ListInstancesByCluster returned error: %v", err)
}
if len(instances) != 1 {
t.Fatalf("expected 1 instance, got %d", len(instances))
}
if instances[0].OwnerUsername != "alice" {
t.Fatalf("expected owner username alice, got %q", instances[0].OwnerUsername)
}
}
func TestCreateInstanceRejectsGPUWhenWorkspaceQuotaEmptyBeforeCreate(t *testing.T) {
ctx := authz.WithPrincipal(context.Background(), &authz.Principal{
UserID: "user-ivanwu",
Username: "ivanwu",
Role: authz.RoleUser,
WorkspaceID: "workspace-ivanwu",
WorkspaceName: "ivanwu",
Namespace: "ocdp-u-ivanwu",
})
instanceRepo := persistencemock.NewInstanceRepositoryMock()
workspaceRepo := persistencemock.NewWorkspaceRepositoryMock()
bindingRepo := persistencemock.NewWorkspaceClusterBindingRepositoryMock()
workspace := entity.NewWorkspace("ivanwu", "admin")
workspace.ID = "workspace-ivanwu"
workspace.K8sNamespace = "ocdp-u-ivanwu"
workspace.K8sSAName = entity.ServiceAccountForNamespace(workspace.K8sNamespace)
workspace.QuotaCPU = "8"
workspace.QuotaMemory = "32Gi"
workspace.QuotaGPU = ""
workspace.QuotaGPUMem = ""
if err := workspaceRepo.Create(ctx, workspace); err != nil {
t.Fatalf("seed workspace: %v", err)
}
cluster := &entity.Cluster{
ID: "k3s",
Name: "k3s",
Host: "https://k3s.invalid",
Token: "token",
OwnerID: "admin",
Visibility: authz.VisibilityGlobalShared,
}
registry := &entity.Registry{
ID: "registry-1",
Name: "harbor",
URL: "https://harbor.invalid",
OwnerID: "admin",
Visibility: authz.VisibilityGlobalShared,
}
helm := &stubHelmClient{
estimate: &repository.ResourceEstimate{
Requests: repository.ResourceVector{
CPU: resource.MustParse("2"),
Memory: resource.MustParse("8Gi"),
GPU: 1,
GPUMemoryMB: 10000,
},
},
}
oci := &stubOCIClient{}
svc := NewInstanceService(
instanceRepo,
&stubClusterRepo{cluster: cluster},
&stubRegistryRepo{registry: registry},
helm,
oci,
nil,
bindingRepo,
)
svc.SetTenantProvisioning(workspaceRepo, &recordingTenantClient{usage: &repository.ResourceQuotaUsage{}})
instance := entity.NewInstance("k3s", "vllm-qwen", "ocdp-u-ivanwu", registry.ID, "library/vllm-serve", "vllm-serve", "0.1.0")
instance.SetValues(map[string]interface{}{
"image": map[string]interface{}{
"repository": "harbor.bwgdi.com/library/vllm-openai",
"tag": "v0.17.1",
},
"model": "Qwen/Qwen2.5-0.5B",
})
err := svc.CreateInstance(ctx, instance)
if !errors.Is(err, ErrQuotaExceeded) {
t.Fatalf("expected GPU quota rejection, got %v", err)
}
instances, listErr := instanceRepo.List(ctx)
if listErr != nil {
t.Fatalf("list instances: %v", listErr)
}
if len(instances) != 0 {
t.Fatalf("expected quota rejection before instance DB create, got %#v", instances)
}
if helm.installCalls != 0 {
t.Fatalf("expected Helm install not to be called, got %d calls", helm.installCalls)
}
if oci.pullCalls != 1 {
t.Fatalf("expected chart pull for quota rendering, got %d pulls", oci.pullCalls)
}
}
func waitForInstanceDeleted(t *testing.T, ctx context.Context, repo repository.InstanceRepository, id string) {
t.Helper()
@ -133,13 +338,19 @@ func (*stubClusterRepo) List(ctx context.Context) ([]*entity.Cluster, error) { r
type stubHelmClient struct {
uninstallErr error
estimate *repository.ResourceEstimate
values map[string]interface{}
installCalls int
upgradeCalls int
}
func (*stubHelmClient) Install(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
func (s *stubHelmClient) Install(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
s.installCalls++
return nil
}
func (*stubHelmClient) Upgrade(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
func (s *stubHelmClient) Upgrade(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
s.upgradeCalls++
return nil
}
@ -163,13 +374,116 @@ func (*stubHelmClient) List(ctx context.Context, cluster *entity.Cluster, namesp
return nil, nil
}
func (*stubHelmClient) GetValues(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (map[string]interface{}, error) {
return nil, nil
func (s *stubHelmClient) GetValues(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (map[string]interface{}, error) {
return s.values, nil
}
func (*stubHelmClient) GetChartDefaultValues(chartPath string) (map[string]interface{}, error) {
return nil, nil
}
func (s *stubHelmClient) EstimateInstanceResources(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) (*repository.ResourceEstimate, error) {
if s.estimate != nil {
return s.estimate, nil
}
return &repository.ResourceEstimate{}, nil
}
type stubRegistryRepo struct {
registry *entity.Registry
}
func (s *stubRegistryRepo) Create(ctx context.Context, registry *entity.Registry) error {
s.registry = registry
return nil
}
func (s *stubRegistryRepo) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
if s.registry != nil && s.registry.ID == id {
return s.registry, nil
}
return nil, entity.ErrRegistryNotFound
}
func (s *stubRegistryRepo) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
if s.registry != nil && s.registry.Name == name {
return s.registry, nil
}
return nil, entity.ErrRegistryNotFound
}
func (s *stubRegistryRepo) Update(ctx context.Context, registry *entity.Registry) error {
s.registry = registry
return nil
}
func (s *stubRegistryRepo) Delete(ctx context.Context, id string) error {
if s.registry != nil && s.registry.ID == id {
s.registry = nil
return nil
}
return entity.ErrRegistryNotFound
}
func (s *stubRegistryRepo) List(ctx context.Context) ([]*entity.Registry, error) {
if s.registry == nil {
return nil, nil
}
return []*entity.Registry{s.registry}, nil
}
type stubOCIClient struct {
pullCalls int
}
func (*stubOCIClient) ListRepositories(ctx context.Context, registry *entity.Registry, artifactType string) ([]string, error) {
return nil, nil
}
func (*stubOCIClient) ListArtifacts(ctx context.Context, registry *entity.Registry, repositoryName, mediaTypeFilter string) ([]*entity.Artifact, error) {
return nil, nil
}
func (*stubOCIClient) GetArtifact(ctx context.Context, registry *entity.Registry, repositoryName, reference string) (*entity.Artifact, error) {
return nil, nil
}
func (*stubOCIClient) GetValuesSchema(ctx context.Context, registry *entity.Registry, repositoryName, reference string) (string, error) {
return "", nil
}
func (*stubOCIClient) GetValuesYAML(ctx context.Context, registry *entity.Registry, repositoryName, reference string) (string, error) {
return "", nil
}
func (s *stubOCIClient) PullArtifact(ctx context.Context, registry *entity.Registry, repositoryName, reference, destPath string) error {
s.pullCalls++
return nil
}
func (*stubOCIClient) PushArtifact(ctx context.Context, registry *entity.Registry, repositoryName, tag, sourcePath string) error {
return nil
}
func (*stubOCIClient) CheckHealth(ctx context.Context, registry *entity.Registry) error {
return nil
}
type stubScaleClient struct {
replicas int32
}
func (s *stubScaleClient) GetDeploymentReplicas(ctx context.Context, cluster *entity.Cluster, namespace, releaseName string) (int32, error) {
return s.replicas, nil
}
func (s *stubScaleClient) ScaleDeployment(ctx context.Context, cluster *entity.Cluster, namespace, releaseName string, replicas int32) error {
s.replicas = replicas
return nil
}
var _ repository.ClusterRepository = (*stubClusterRepo)(nil)
var _ repository.RegistryRepository = (*stubRegistryRepo)(nil)
var _ repository.HelmClient = (*stubHelmClient)(nil)
var _ repository.OCIClient = (*stubOCIClient)(nil)
var _ ScaleClient = (*stubScaleClient)(nil)

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"sort"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
@ -13,16 +14,22 @@ import (
type MonitoringService struct {
clusterRepo repository.ClusterRepository
metricsClient repository.MetricsClient
instanceRepo repository.InstanceRepository
userRepo repository.UserRepository
}
// NewMonitoringService 创建监控服务
func NewMonitoringService(
clusterRepo repository.ClusterRepository,
metricsClient repository.MetricsClient,
instanceRepo repository.InstanceRepository,
userRepo repository.UserRepository,
) *MonitoringService {
return &MonitoringService{
clusterRepo: clusterRepo,
metricsClient: metricsClient,
instanceRepo: instanceRepo,
userRepo: userRepo,
}
}
@ -43,6 +50,8 @@ func (s *MonitoringService) GetClusterMonitoring(ctx context.Context, clusterID
if err != nil {
return nil, fmt.Errorf("failed to get cluster metrics: %w", err)
}
s.enrichResourceUsage(ctx, principal, metrics)
s.scopeTenantMetrics(principal, metrics)
return metrics, nil
}
@ -75,12 +84,310 @@ func (s *MonitoringService) ListClusterMonitoring(ctx context.Context) ([]*entit
Status: "unknown",
}
}
s.enrichResourceUsage(ctx, principal, metrics)
s.scopeTenantMetrics(principal, metrics)
result = append(result, metrics)
}
return result, nil
}
func (s *MonitoringService) enrichResourceUsage(ctx context.Context, principal *authz.Principal, metrics *entity.ClusterMetrics) {
if metrics == nil || s.instanceRepo == nil || s.metricsClient == nil {
s.addVisibleUserRows(ctx, principal, metrics)
return
}
instances, err := s.instanceRepo.ListByCluster(ctx, metrics.ClusterID)
if err != nil {
fmt.Printf("Warning: failed to list instances for cluster %s resource usage: %v\n", metrics.ClusterID, err)
s.addVisibleUserRows(ctx, principal, metrics)
return
}
allocations, err := s.metricsClient.GetPodResourceAllocations(ctx, metrics.ClusterID)
if err != nil {
fmt.Printf("Warning: failed to list pod resource allocations for cluster %s: %v\n", metrics.ClusterID, err)
s.addVisibleUserRows(ctx, principal, metrics)
return
}
visibleInstances := make(map[string]*entity.Instance)
for _, instance := range instances {
if instance == nil || !canReadMonitoringInstance(principal, instance) {
continue
}
key := monitoringInstanceKey(instance.Namespace, instance.Name)
visibleInstances[key] = instance
}
type usageAccumulator struct {
userID string
username string
workspaceID string
allocation entity.ResourceAllocation
podCount int
instances map[string]struct{}
}
byUser := make(map[string]*usageAccumulator)
total := entity.ResourceAllocation{}
for _, pod := range allocations {
if pod == nil {
continue
}
instance := visibleInstances[monitoringInstanceKey(pod.Namespace, pod.InstanceName)]
if instance == nil {
continue
}
total = addResourceAllocation(total, pod.Allocation)
username := instance.OwnerUsername
if username == "" {
username = s.usernameForOwner(ctx, instance.OwnerID, principal)
}
acc := byUser[instance.OwnerID]
if acc == nil {
acc = &usageAccumulator{
userID: instance.OwnerID,
username: username,
workspaceID: instance.WorkspaceID,
instances: map[string]struct{}{},
}
byUser[instance.OwnerID] = acc
}
if acc.username == "" {
acc.username = username
}
acc.allocation = addResourceAllocation(acc.allocation, pod.Allocation)
acc.podCount++
acc.instances[instance.ID] = struct{}{}
}
metrics.CPURequests = formatCPUAllocation(total.CPURequestsMilli)
metrics.CPULimits = formatCPUAllocation(total.CPULimitsMilli)
metrics.MemoryRequests = formatMemoryAllocation(total.MemoryRequestsBytes)
metrics.MemoryLimits = formatMemoryAllocation(total.MemoryLimitsBytes)
metrics.GPURequests = total.GPURequests
metrics.GPULimits = total.GPULimits
metrics.GPUMemoryRequestsMB = total.GPUMemoryRequestsMB
metrics.GPUMemoryLimitsMB = total.GPUMemoryLimitsMB
metrics.AllocatedGPU = total.GPURequests
metrics.AllocatedGPUMemoryMB = total.GPUMemoryRequestsMB
userIDs := make([]string, 0, len(byUser))
for userID := range byUser {
userIDs = append(userIDs, userID)
}
sort.Slice(userIDs, func(i, j int) bool {
left := byUser[userIDs[i]]
right := byUser[userIDs[j]]
if left.username == right.username {
return left.userID < right.userID
}
return left.username < right.username
})
usage := make([]entity.UserResourceUsage, 0, len(userIDs))
for _, userID := range userIDs {
acc := byUser[userID]
usage = append(usage, entity.UserResourceUsage{
UserID: acc.userID,
Username: acc.username,
WorkspaceID: acc.workspaceID,
InstanceCount: len(acc.instances),
PodCount: acc.podCount,
CPURequests: formatCPUAllocation(acc.allocation.CPURequestsMilli),
CPULimits: formatCPUAllocation(acc.allocation.CPULimitsMilli),
MemoryRequests: formatMemoryAllocation(acc.allocation.MemoryRequestsBytes),
MemoryLimits: formatMemoryAllocation(acc.allocation.MemoryLimitsBytes),
GPURequests: acc.allocation.GPURequests,
GPULimits: acc.allocation.GPULimits,
GPUMemoryRequestsMB: acc.allocation.GPUMemoryRequestsMB,
GPUMemoryLimitsMB: acc.allocation.GPUMemoryLimitsMB,
})
}
metrics.ResourceUsageByUser = usage
s.addVisibleUserRows(ctx, principal, metrics)
}
func (s *MonitoringService) addVisibleUserRows(ctx context.Context, principal *authz.Principal, metrics *entity.ClusterMetrics) {
if principal == nil || metrics == nil {
return
}
existing := make(map[string]struct{}, len(metrics.ResourceUsageByUser))
for _, row := range metrics.ResourceUsageByUser {
if row.UserID != "" {
existing[row.UserID] = struct{}{}
}
}
appendEmpty := func(userID, username, workspaceID string) {
if userID == "" {
return
}
if _, ok := existing[userID]; ok {
return
}
metrics.ResourceUsageByUser = append(metrics.ResourceUsageByUser, entity.UserResourceUsage{
UserID: userID,
Username: username,
WorkspaceID: workspaceID,
InstanceCount: 0,
PodCount: 0,
CPURequests: "0 cores",
CPULimits: "0 cores",
MemoryRequests: "0 B",
MemoryLimits: "0 B",
})
existing[userID] = struct{}{}
}
if !principal.IsAdmin() {
appendEmpty(principal.UserID, principal.Username, principal.WorkspaceID)
return
}
if s.userRepo == nil {
return
}
users, err := s.userRepo.List(ctx)
if err != nil {
fmt.Printf("Warning: failed to list users for monitoring rows: %v\n", err)
return
}
for _, user := range users {
if user == nil || user.Role != authz.RoleUser || !user.IsActive {
continue
}
appendEmpty(user.ID, user.Username, user.WorkspaceID)
}
sort.Slice(metrics.ResourceUsageByUser, func(i, j int) bool {
left := metrics.ResourceUsageByUser[i]
right := metrics.ResourceUsageByUser[j]
if left.Username == right.Username {
return left.UserID < right.UserID
}
return left.Username < right.Username
})
}
func (s *MonitoringService) scopeTenantMetrics(principal *authz.Principal, metrics *entity.ClusterMetrics) {
if principal == nil || principal.IsAdmin() || metrics == nil {
return
}
var total entity.ResourceAllocation
podCount := 0
instanceCount := 0
for _, usage := range metrics.ResourceUsageByUser {
if usage.UserID != principal.UserID {
continue
}
podCount += usage.PodCount
instanceCount += usage.InstanceCount
total.GPURequests += usage.GPURequests
total.GPULimits += usage.GPULimits
total.GPUMemoryRequestsMB += usage.GPUMemoryRequestsMB
total.GPUMemoryLimitsMB += usage.GPUMemoryLimitsMB
}
metrics.NodeCount = 0
metrics.Nodes = nil
metrics.PodCount = podCount
metrics.TotalCPU = ""
metrics.TotalMemory = ""
metrics.TotalGPU = 0
metrics.UsedCPU = metrics.CPURequests
metrics.UsedMemory = metrics.MemoryRequests
metrics.UsedGPU = int(total.GPURequests)
metrics.CPUUsage = 0
metrics.MemoryUsage = 0
metrics.GPUUsage = 0
metrics.MaxNodeCPU = ""
metrics.MaxNodeMemory = ""
metrics.MaxNodeGPU = 0
metrics.MaxNodeCPUUsage = 0
metrics.MaxNodeMemUsage = 0
metrics.MaxNodeGPUUsage = 0
metrics.ResourceUsageByUser = filterSelfUsage(principal.UserID, metrics.ResourceUsageByUser)
if instanceCount == 0 {
metrics.CPURequests = ""
metrics.CPULimits = ""
metrics.MemoryRequests = ""
metrics.MemoryLimits = ""
metrics.GPURequests = 0
metrics.GPULimits = 0
metrics.GPUMemoryRequestsMB = 0
metrics.GPUMemoryLimitsMB = 0
metrics.AllocatedGPU = 0
metrics.AllocatedGPUMemoryMB = 0
}
}
func filterSelfUsage(userID string, usage []entity.UserResourceUsage) []entity.UserResourceUsage {
filtered := make([]entity.UserResourceUsage, 0, len(usage))
for _, row := range usage {
if row.UserID == userID {
filtered = append(filtered, row)
}
}
return filtered
}
func canReadMonitoringInstance(principal *authz.Principal, instance *entity.Instance) bool {
if principal == nil || instance == nil {
return false
}
if principal.IsAdmin() {
return true
}
return instance.WorkspaceID == principal.WorkspaceID && instance.OwnerID == principal.UserID
}
func (s *MonitoringService) usernameForOwner(ctx context.Context, ownerID string, principal *authz.Principal) string {
if ownerID == "" {
return ""
}
if principal != nil && ownerID == principal.UserID {
return principal.Username
}
if s.userRepo == nil {
return ""
}
user, err := s.userRepo.GetByID(ctx, ownerID)
if err != nil || user == nil {
return ""
}
return user.Username
}
func monitoringInstanceKey(namespace, name string) string {
return namespace + "/" + name
}
func addResourceAllocation(left, right entity.ResourceAllocation) entity.ResourceAllocation {
return entity.ResourceAllocation{
CPURequestsMilli: left.CPURequestsMilli + right.CPURequestsMilli,
CPULimitsMilli: left.CPULimitsMilli + right.CPULimitsMilli,
MemoryRequestsBytes: left.MemoryRequestsBytes + right.MemoryRequestsBytes,
MemoryLimitsBytes: left.MemoryLimitsBytes + right.MemoryLimitsBytes,
GPURequests: left.GPURequests + right.GPURequests,
GPULimits: left.GPULimits + right.GPULimits,
GPUMemoryRequestsMB: left.GPUMemoryRequestsMB + right.GPUMemoryRequestsMB,
GPUMemoryLimitsMB: left.GPUMemoryLimitsMB + right.GPUMemoryLimitsMB,
}
}
func formatCPUAllocation(milli int64) string {
return fmt.Sprintf("%.2f cores", float64(milli)/1000.0)
}
func formatMemoryAllocation(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// GetMonitoringSummary 获取监控汇总信息
func (s *MonitoringService) GetMonitoringSummary(ctx context.Context) (*entity.MonitoringSummary, error) {
// 获取所有集群监控数据
@ -123,6 +430,9 @@ func (s *MonitoringService) GetNodeMetrics(ctx context.Context, clusterID string
if !authz.CanReadResource(principal, cluster.WorkspaceID, cluster.OwnerID, cluster.Visibility) {
return nil, entity.ErrClusterNotFound
}
if !principal.IsAdmin() {
return nil, entity.ErrForbidden
}
nodes, err := s.metricsClient.GetNodeMetrics(ctx, clusterID)
if err != nil {
return nil, fmt.Errorf("failed to get node metrics: %w", err)

View File

@ -0,0 +1,228 @@
package service
import (
"context"
"testing"
"time"
persistencemock "github.com/ocdp/cluster-service/internal/adapter/output/persistence/mock"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/pkg/authz"
)
func TestListClusterMonitoringAggregatesResourceUsageForAdmin(t *testing.T) {
ctx := authz.WithPrincipal(context.Background(), &authz.Principal{
UserID: "admin-1",
Username: "admin",
Role: authz.RoleAdmin,
WorkspaceID: "workspace-admin",
})
instanceRepo, userRepo := seedMonitoringOwners(t, ctx)
svc := NewMonitoringService(
&monitoringClusterRepo{clusters: []*entity.Cluster{{ID: "cluster-1", Name: "cluster", Visibility: authz.VisibilityGlobalShared}}},
&stubMetricsClient{allocations: monitoringAllocations()},
instanceRepo,
userRepo,
)
metrics, err := svc.ListClusterMonitoring(ctx)
if err != nil {
t.Fatalf("ListClusterMonitoring returned error: %v", err)
}
if len(metrics) != 1 {
t.Fatalf("expected 1 cluster metric, got %d", len(metrics))
}
got := metrics[0]
if got.AllocatedGPU != 3 || got.AllocatedGPUMemoryMB != 30000 {
t.Fatalf("expected total GPU/gpumem allocation 3/30000, got %d/%d", got.AllocatedGPU, got.AllocatedGPUMemoryMB)
}
if len(got.ResourceUsageByUser) != 2 {
t.Fatalf("expected 2 user usage rows, got %d: %#v", len(got.ResourceUsageByUser), got.ResourceUsageByUser)
}
if got.ResourceUsageByUser[0].Username != "alice" || got.ResourceUsageByUser[0].GPURequests != 1 {
t.Fatalf("expected alice GPU request row first, got %#v", got.ResourceUsageByUser[0])
}
if got.ResourceUsageByUser[1].Username != "bob" || got.ResourceUsageByUser[1].GPURequests != 2 {
t.Fatalf("expected bob GPU request row second, got %#v", got.ResourceUsageByUser[1])
}
}
func TestListClusterMonitoringFiltersResourceUsageForOrdinaryUser(t *testing.T) {
ctx := authz.WithPrincipal(context.Background(), &authz.Principal{
UserID: "user-1",
Username: "alice",
Role: authz.RoleUser,
WorkspaceID: "workspace-1",
})
instanceRepo, userRepo := seedMonitoringOwners(t, ctx)
svc := NewMonitoringService(
&monitoringClusterRepo{clusters: []*entity.Cluster{{ID: "cluster-1", Name: "cluster", Visibility: authz.VisibilityGlobalShared}}},
&stubMetricsClient{allocations: monitoringAllocations()},
instanceRepo,
userRepo,
)
metrics, err := svc.ListClusterMonitoring(ctx)
if err != nil {
t.Fatalf("ListClusterMonitoring returned error: %v", err)
}
got := metrics[0]
if got.AllocatedGPU != 1 || got.AllocatedGPUMemoryMB != 10000 {
t.Fatalf("expected ordinary user allocation to be scoped to alice, got %d/%d", got.AllocatedGPU, got.AllocatedGPUMemoryMB)
}
if len(got.ResourceUsageByUser) != 1 {
t.Fatalf("expected only alice usage row, got %d: %#v", len(got.ResourceUsageByUser), got.ResourceUsageByUser)
}
if got.ResourceUsageByUser[0].UserID != "user-1" || got.ResourceUsageByUser[0].Username != "alice" {
t.Fatalf("expected alice usage row, got %#v", got.ResourceUsageByUser[0])
}
if got.NodeCount != 0 || len(got.Nodes) != 0 || got.TotalCPU != "" || got.TotalMemory != "" {
t.Fatalf("expected ordinary user cluster-wide metrics to be sanitized, got nodes=%d/%d totalCPU=%q totalMemory=%q", got.NodeCount, len(got.Nodes), got.TotalCPU, got.TotalMemory)
}
if got.PodCount != 1 {
t.Fatalf("expected ordinary user pod count to be self scoped, got %d", got.PodCount)
}
}
func TestGetNodeMetricsForbiddenForOrdinaryUser(t *testing.T) {
ctx := authz.WithPrincipal(context.Background(), &authz.Principal{
UserID: "user-1",
Username: "alice",
Role: authz.RoleUser,
WorkspaceID: "workspace-1",
})
svc := NewMonitoringService(
&monitoringClusterRepo{clusters: []*entity.Cluster{{ID: "cluster-1", Name: "cluster", Visibility: authz.VisibilityGlobalShared}}},
&stubMetricsClient{allocations: monitoringAllocations()},
nil,
nil,
)
_, err := svc.GetNodeMetrics(ctx, "cluster-1")
if err != entity.ErrForbidden {
t.Fatalf("expected ordinary user node metrics to be forbidden, got %v", err)
}
}
func seedMonitoringOwners(t *testing.T, ctx context.Context) (*persistencemock.InstanceRepositoryMock, *persistencemock.UserRepositoryMock) {
t.Helper()
instanceRepo := persistencemock.NewInstanceRepositoryMock().(*persistencemock.InstanceRepositoryMock)
userRepo := persistencemock.NewUserRepositoryMock().(*persistencemock.UserRepositoryMock)
for _, user := range []*entity.User{
{ID: "user-1", Username: "alice", PasswordHash: "hash", Role: "user", WorkspaceID: "workspace-1"},
{ID: "user-2", Username: "bob", PasswordHash: "hash", Role: "user", WorkspaceID: "workspace-2"},
} {
if err := userRepo.Create(ctx, user); err != nil {
t.Fatalf("failed to seed user %s: %v", user.ID, err)
}
}
for _, instance := range []*entity.Instance{
{ID: "inst-1", ClusterID: "cluster-1", Name: "alice-app", Namespace: "ocdp-u-alice", WorkspaceID: "workspace-1", OwnerID: "user-1"},
{ID: "inst-2", ClusterID: "cluster-1", Name: "bob-app", Namespace: "ocdp-u-bob", WorkspaceID: "workspace-2", OwnerID: "user-2"},
} {
if err := instanceRepo.Create(ctx, instance); err != nil {
t.Fatalf("failed to seed instance %s: %v", instance.ID, err)
}
}
return instanceRepo, userRepo
}
func monitoringAllocations() []*entity.PodResourceAllocation {
return []*entity.PodResourceAllocation{
{
ClusterID: "cluster-1",
Namespace: "ocdp-u-alice",
PodName: "alice-app-0",
InstanceName: "alice-app",
Allocation: entity.ResourceAllocation{
CPURequestsMilli: 500,
CPULimitsMilli: 1000,
MemoryRequestsBytes: 1024 * 1024 * 1024,
MemoryLimitsBytes: 2 * 1024 * 1024 * 1024,
GPURequests: 1,
GPULimits: 1,
GPUMemoryRequestsMB: 10000,
GPUMemoryLimitsMB: 10000,
},
},
{
ClusterID: "cluster-1",
Namespace: "ocdp-u-bob",
PodName: "bob-app-0",
InstanceName: "bob-app",
Allocation: entity.ResourceAllocation{
CPURequestsMilli: 2000,
CPULimitsMilli: 4000,
MemoryRequestsBytes: 4 * 1024 * 1024 * 1024,
MemoryLimitsBytes: 8 * 1024 * 1024 * 1024,
GPURequests: 2,
GPULimits: 2,
GPUMemoryRequestsMB: 20000,
GPUMemoryLimitsMB: 20000,
},
},
}
}
type monitoringClusterRepo struct {
clusters []*entity.Cluster
}
func (r *monitoringClusterRepo) Create(ctx context.Context, cluster *entity.Cluster) error {
r.clusters = append(r.clusters, cluster)
return nil
}
func (r *monitoringClusterRepo) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
for _, cluster := range r.clusters {
if cluster.ID == id {
return cluster, nil
}
}
return nil, entity.ErrClusterNotFound
}
func (r *monitoringClusterRepo) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
for _, cluster := range r.clusters {
if cluster.Name == name {
return cluster, nil
}
}
return nil, entity.ErrClusterNotFound
}
func (r *monitoringClusterRepo) Update(ctx context.Context, cluster *entity.Cluster) error {
return nil
}
func (r *monitoringClusterRepo) Delete(ctx context.Context, id string) error { return nil }
func (r *monitoringClusterRepo) List(ctx context.Context) ([]*entity.Cluster, error) {
return r.clusters, nil
}
type stubMetricsClient struct {
allocations []*entity.PodResourceAllocation
}
func (c *stubMetricsClient) GetClusterMetrics(ctx context.Context, clusterID string) (*entity.ClusterMetrics, error) {
return &entity.ClusterMetrics{
ClusterID: clusterID,
ClusterName: "cluster",
Status: "healthy",
NodeCount: 3,
PodCount: 99,
TotalCPU: "48 cores",
TotalMemory: "256Gi",
Nodes: []entity.NodeMetrics{{NodeName: "node-a"}},
LastCheck: time.Now(),
}, nil
}
func (c *stubMetricsClient) GetNodeMetrics(ctx context.Context, clusterID string) ([]*entity.NodeMetrics, error) {
return nil, nil
}
func (c *stubMetricsClient) GetPodResourceAllocations(ctx context.Context, clusterID string) ([]*entity.PodResourceAllocation, error) {
return c.allocations, nil
}

View File

@ -0,0 +1,400 @@
package service
import (
"context"
"errors"
"fmt"
"io"
"sort"
"strconv"
"strings"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/util/yaml"
)
var ErrQuotaExceeded = errors.New("quota exceeded")
type QuotaExceededResource struct {
Name string
Required string
Hard string
}
type QuotaPrecheckResult struct {
Allowed bool
Required repository.ResourceEstimate
Hard repository.ResourceVector
Exceeded []QuotaExceededResource
}
type QuotaPrecheckService struct {
helmClient repository.HelmClient
}
func NewQuotaPrecheckService(helmClient repository.HelmClient) *QuotaPrecheckService {
return &QuotaPrecheckService{helmClient: helmClient}
}
func (s *QuotaPrecheckService) EstimateAndCompare(ctx context.Context, cluster *entity.Cluster, workspace *entity.Workspace, instance *entity.Instance) (*QuotaPrecheckResult, error) {
if s == nil || s.helmClient == nil {
return nil, errors.New("quota precheck requires helm client")
}
estimate, err := s.helmClient.EstimateInstanceResources(ctx, cluster, instance)
if err != nil {
return nil, err
}
result, err := CompareWorkspaceQuota(workspace, estimate)
if err != nil {
return result, err
}
return result, nil
}
func (s *QuotaPrecheckService) EstimateAndCompareBinding(ctx context.Context, cluster *entity.Cluster, binding *entity.WorkspaceClusterBinding, usage *repository.ResourceQuotaUsage, target *entity.Instance, current *entity.Instance) (*QuotaPrecheckResult, error) {
if s == nil || s.helmClient == nil {
return nil, errors.New("quota precheck requires helm client")
}
targetEstimate, err := s.helmClient.EstimateInstanceResources(ctx, cluster, target)
if err != nil {
return nil, err
}
var currentEstimate *repository.ResourceEstimate
if current != nil {
currentEstimate, err = s.helmClient.EstimateInstanceResources(ctx, cluster, current)
if err != nil {
return nil, err
}
}
result, err := CompareBindingQuota(binding, usage, targetEstimate, currentEstimate)
if err != nil {
return result, err
}
return result, nil
}
func CompareWorkspaceQuota(workspace *entity.Workspace, estimate *repository.ResourceEstimate) (*QuotaPrecheckResult, error) {
return compareQuotaList(resourceQuotaHard(workspace), nil, estimate, nil)
}
func CompareBindingQuota(binding *entity.WorkspaceClusterBinding, usage *repository.ResourceQuotaUsage, targetEstimate, currentEstimate *repository.ResourceEstimate) (*QuotaPrecheckResult, error) {
return compareQuotaList(bindingQuotaHard(binding), usage, targetEstimate, currentEstimate)
}
func compareQuotaList(hardList corev1.ResourceList, usage *repository.ResourceQuotaUsage, targetEstimate, currentEstimate *repository.ResourceEstimate) (*QuotaPrecheckResult, error) {
if targetEstimate == nil {
targetEstimate = &repository.ResourceEstimate{}
}
current := effectiveQuotaRequests(currentEstimate)
target := effectiveQuotaRequests(targetEstimate)
used := repository.ResourceVector{}
if usage != nil {
used = usage.Used
}
required := addResourceVector(subtractResourceVectorFloorZero(used, current), target)
hard := resourceVectorFromQuotaHard(hardList)
result := &QuotaPrecheckResult{
Allowed: true,
Required: repository.ResourceEstimate{
Requests: required,
},
Hard: hard,
}
addExceeded := func(name, required, limit string) {
result.Allowed = false
result.Exceeded = append(result.Exceeded, QuotaExceededResource{
Name: name,
Required: required,
Hard: limit,
})
}
if quantity, ok := hardList[corev1.ResourceName("requests.cpu")]; ok && required.CPU.Cmp(quantity) > 0 {
addExceeded("requests.cpu", required.CPU.String(), quantity.String())
}
if quantity, ok := hardList[corev1.ResourceName("requests.memory")]; ok && required.Memory.Cmp(quantity) > 0 {
addExceeded("requests.memory", required.Memory.String(), quantity.String())
}
if quantity, ok := hardList[corev1.ResourceName("requests.nvidia.com/gpu")]; ok && required.GPU > quantity.Value() {
addExceeded("requests.nvidia.com/gpu", strconv.FormatInt(required.GPU, 10), quantity.String())
}
if quantity, ok := hardList[corev1.ResourceName("requests.nvidia.com/gpumem")]; ok && required.GPUMemoryMB > quantity.Value() {
addExceeded("requests.nvidia.com/gpumem", strconv.FormatInt(required.GPUMemoryMB, 10), quantity.String())
}
sort.Slice(result.Exceeded, func(i, j int) bool {
return result.Exceeded[i].Name < result.Exceeded[j].Name
})
if !result.Allowed {
return result, ErrQuotaExceeded
}
return result, nil
}
func legacyCompareWorkspaceQuota(workspace *entity.Workspace, estimate *repository.ResourceEstimate) (*QuotaPrecheckResult, error) {
if estimate == nil {
estimate = &repository.ResourceEstimate{}
}
hardList := resourceQuotaHard(workspace)
hard := resourceVectorFromQuotaHard(hardList)
result := &QuotaPrecheckResult{
Allowed: true,
Required: *estimate,
Hard: hard,
}
effectiveRequests := effectiveQuotaRequests(estimate)
addExceeded := func(name, required, limit string) {
result.Allowed = false
result.Exceeded = append(result.Exceeded, QuotaExceededResource{
Name: name,
Required: required,
Hard: limit,
})
}
if quantity, ok := hardList[corev1.ResourceName("requests.cpu")]; ok && effectiveRequests.CPU.Cmp(quantity) > 0 {
addExceeded("requests.cpu", effectiveRequests.CPU.String(), quantity.String())
}
if quantity, ok := hardList[corev1.ResourceName("requests.memory")]; ok && effectiveRequests.Memory.Cmp(quantity) > 0 {
addExceeded("requests.memory", effectiveRequests.Memory.String(), quantity.String())
}
if quantity, ok := hardList[corev1.ResourceName("requests.nvidia.com/gpu")]; ok && effectiveRequests.GPU > quantity.Value() {
addExceeded("requests.nvidia.com/gpu", strconv.FormatInt(effectiveRequests.GPU, 10), quantity.String())
}
if quantity, ok := hardList[corev1.ResourceName("requests.nvidia.com/gpumem")]; ok && effectiveRequests.GPUMemoryMB > quantity.Value() {
addExceeded("requests.nvidia.com/gpumem", strconv.FormatInt(effectiveRequests.GPUMemoryMB, 10), quantity.String())
}
sort.Slice(result.Exceeded, func(i, j int) bool {
return result.Exceeded[i].Name < result.Exceeded[j].Name
})
if !result.Allowed {
return result, ErrQuotaExceeded
}
return result, nil
}
func effectiveQuotaRequests(estimate *repository.ResourceEstimate) repository.ResourceVector {
if estimate == nil {
return repository.ResourceVector{}
}
return repository.ResourceVector{
CPU: maxQuantity(estimate.Requests.CPU, estimate.Limits.CPU),
Memory: maxQuantity(estimate.Requests.Memory, estimate.Limits.Memory),
GPU: maxInt64(estimate.Requests.GPU, estimate.Limits.GPU),
GPUMemoryMB: maxInt64(estimate.Requests.GPUMemoryMB, estimate.Limits.GPUMemoryMB),
}
}
func addResourceVector(left, right repository.ResourceVector) repository.ResourceVector {
out := left
out.CPU.Add(right.CPU)
out.Memory.Add(right.Memory)
out.GPU += right.GPU
out.GPUMemoryMB += right.GPUMemoryMB
return out
}
func subtractResourceVectorFloorZero(left, right repository.ResourceVector) repository.ResourceVector {
out := left
out.CPU.Sub(right.CPU)
if out.CPU.Sign() < 0 {
out.CPU = resource.Quantity{}
}
out.Memory.Sub(right.Memory)
if out.Memory.Sign() < 0 {
out.Memory = resource.Quantity{}
}
out.GPU -= right.GPU
if out.GPU < 0 {
out.GPU = 0
}
out.GPUMemoryMB -= right.GPUMemoryMB
if out.GPUMemoryMB < 0 {
out.GPUMemoryMB = 0
}
return out
}
func maxQuantity(left, right resource.Quantity) resource.Quantity {
if left.Cmp(right) >= 0 {
return left
}
return right
}
func maxInt64(left, right int64) int64 {
if left >= right {
return left
}
return right
}
func EstimateRenderedManifestResources(manifest string) (*repository.ResourceEstimate, error) {
decoder := yaml.NewYAMLOrJSONDecoder(strings.NewReader(manifest), 4096)
estimate := &repository.ResourceEstimate{}
for {
var obj unstructured.Unstructured
if err := decoder.Decode(&obj); err != nil {
if errors.Is(err, io.EOF) {
break
}
return nil, fmt.Errorf("failed to decode rendered manifest: %w", err)
}
if obj.GetKind() == "" {
continue
}
podSpec, replicas, ok := podTemplateSpec(obj.Object)
if !ok {
continue
}
addPodSpecResources(estimate, podSpec, replicas)
}
return estimate, nil
}
func resourceVectorFromQuotaHard(hard corev1.ResourceList) repository.ResourceVector {
gpu := hard[corev1.ResourceName("requests.nvidia.com/gpu")]
gpuMemory := hard[corev1.ResourceName("requests.nvidia.com/gpumem")]
return repository.ResourceVector{
CPU: hard[corev1.ResourceName("requests.cpu")],
Memory: hard[corev1.ResourceName("requests.memory")],
GPU: gpu.Value(),
GPUMemoryMB: gpuMemory.Value(),
}
}
func bindingQuotaHard(binding *entity.WorkspaceClusterBinding) corev1.ResourceList {
hard := corev1.ResourceList{}
if binding == nil {
return hard
}
addQuantity := func(name corev1.ResourceName, value string) {
value = normalizeStandardQuotaQuantity(value)
if value == "" {
return
}
if quantity, err := resource.ParseQuantity(value); err == nil {
hard[name] = quantity
}
}
addGPUMemoryQuantity := func(value string) {
value, err := normalizeGPUMemoryQuota(value)
if err != nil || value == "" {
return
}
if quantity, err := resource.ParseQuantity(value); err == nil {
hard[corev1.ResourceName("requests.nvidia.com/gpumem")] = quantity
}
}
addQuantity(corev1.ResourceName("requests.cpu"), binding.QuotaCPU)
addQuantity(corev1.ResourceName("requests.memory"), binding.QuotaMemory)
addQuantity(corev1.ResourceName("requests.nvidia.com/gpu"), binding.QuotaGPU)
addGPUMemoryQuantity(binding.QuotaGPUMem)
return hard
}
func podTemplateSpec(obj map[string]interface{}) (map[string]interface{}, int64, bool) {
kind, _, _ := unstructured.NestedString(obj, "kind")
switch kind {
case "Pod":
spec, ok := nestedMap(obj, "spec")
return spec, 1, ok
case "Deployment", "ReplicaSet", "StatefulSet", "ReplicationController":
spec, replicas, ok := workloadTemplateSpec(obj)
return spec, replicas, ok
case "DaemonSet", "Job":
spec, ok := nestedMap(obj, "spec", "template", "spec")
return spec, 1, ok
case "CronJob":
spec, ok := nestedMap(obj, "spec", "jobTemplate", "spec", "template", "spec")
return spec, 1, ok
default:
return nil, 0, false
}
}
func workloadTemplateSpec(obj map[string]interface{}) (map[string]interface{}, int64, bool) {
spec, ok := nestedMap(obj, "spec", "template", "spec")
if !ok {
return nil, 0, false
}
replicas, _, err := unstructured.NestedInt64(obj, "spec", "replicas")
if err != nil || replicas < 1 {
replicas = 1
}
return spec, replicas, true
}
func nestedMap(obj map[string]interface{}, fields ...string) (map[string]interface{}, bool) {
value, ok, err := unstructured.NestedMap(obj, fields...)
return value, ok && err == nil
}
func addPodSpecResources(estimate *repository.ResourceEstimate, podSpec map[string]interface{}, replicas int64) {
if replicas < 1 {
replicas = 1
}
for _, field := range []string{"initContainers", "containers"} {
containers, ok, err := unstructured.NestedSlice(podSpec, field)
if err != nil || !ok {
continue
}
for _, item := range containers {
container, ok := item.(map[string]interface{})
if !ok {
continue
}
addContainerResourceList(&estimate.Requests, replicas, container, "resources", "requests")
addContainerResourceList(&estimate.Limits, replicas, container, "resources", "limits")
}
}
}
func addContainerResourceList(target *repository.ResourceVector, replicas int64, container map[string]interface{}, fields ...string) {
resources, ok := nestedMap(container, fields...)
if !ok {
return
}
for name, value := range resources {
switch name {
case "cpu":
addQuantity(&target.CPU, value, replicas)
case "memory":
addQuantity(&target.Memory, value, replicas)
case "nvidia.com/gpu", "requests.nvidia.com/gpu", "limits.nvidia.com/gpu":
target.GPU += parseIntegerResource(value) * replicas
case "nvidia.com/gpumem", "requests.nvidia.com/gpumem", "limits.nvidia.com/gpumem":
target.GPUMemoryMB += parseGPUMemoryResource(value) * replicas
}
}
}
func addQuantity(target *resource.Quantity, value interface{}, replicas int64) {
quantity, err := resource.ParseQuantity(fmt.Sprint(value))
if err != nil {
return
}
quantity.Mul(replicas)
target.Add(quantity)
}
func parseIntegerResource(value interface{}) int64 {
quantity, err := resource.ParseQuantity(fmt.Sprint(value))
if err != nil {
return 0
}
return quantity.Value()
}
func parseGPUMemoryResource(value interface{}) int64 {
normalized, err := normalizeGPUMemoryQuota(fmt.Sprint(value))
if err != nil || normalized == "" {
return 0
}
parsed, err := strconv.ParseInt(normalized, 10, 64)
if err != nil {
return 0
}
return parsed
}

View File

@ -0,0 +1,241 @@
package service
import (
"errors"
"testing"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
)
func TestCompareWorkspaceQuotaReportsExceededRequests(t *testing.T) {
t.Parallel()
workspace := &entity.Workspace{
QuotaCPU: "2",
QuotaMemory: "4Gi",
QuotaGPU: "1",
QuotaGPUMem: "10000",
}
estimate := &repository.ResourceEstimate{
Requests: repository.ResourceVector{
CPU: resource.MustParse("2500m"),
Memory: resource.MustParse("3Gi"),
GPU: 1,
GPUMemoryMB: 12000,
},
}
result, err := CompareWorkspaceQuota(workspace, estimate)
if !errors.Is(err, ErrQuotaExceeded) {
t.Fatalf("expected ErrQuotaExceeded, got %v", err)
}
if result == nil || result.Allowed {
t.Fatalf("expected denied result, got %#v", result)
}
if len(result.Exceeded) != 2 {
t.Fatalf("expected 2 exceeded resources, got %#v", result.Exceeded)
}
if result.Exceeded[0].Name != "requests.cpu" {
t.Fatalf("expected requests.cpu exceeded first, got %#v", result.Exceeded)
}
if result.Exceeded[1].Name != "requests.nvidia.com/gpumem" {
t.Fatalf("expected requests.nvidia.com/gpumem exceeded second, got %#v", result.Exceeded)
}
}
func TestCompareWorkspaceQuotaUsesLimitsAsEffectiveRequests(t *testing.T) {
t.Parallel()
workspace := &entity.Workspace{
QuotaGPU: "0",
QuotaGPUMem: "9999",
}
estimate := &repository.ResourceEstimate{
Limits: repository.ResourceVector{
GPU: 1,
GPUMemoryMB: 10000,
},
}
result, err := CompareWorkspaceQuota(workspace, estimate)
if !errors.Is(err, ErrQuotaExceeded) {
t.Fatalf("expected ErrQuotaExceeded from limits-only GPU resources, got %v", err)
}
if result == nil || len(result.Exceeded) != 2 {
t.Fatalf("expected gpu and gpumem to be exceeded, got %#v", result)
}
}
func TestCompareBindingQuotaSubtractsCurrentReleaseFromUsedQuota(t *testing.T) {
t.Parallel()
binding := &entity.WorkspaceClusterBinding{
QuotaCPU: "1",
QuotaMemory: "2Gi",
QuotaGPU: "1",
QuotaGPUMem: "10000",
}
usage := &repository.ResourceQuotaUsage{
Used: repository.ResourceVector{
CPU: resource.MustParse("1"),
Memory: resource.MustParse("2Gi"),
GPU: 1,
GPUMemoryMB: 10000,
},
}
current := &repository.ResourceEstimate{
Requests: repository.ResourceVector{
CPU: resource.MustParse("1"),
Memory: resource.MustParse("2Gi"),
GPU: 1,
GPUMemoryMB: 10000,
},
}
targetSameSize := &repository.ResourceEstimate{
Requests: repository.ResourceVector{
CPU: resource.MustParse("1"),
Memory: resource.MustParse("2Gi"),
GPU: 1,
GPUMemoryMB: 10000,
},
}
result, err := CompareBindingQuota(binding, usage, targetSameSize, current)
if err != nil {
t.Fatalf("expected update with same resource footprint to fit quota, got %v", err)
}
if result.Required.Requests.GPU != 1 || result.Required.Requests.GPUMemoryMB != 10000 {
t.Fatalf("expected required resources to subtract current release before target, got %#v", result.Required.Requests)
}
targetScaledUp := &repository.ResourceEstimate{
Requests: repository.ResourceVector{
CPU: resource.MustParse("2"),
Memory: resource.MustParse("4Gi"),
GPU: 2,
GPUMemoryMB: 20000,
},
}
result, err = CompareBindingQuota(binding, usage, targetScaledUp, current)
if !errors.Is(err, ErrQuotaExceeded) {
t.Fatalf("expected scale-up beyond quota to be rejected, got %v", err)
}
if result == nil || result.Allowed {
t.Fatalf("expected denied quota result, got %#v", result)
}
}
func TestCompareBindingQuotaTreatsExplicitZeroGPUAsNoGPUAllowed(t *testing.T) {
t.Parallel()
binding := &entity.WorkspaceClusterBinding{
QuotaCPU: "8",
QuotaMemory: "32Gi",
QuotaGPU: "0",
QuotaGPUMem: "0",
}
vllmLikeEstimate := &repository.ResourceEstimate{
Requests: repository.ResourceVector{
CPU: resource.MustParse("2"),
Memory: resource.MustParse("8Gi"),
GPU: 1,
GPUMemoryMB: 10000,
},
}
result, err := CompareBindingQuota(binding, &repository.ResourceQuotaUsage{}, vllmLikeEstimate, nil)
if !errors.Is(err, ErrQuotaExceeded) {
t.Fatalf("expected GPU request to exceed explicit zero quota, got %v", err)
}
exceeded := map[string]bool{}
for _, item := range result.Exceeded {
exceeded[item.Name] = true
}
for _, name := range []string{"requests.nvidia.com/gpu", "requests.nvidia.com/gpumem"} {
if !exceeded[name] {
t.Fatalf("expected %s to be exceeded, got %#v", name, result.Exceeded)
}
}
}
func TestBindingQuotaHardKeepsGPUMemoryAsIntegerMB(t *testing.T) {
t.Parallel()
hard := bindingQuotaHard(&entity.WorkspaceClusterBinding{QuotaGPU: "1", QuotaGPUMem: "10000"})
gpuMem := hard[corev1.ResourceName("requests.nvidia.com/gpumem")]
if gpuMem.Value() != 10000 {
t.Fatalf("expected gpumem quota to remain integer MB 10000, got %s value=%d", gpuMem.String(), gpuMem.Value())
}
}
func TestEstimateRenderedManifestResourcesSumsPodTemplates(t *testing.T) {
t.Parallel()
manifest := `
apiVersion: apps/v1
kind: Deployment
metadata:
name: gpu-worker
spec:
replicas: 3
template:
spec:
initContainers:
- name: init
image: busybox
resources:
requests:
cpu: 100m
memory: 128Mi
containers:
- name: app
image: busybox
resources:
requests:
cpu: 500m
memory: 1Gi
nvidia.com/gpu: "1"
nvidia.com/gpumem: "10000"
limits:
cpu: "1"
memory: 2Gi
nvidia.com/gpu: "1"
nvidia.com/gpumem: "12000"
---
apiVersion: v1
kind: Service
metadata:
name: ignored
`
estimate, err := EstimateRenderedManifestResources(manifest)
if err != nil {
t.Fatalf("EstimateRenderedManifestResources returned error: %v", err)
}
if estimate.Requests.CPU.Cmp(resource.MustParse("1800m")) != 0 {
t.Fatalf("expected requests cpu 1800m, got %s", estimate.Requests.CPU.String())
}
if estimate.Requests.Memory.Cmp(resource.MustParse("3456Mi")) != 0 {
t.Fatalf("expected requests memory 3456Mi, got %s", estimate.Requests.Memory.String())
}
if estimate.Requests.GPU != 3 {
t.Fatalf("expected requests gpu 3, got %d", estimate.Requests.GPU)
}
if estimate.Requests.GPUMemoryMB != 30000 {
t.Fatalf("expected requests gpumem 30000, got %d", estimate.Requests.GPUMemoryMB)
}
if estimate.Limits.CPU.Cmp(resource.MustParse("3")) != 0 {
t.Fatalf("expected limits cpu 3, got %s", estimate.Limits.CPU.String())
}
if estimate.Limits.Memory.Cmp(resource.MustParse("6Gi")) != 0 {
t.Fatalf("expected limits memory 6Gi, got %s", estimate.Limits.Memory.String())
}
if estimate.Limits.GPU != 3 {
t.Fatalf("expected limits gpu 3, got %d", estimate.Limits.GPU)
}
if estimate.Limits.GPUMemoryMB != 36000 {
t.Fatalf("expected limits gpumem 36000, got %d", estimate.Limits.GPUMemoryMB)
}
}

View File

@ -9,6 +9,10 @@ import (
func normalizeStandardQuotaQuantity(value string) string {
value = strings.TrimSpace(value)
switch strings.ToLower(value) {
case "unlimited", "none", "no-limit", "nolimit":
return ""
}
upper := strings.ToUpper(value)
switch {
case strings.HasSuffix(upper, "MB"):

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"sort"
"strings"
"time"
"github.com/google/uuid"
@ -94,17 +95,17 @@ func (s *WorkspaceService) EnsureClusterBinding(ctx context.Context, workspaceID
ClusterID: cluster.ID,
Namespace: workspace.K8sNamespace,
ServiceAccount: workspace.K8sSAName,
QuotaCPU: workspace.QuotaCPU,
QuotaMemory: workspace.QuotaMemory,
QuotaGPU: workspace.QuotaGPU,
QuotaGPUMem: workspace.QuotaGPUMem,
QuotaCPU: strings.TrimSpace(workspace.QuotaCPU),
QuotaMemory: strings.TrimSpace(workspace.QuotaMemory),
QuotaGPU: zeroIfEmptyQuota(workspace.QuotaGPU),
QuotaGPUMem: zeroIfEmptyQuota(workspace.QuotaGPUMem),
Status: "active",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
tenantBinding := entity.NewTenantBinding(binding.Namespace)
tenantBinding.ServiceAccountName = binding.ServiceAccount
tenantBinding.ResourceQuotaHard = resourceQuotaHard(workspace)
tenantBinding.ResourceQuotaHard = bindingQuotaHard(binding)
if s.tenantClient != nil {
if err := s.tenantClient.EnsureTenant(ctx, cluster, tenantBinding); err != nil {
return nil, err
@ -145,10 +146,22 @@ func (s *WorkspaceService) IssueKubeconfig(ctx context.Context, workspaceID, clu
if err != nil {
return nil, err
}
} else {
binding.QuotaCPU = strings.TrimSpace(workspace.QuotaCPU)
binding.QuotaMemory = strings.TrimSpace(workspace.QuotaMemory)
binding.QuotaGPU = zeroIfEmptyQuota(workspace.QuotaGPU)
binding.QuotaGPUMem = zeroIfEmptyQuota(workspace.QuotaGPUMem)
binding.UpdatedAt = time.Now()
}
tenantBinding := entity.NewTenantBinding(binding.Namespace)
tenantBinding.ServiceAccountName = binding.ServiceAccount
tenantBinding.ResourceQuotaHard = resourceQuotaHard(workspace)
tenantBinding.ResourceQuotaHard = bindingQuotaHard(binding)
if s.tenantClient != nil {
if err := s.tenantClient.EnsureTenant(ctx, cluster, tenantBinding); err != nil {
return nil, err
}
}
_ = s.bindingRepo.Upsert(ctx, binding)
kubeconfig, err := s.tenantClient.IssueKubeconfig(ctx, cluster, tenantBinding, ttl)
if err != nil {
return nil, err