package postgres import ( "context" "database/sql" "fmt" "time" "github.com/google/uuid" "github.com/ocdp/cluster-service/internal/domain/entity" "github.com/ocdp/cluster-service/internal/domain/repository" "github.com/ocdp/cluster-service/internal/pkg/crypto" ) // ClusterRepository PostgreSQL 集群仓储实现 type ClusterRepository struct { db *DB encryptor crypto.Encryptor } // NewClusterRepository 创建 PostgreSQL 集群仓储 func NewClusterRepository(db *DB, encryptor crypto.Encryptor) repository.ClusterRepository { return &ClusterRepository{ db: db, encryptor: encryptor, } } // Create 创建集群 func (r *ClusterRepository) Create(ctx context.Context, cluster *entity.Cluster) error { if cluster.ID == "" { cluster.ID = uuid.New().String() } // 设置默认值 if cluster.IsolationMode == "" { cluster.IsolationMode = entity.IsolationModeNamespace } // 加密敏感数据 encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData) if err != nil { return fmt.Errorf("failed to encrypt CA data: %w", err) } encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData) if err != nil { return fmt.Errorf("failed to encrypt cert data: %w", err) } encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData) if err != nil { return fmt.Errorf("failed to encrypt key data: %w", err) } encryptedToken, err := r.encryptor.Encrypt(cluster.Token) if err != nil { return fmt.Errorf("failed to encrypt token: %w", err) } query := ` INSERT INTO clusters (id, workspace_id, owner_id, name, host, ca_data, cert_data, key_data, token, description, isolation_mode, default_namespace, is_shared, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) ` _, err = r.db.conn.ExecContext(ctx, query, cluster.ID, cluster.WorkspaceID, cluster.OwnerID, cluster.Name, cluster.Host, encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken, cluster.Description, cluster.IsolationMode, cluster.DefaultNamespace, cluster.IsShared, cluster.CreatedAt, cluster.UpdatedAt, ) if err != nil { return fmt.Errorf("failed to create cluster: %w", err) } return nil } // GetByID 根据 ID 获取集群 func (r *ClusterRepository) GetByID(ctx context.Context, id string) (*entity.Cluster, error) { query := ` SELECT id, workspace_id, owner_id, name, host, ca_data, cert_data, key_data, token, description, isolation_mode, default_namespace, is_shared, created_at, updated_at FROM clusters WHERE id = $1 ` cluster := &entity.Cluster{} var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string err := r.db.conn.QueryRowContext(ctx, query, id).Scan( &cluster.ID, &cluster.WorkspaceID, &cluster.OwnerID, &cluster.Name, &cluster.Host, &encryptedCAData, &encryptedCertData, &encryptedKeyData, &encryptedToken, &cluster.Description, &cluster.IsolationMode, &cluster.DefaultNamespace, &cluster.IsShared, &cluster.CreatedAt, &cluster.UpdatedAt, ) if err == sql.ErrNoRows { return nil, entity.ErrClusterNotFound } if err != nil { return nil, fmt.Errorf("failed to get cluster: %w", err) } // 解密敏感数据 cluster.CAData, _ = r.encryptor.Decrypt(encryptedCAData) cluster.CertData, _ = r.encryptor.Decrypt(encryptedCertData) cluster.KeyData, _ = r.encryptor.Decrypt(encryptedKeyData) cluster.Token, _ = r.encryptor.Decrypt(encryptedToken) return cluster, nil } // GetByName 根据名称获取集群 func (r *ClusterRepository) GetByName(ctx context.Context, name string) (*entity.Cluster, error) { query := ` SELECT id, workspace_id, owner_id, name, host, ca_data, cert_data, key_data, token, description, isolation_mode, default_namespace, is_shared, created_at, updated_at FROM clusters WHERE name = $1 ` cluster := &entity.Cluster{} var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string err := r.db.conn.QueryRowContext(ctx, query, name).Scan( &cluster.ID, &cluster.WorkspaceID, &cluster.OwnerID, &cluster.Name, &cluster.Host, &encryptedCAData, &encryptedCertData, &encryptedKeyData, &encryptedToken, &cluster.Description, &cluster.IsolationMode, &cluster.DefaultNamespace, &cluster.IsShared, &cluster.CreatedAt, &cluster.UpdatedAt, ) if err == sql.ErrNoRows { return nil, entity.ErrClusterNotFound } if err != nil { return nil, fmt.Errorf("failed to get cluster: %w", err) } // 解密敏感数据 cluster.CAData, _ = r.encryptor.Decrypt(encryptedCAData) cluster.CertData, _ = r.encryptor.Decrypt(encryptedCertData) cluster.KeyData, _ = r.encryptor.Decrypt(encryptedKeyData) cluster.Token, _ = r.encryptor.Decrypt(encryptedToken) return cluster, nil } // Update 更新集群 func (r *ClusterRepository) Update(ctx context.Context, cluster *entity.Cluster) error { cluster.UpdatedAt = time.Now() // 加密敏感数据 encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData) if err != nil { return fmt.Errorf("failed to encrypt CA data: %w", err) } encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData) if err != nil { return fmt.Errorf("failed to encrypt cert data: %w", err) } encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData) if err != nil { return fmt.Errorf("failed to encrypt key data: %w", err) } encryptedToken, err := r.encryptor.Encrypt(cluster.Token) if err != nil { return fmt.Errorf("failed to encrypt token: %w", err) } query := ` UPDATE clusters SET name = $1, host = $2, ca_data = $3, cert_data = $4, key_data = $5, token = $6, description = $7, isolation_mode = $8, default_namespace = $9, is_shared = $10, updated_at = $11 WHERE id = $12 ` result, err := r.db.conn.ExecContext(ctx, query, cluster.Name, cluster.Host, encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken, cluster.Description, cluster.IsolationMode, cluster.DefaultNamespace, cluster.IsShared, cluster.UpdatedAt, cluster.ID, ) if err != nil { return fmt.Errorf("failed to update cluster: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rows == 0 { return entity.ErrClusterNotFound } return nil } // Delete 删除集群 func (r *ClusterRepository) Delete(ctx context.Context, id string) error { query := `DELETE FROM clusters WHERE id = $1` result, err := r.db.conn.ExecContext(ctx, query, id) if err != nil { return fmt.Errorf("failed to delete cluster: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rows == 0 { return entity.ErrClusterNotFound } return nil } // List 列出所有集群 func (r *ClusterRepository) List(ctx context.Context) ([]*entity.Cluster, error) { query := ` SELECT id, workspace_id, owner_id, name, host, ca_data, cert_data, key_data, token, description, isolation_mode, default_namespace, is_shared, created_at, updated_at FROM clusters ORDER BY created_at DESC ` rows, err := r.db.conn.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("failed to list clusters: %w", err) } defer rows.Close() return r.scanClusters(rows) } // GetByWorkspace 获取 workspace 的所有集群(包括共享集群) func (r *ClusterRepository) GetByWorkspace(ctx context.Context, workspaceID string) ([]*entity.Cluster, error) { query := ` SELECT id, workspace_id, owner_id, name, host, ca_data, cert_data, key_data, token, description, isolation_mode, default_namespace, is_shared, created_at, updated_at FROM clusters WHERE workspace_id = $1 OR is_shared = TRUE ORDER BY is_shared, created_at DESC ` rows, err := r.db.conn.QueryContext(ctx, query, workspaceID) if err != nil { return nil, fmt.Errorf("failed to list clusters by workspace: %w", err) } defer rows.Close() return r.scanClusters(rows) } // GetShared 获取所有共享集群 func (r *ClusterRepository) GetShared(ctx context.Context) ([]*entity.Cluster, error) { query := ` SELECT id, workspace_id, owner_id, name, host, ca_data, cert_data, key_data, token, description, isolation_mode, default_namespace, is_shared, created_at, updated_at FROM clusters WHERE is_shared = TRUE ORDER BY created_at DESC ` rows, err := r.db.conn.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("failed to list shared clusters: %w", err) } defer rows.Close() return r.scanClusters(rows) } // scanClusters 扫描多行结果 func (r *ClusterRepository) scanClusters(rows *sql.Rows) ([]*entity.Cluster, error) { clusters := make([]*entity.Cluster, 0) for rows.Next() { cluster := &entity.Cluster{} var ( encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken sql.NullString workspaceID, ownerID, defaultNamespace sql.NullString ) err := rows.Scan( &cluster.ID, &workspaceID, &ownerID, &cluster.Name, &cluster.Host, &encryptedCAData, &encryptedCertData, &encryptedKeyData, &encryptedToken, &cluster.Description, &cluster.IsolationMode, &defaultNamespace, &cluster.IsShared, &cluster.CreatedAt, &cluster.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("failed to scan cluster: %w", err) } // 处理 NULL 值 cluster.WorkspaceID = workspaceID.String cluster.OwnerID = ownerID.String cluster.DefaultNamespace = defaultNamespace.String // 解密敏感数据 if encryptedCAData.Valid { cluster.CAData, _ = r.encryptor.Decrypt(encryptedCAData.String) } if encryptedCertData.Valid { cluster.CertData, _ = r.encryptor.Decrypt(encryptedCertData.String) } if encryptedKeyData.Valid { cluster.KeyData, _ = r.encryptor.Decrypt(encryptedKeyData.String) } if encryptedToken.Valid { cluster.Token, _ = r.encryptor.Decrypt(encryptedToken.String) } clusters = append(clusters, cluster) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("rows iteration error: %w", err) } return clusters, nil }