package postgres import ( "context" "database/sql" "fmt" "time" "github.com/google/uuid" "github.com/ocdp/cluster-service/internal/domain/entity" "github.com/ocdp/cluster-service/internal/domain/repository" "github.com/ocdp/cluster-service/internal/pkg/crypto" ) // ClusterRepository PostgreSQL 集群仓储实现 type ClusterRepository struct { db *DB encryptor crypto.Encryptor } // NewClusterRepository 创建 PostgreSQL 集群仓储 func NewClusterRepository(db *DB, encryptor crypto.Encryptor) repository.ClusterRepository { return &ClusterRepository{ db: db, encryptor: encryptor, } } // Create 创建集群 func (r *ClusterRepository) Create(ctx context.Context, cluster *entity.Cluster) error { if cluster.ID == "" { cluster.ID = uuid.New().String() } // 加密敏感数据 encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData) if err != nil { return fmt.Errorf("failed to encrypt CA data: %w", err) } encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData) if err != nil { return fmt.Errorf("failed to encrypt cert data: %w", err) } encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData) if err != nil { return fmt.Errorf("failed to encrypt key data: %w", err) } encryptedToken, err := r.encryptor.Encrypt(cluster.Token) if err != nil { return fmt.Errorf("failed to encrypt token: %w", err) } query := ` INSERT INTO clusters (id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ` _, err = r.db.conn.ExecContext(ctx, query, cluster.ID, cluster.Name, cluster.Host, encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken, cluster.Description, cluster.CreatedAt, cluster.UpdatedAt, ) if err != nil { return fmt.Errorf("failed to create cluster: %w", err) } return nil } // GetByID 根据 ID 获取集群 func (r *ClusterRepository) GetByID(ctx context.Context, id string) (*entity.Cluster, error) { query := ` SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at FROM clusters WHERE id = $1 ` cluster := &entity.Cluster{} var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string err := r.db.conn.QueryRowContext(ctx, query, id).Scan( &cluster.ID, &cluster.Name, &cluster.Host, &encryptedCAData, &encryptedCertData, &encryptedKeyData, &encryptedToken, &cluster.Description, &cluster.CreatedAt, &cluster.UpdatedAt, ) if err == sql.ErrNoRows { return nil, entity.ErrClusterNotFound } if err != nil { return nil, fmt.Errorf("failed to get cluster: %w", err) } // 解密敏感数据 cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData) if err != nil { return nil, fmt.Errorf("failed to decrypt CA data: %w", err) } cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData) if err != nil { return nil, fmt.Errorf("failed to decrypt cert data: %w", err) } cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData) if err != nil { return nil, fmt.Errorf("failed to decrypt key data: %w", err) } cluster.Token, err = r.encryptor.Decrypt(encryptedToken) if err != nil { return nil, fmt.Errorf("failed to decrypt token: %w", err) } return cluster, nil } // GetByName 根据名称获取集群 func (r *ClusterRepository) GetByName(ctx context.Context, name string) (*entity.Cluster, error) { query := ` SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at FROM clusters WHERE name = $1 ` cluster := &entity.Cluster{} var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string err := r.db.conn.QueryRowContext(ctx, query, name).Scan( &cluster.ID, &cluster.Name, &cluster.Host, &encryptedCAData, &encryptedCertData, &encryptedKeyData, &encryptedToken, &cluster.Description, &cluster.CreatedAt, &cluster.UpdatedAt, ) if err == sql.ErrNoRows { return nil, entity.ErrClusterNotFound } if err != nil { return nil, fmt.Errorf("failed to get cluster: %w", err) } // 解密敏感数据 cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData) if err != nil { return nil, fmt.Errorf("failed to decrypt CA data: %w", err) } cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData) if err != nil { return nil, fmt.Errorf("failed to decrypt cert data: %w", err) } cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData) if err != nil { return nil, fmt.Errorf("failed to decrypt key data: %w", err) } cluster.Token, err = r.encryptor.Decrypt(encryptedToken) if err != nil { return nil, fmt.Errorf("failed to decrypt token: %w", err) } return cluster, nil } // Update 更新集群 func (r *ClusterRepository) Update(ctx context.Context, cluster *entity.Cluster) error { cluster.UpdatedAt = time.Now() // 加密敏感数据 encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData) if err != nil { return fmt.Errorf("failed to encrypt CA data: %w", err) } encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData) if err != nil { return fmt.Errorf("failed to encrypt cert data: %w", err) } encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData) if err != nil { return fmt.Errorf("failed to encrypt key data: %w", err) } encryptedToken, err := r.encryptor.Encrypt(cluster.Token) if err != nil { return fmt.Errorf("failed to encrypt token: %w", err) } query := ` UPDATE clusters SET name = $1, host = $2, ca_data = $3, cert_data = $4, key_data = $5, token = $6, description = $7, updated_at = $8 WHERE id = $9 ` result, err := r.db.conn.ExecContext(ctx, query, cluster.Name, cluster.Host, encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken, cluster.Description, cluster.UpdatedAt, cluster.ID, ) if err != nil { return fmt.Errorf("failed to update cluster: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rows == 0 { return entity.ErrClusterNotFound } return nil } // Delete 删除集群 func (r *ClusterRepository) Delete(ctx context.Context, id string) error { query := `DELETE FROM clusters WHERE id = $1` result, err := r.db.conn.ExecContext(ctx, query, id) if err != nil { return fmt.Errorf("failed to delete cluster: %w", err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("failed to get affected rows: %w", err) } if rows == 0 { return entity.ErrClusterNotFound } return nil } // List 列出所有集群 func (r *ClusterRepository) List(ctx context.Context) ([]*entity.Cluster, error) { query := ` SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at FROM clusters ORDER BY created_at DESC ` rows, err := r.db.conn.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("failed to list clusters: %w", err) } defer rows.Close() clusters := make([]*entity.Cluster, 0) for rows.Next() { cluster := &entity.Cluster{} var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string err := rows.Scan( &cluster.ID, &cluster.Name, &cluster.Host, &encryptedCAData, &encryptedCertData, &encryptedKeyData, &encryptedToken, &cluster.Description, &cluster.CreatedAt, &cluster.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("failed to scan cluster: %w", err) } // 解密敏感数据 cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData) if err != nil { return nil, fmt.Errorf("failed to decrypt CA data: %w", err) } cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData) if err != nil { return nil, fmt.Errorf("failed to decrypt cert data: %w", err) } cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData) if err != nil { return nil, fmt.Errorf("failed to decrypt key data: %w", err) } cluster.Token, err = r.encryptor.Decrypt(encryptedToken) if err != nil { return nil, fmt.Errorf("failed to decrypt token: %w", err) } clusters = append(clusters, cluster) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("rows iteration error: %w", err) } return clusters, nil }