338 lines
8.1 KiB
Go
338 lines
8.1 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/ocdp/cluster-service/internal/domain/entity"
|
|
"github.com/ocdp/cluster-service/internal/domain/repository"
|
|
"github.com/ocdp/cluster-service/internal/pkg/crypto"
|
|
)
|
|
|
|
// ClusterRepository PostgreSQL 集群仓储实现
|
|
type ClusterRepository struct {
|
|
db *DB
|
|
encryptor crypto.Encryptor
|
|
}
|
|
|
|
// NewClusterRepository 创建 PostgreSQL 集群仓储
|
|
func NewClusterRepository(db *DB, encryptor crypto.Encryptor) repository.ClusterRepository {
|
|
return &ClusterRepository{
|
|
db: db,
|
|
encryptor: encryptor,
|
|
}
|
|
}
|
|
|
|
// Create 创建集群
|
|
func (r *ClusterRepository) Create(ctx context.Context, cluster *entity.Cluster) error {
|
|
if cluster.ID == "" {
|
|
cluster.ID = uuid.New().String()
|
|
}
|
|
|
|
// 加密敏感数据
|
|
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt CA data: %w", err)
|
|
}
|
|
|
|
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt cert data: %w", err)
|
|
}
|
|
|
|
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt key data: %w", err)
|
|
}
|
|
|
|
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt token: %w", err)
|
|
}
|
|
|
|
query := `
|
|
INSERT INTO clusters (id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
|
`
|
|
|
|
_, err = r.db.conn.ExecContext(ctx, query,
|
|
cluster.ID,
|
|
cluster.Name,
|
|
cluster.Host,
|
|
encryptedCAData,
|
|
encryptedCertData,
|
|
encryptedKeyData,
|
|
encryptedToken,
|
|
cluster.Description,
|
|
cluster.CreatedAt,
|
|
cluster.UpdatedAt,
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create cluster: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetByID 根据 ID 获取集群
|
|
func (r *ClusterRepository) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
|
|
query := `
|
|
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
|
|
FROM clusters
|
|
WHERE id = $1
|
|
`
|
|
|
|
cluster := &entity.Cluster{}
|
|
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
|
|
|
|
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
|
|
&cluster.ID,
|
|
&cluster.Name,
|
|
&cluster.Host,
|
|
&encryptedCAData,
|
|
&encryptedCertData,
|
|
&encryptedKeyData,
|
|
&encryptedToken,
|
|
&cluster.Description,
|
|
&cluster.CreatedAt,
|
|
&cluster.UpdatedAt,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, entity.ErrClusterNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get cluster: %w", err)
|
|
}
|
|
|
|
// 解密敏感数据
|
|
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
|
|
}
|
|
|
|
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
|
|
}
|
|
|
|
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
|
|
}
|
|
|
|
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
|
}
|
|
|
|
return cluster, nil
|
|
}
|
|
|
|
// GetByName 根据名称获取集群
|
|
func (r *ClusterRepository) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
|
|
query := `
|
|
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
|
|
FROM clusters
|
|
WHERE name = $1
|
|
`
|
|
|
|
cluster := &entity.Cluster{}
|
|
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
|
|
|
|
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
|
|
&cluster.ID,
|
|
&cluster.Name,
|
|
&cluster.Host,
|
|
&encryptedCAData,
|
|
&encryptedCertData,
|
|
&encryptedKeyData,
|
|
&encryptedToken,
|
|
&cluster.Description,
|
|
&cluster.CreatedAt,
|
|
&cluster.UpdatedAt,
|
|
)
|
|
|
|
if err == sql.ErrNoRows {
|
|
return nil, entity.ErrClusterNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get cluster: %w", err)
|
|
}
|
|
|
|
// 解密敏感数据
|
|
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
|
|
}
|
|
|
|
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
|
|
}
|
|
|
|
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
|
|
}
|
|
|
|
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
|
}
|
|
|
|
return cluster, nil
|
|
}
|
|
|
|
// Update 更新集群
|
|
func (r *ClusterRepository) Update(ctx context.Context, cluster *entity.Cluster) error {
|
|
cluster.UpdatedAt = time.Now()
|
|
|
|
// 加密敏感数据
|
|
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt CA data: %w", err)
|
|
}
|
|
|
|
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt cert data: %w", err)
|
|
}
|
|
|
|
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt key data: %w", err)
|
|
}
|
|
|
|
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to encrypt token: %w", err)
|
|
}
|
|
|
|
query := `
|
|
UPDATE clusters
|
|
SET name = $1, host = $2, ca_data = $3, cert_data = $4, key_data = $5,
|
|
token = $6, description = $7, updated_at = $8
|
|
WHERE id = $9
|
|
`
|
|
|
|
result, err := r.db.conn.ExecContext(ctx, query,
|
|
cluster.Name,
|
|
cluster.Host,
|
|
encryptedCAData,
|
|
encryptedCertData,
|
|
encryptedKeyData,
|
|
encryptedToken,
|
|
cluster.Description,
|
|
cluster.UpdatedAt,
|
|
cluster.ID,
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update cluster: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get affected rows: %w", err)
|
|
}
|
|
|
|
if rows == 0 {
|
|
return entity.ErrClusterNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Delete 删除集群
|
|
func (r *ClusterRepository) Delete(ctx context.Context, id string) error {
|
|
query := `DELETE FROM clusters WHERE id = $1`
|
|
|
|
result, err := r.db.conn.ExecContext(ctx, query, id)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete cluster: %w", err)
|
|
}
|
|
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get affected rows: %w", err)
|
|
}
|
|
|
|
if rows == 0 {
|
|
return entity.ErrClusterNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// List 列出所有集群
|
|
func (r *ClusterRepository) List(ctx context.Context) ([]*entity.Cluster, error) {
|
|
query := `
|
|
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
|
|
FROM clusters
|
|
ORDER BY created_at DESC
|
|
`
|
|
|
|
rows, err := r.db.conn.QueryContext(ctx, query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to list clusters: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
clusters := make([]*entity.Cluster, 0)
|
|
for rows.Next() {
|
|
cluster := &entity.Cluster{}
|
|
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
|
|
|
|
err := rows.Scan(
|
|
&cluster.ID,
|
|
&cluster.Name,
|
|
&cluster.Host,
|
|
&encryptedCAData,
|
|
&encryptedCertData,
|
|
&encryptedKeyData,
|
|
&encryptedToken,
|
|
&cluster.Description,
|
|
&cluster.CreatedAt,
|
|
&cluster.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to scan cluster: %w", err)
|
|
}
|
|
|
|
// 解密敏感数据
|
|
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
|
|
}
|
|
|
|
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
|
|
}
|
|
|
|
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
|
|
}
|
|
|
|
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
|
}
|
|
|
|
clusters = append(clusters, cluster)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("rows iteration error: %w", err)
|
|
}
|
|
|
|
return clusters, nil
|
|
}
|
|
|