This commit is contained in:
mangomqy
2025-11-13 02:54:06 +00:00
commit c5e51ed069
254 changed files with 54901 additions and 0 deletions

View File

@ -0,0 +1,174 @@
package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// ClusterRepositoryMock 集群仓储 Mock 实现(内存存储,支持加密)
type ClusterRepositoryMock struct {
mu sync.RWMutex
clusters map[string]*entity.Cluster // key: cluster ID
encryptor crypto.Encryptor // 加密器
}
// NewClusterRepositoryMock 创建 Mock 实现
func NewClusterRepositoryMock(encryptor crypto.Encryptor) repository.ClusterRepository {
return &ClusterRepositoryMock{
clusters: make(map[string]*entity.Cluster),
encryptor: encryptor,
}
}
func (r *ClusterRepositoryMock) Create(ctx context.Context, cluster *entity.Cluster) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查名称是否已存在
for _, c := range r.clusters {
if c.Name == cluster.Name {
return entity.ErrClusterExists
}
}
// Mock 模式:如果没有提供认证信息,自动填充默认的 Mock 证书
if (cluster.CertData == "" || cluster.KeyData == "") && cluster.Token == "" {
cluster.CAData = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1vY2sgQ0EgQ2VydGlmaWNhdGUKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQ=="
cluster.CertData = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1vY2sgQ2xpZW50IENlcnRpZmljYXRlCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0="
cluster.KeyData = "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNb2NrIFByaXZhdGUgS2V5Ci0tLS0tRU5EIFJTQSBQUklWQVRFIEtFWS0tLS0t"
}
// 加密敏感数据后存储
encryptedCluster := r.encryptCluster(cluster)
r.clusters[cluster.ID] = encryptedCluster
return nil
}
func (r *ClusterRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
cluster, exists := r.clusters[id]
if !exists {
return nil, entity.ErrClusterNotFound
}
// 解密敏感数据后返回
return r.decryptCluster(cluster), nil
}
func (r *ClusterRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, cluster := range r.clusters {
if cluster.Name == name {
// 解密敏感数据后返回
return r.decryptCluster(cluster), nil
}
}
return nil, entity.ErrClusterNotFound
}
func (r *ClusterRepositoryMock) Update(ctx context.Context, cluster *entity.Cluster) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.clusters[cluster.ID]; !exists {
return entity.ErrClusterNotFound
}
// 加密敏感数据后存储
encryptedCluster := r.encryptCluster(cluster)
r.clusters[cluster.ID] = encryptedCluster
return nil
}
func (r *ClusterRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.clusters[id]; !exists {
return entity.ErrClusterNotFound
}
delete(r.clusters, id)
return nil
}
func (r *ClusterRepositoryMock) List(ctx context.Context) ([]*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
clusters := make([]*entity.Cluster, 0, len(r.clusters))
for _, cluster := range r.clusters {
// 解密敏感数据后返回
clusters = append(clusters, r.decryptCluster(cluster))
}
return clusters, nil
}
// encryptCluster 加密 Cluster 的敏感数据
func (r *ClusterRepositoryMock) encryptCluster(cluster *entity.Cluster) *entity.Cluster {
encrypted := *cluster // 复制
// 加密证书数据
if cluster.CAData != "" && !crypto.IsEncrypted(cluster.CAData) {
if encryptedData, err := r.encryptor.Encrypt(cluster.CAData); err == nil {
encrypted.CAData = encryptedData
}
}
if cluster.CertData != "" && !crypto.IsEncrypted(cluster.CertData) {
if encryptedData, err := r.encryptor.Encrypt(cluster.CertData); err == nil {
encrypted.CertData = encryptedData
}
}
if cluster.KeyData != "" && !crypto.IsEncrypted(cluster.KeyData) {
if encryptedData, err := r.encryptor.Encrypt(cluster.KeyData); err == nil {
encrypted.KeyData = encryptedData
}
}
if cluster.Token != "" && !crypto.IsEncrypted(cluster.Token) {
if encryptedData, err := r.encryptor.Encrypt(cluster.Token); err == nil {
encrypted.Token = encryptedData
}
}
return &encrypted
}
// decryptCluster 解密 Cluster 的敏感数据
func (r *ClusterRepositoryMock) decryptCluster(cluster *entity.Cluster) *entity.Cluster {
decrypted := *cluster // 复制
// 解密证书数据
if cluster.CAData != "" && crypto.IsEncrypted(cluster.CAData) {
if decryptedData, err := r.encryptor.Decrypt(cluster.CAData); err == nil {
decrypted.CAData = decryptedData
}
}
if cluster.CertData != "" && crypto.IsEncrypted(cluster.CertData) {
if decryptedData, err := r.encryptor.Decrypt(cluster.CertData); err == nil {
decrypted.CertData = decryptedData
}
}
if cluster.KeyData != "" && crypto.IsEncrypted(cluster.KeyData) {
if decryptedData, err := r.encryptor.Decrypt(cluster.KeyData); err == nil {
decrypted.KeyData = decryptedData
}
}
if cluster.Token != "" && crypto.IsEncrypted(cluster.Token) {
if decryptedData, err := r.encryptor.Decrypt(cluster.Token); err == nil {
decrypted.Token = decryptedData
}
}
return &decrypted
}

View File

@ -0,0 +1,113 @@
package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// InstanceRepositoryMock 实例仓储 Mock 实现(内存存储)
type InstanceRepositoryMock struct {
mu sync.RWMutex
instances map[string]*entity.Instance // key: instance ID
}
// NewInstanceRepositoryMock 创建 Mock 实现
func NewInstanceRepositoryMock() repository.InstanceRepository {
return &InstanceRepositoryMock{
instances: make(map[string]*entity.Instance),
}
}
func (r *InstanceRepositoryMock) Create(ctx context.Context, instance *entity.Instance) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查同一集群中名称是否已存在
for _, inst := range r.instances {
if inst.ClusterID == instance.ClusterID && inst.Name == instance.Name {
return entity.ErrInstanceExists
}
}
r.instances[instance.ID] = instance
return nil
}
func (r *InstanceRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instance, exists := r.instances[id]
if !exists {
return nil, entity.ErrInstanceNotFound
}
return instance, nil
}
func (r *InstanceRepositoryMock) GetByClusterAndName(ctx context.Context, clusterID, name string) (*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, instance := range r.instances {
if instance.ClusterID == clusterID && instance.Name == name {
return instance, nil
}
}
return nil, entity.ErrInstanceNotFound
}
func (r *InstanceRepositoryMock) Update(ctx context.Context, instance *entity.Instance) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.instances[instance.ID]; !exists {
return entity.ErrInstanceNotFound
}
r.instances[instance.ID] = instance
return nil
}
func (r *InstanceRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.instances[id]; !exists {
return entity.ErrInstanceNotFound
}
delete(r.instances, id)
return nil
}
func (r *InstanceRepositoryMock) ListByCluster(ctx context.Context, clusterID string) ([]*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instances := make([]*entity.Instance, 0)
for _, instance := range r.instances {
if instance.ClusterID == clusterID {
instances = append(instances, instance)
}
}
return instances, nil
}
func (r *InstanceRepositoryMock) List(ctx context.Context) ([]*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instances := make([]*entity.Instance, 0, len(r.instances))
for _, instance := range r.instances {
instances = append(instances, instance)
}
return instances, nil
}

View File

@ -0,0 +1,137 @@
package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// RegistryRepositoryMock Registry 仓储 Mock 实现(内存存储,支持加密)
type RegistryRepositoryMock struct {
mu sync.RWMutex
registries map[string]*entity.Registry // key: registry ID
encryptor crypto.Encryptor // 加密器
}
// NewRegistryRepositoryMock 创建 Mock 实现
func NewRegistryRepositoryMock(encryptor crypto.Encryptor) repository.RegistryRepository {
return &RegistryRepositoryMock{
registries: make(map[string]*entity.Registry),
encryptor: encryptor,
}
}
func (r *RegistryRepositoryMock) Create(ctx context.Context, registry *entity.Registry) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查名称是否已存在
for _, reg := range r.registries {
if reg.Name == registry.Name {
return entity.ErrRegistryExists
}
}
// 加密敏感数据后存储
encryptedRegistry := r.encryptRegistry(registry)
r.registries[registry.ID] = encryptedRegistry
return nil
}
func (r *RegistryRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
registry, exists := r.registries[id]
if !exists {
return nil, entity.ErrRegistryNotFound
}
// 解密敏感数据后返回
return r.decryptRegistry(registry), nil
}
func (r *RegistryRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, registry := range r.registries {
if registry.Name == name {
// 解密敏感数据后返回
return r.decryptRegistry(registry), nil
}
}
return nil, entity.ErrRegistryNotFound
}
func (r *RegistryRepositoryMock) Update(ctx context.Context, registry *entity.Registry) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.registries[registry.ID]; !exists {
return entity.ErrRegistryNotFound
}
// 加密敏感数据后存储
encryptedRegistry := r.encryptRegistry(registry)
r.registries[registry.ID] = encryptedRegistry
return nil
}
func (r *RegistryRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.registries[id]; !exists {
return entity.ErrRegistryNotFound
}
delete(r.registries, id)
return nil
}
func (r *RegistryRepositoryMock) List(ctx context.Context) ([]*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
registries := make([]*entity.Registry, 0, len(r.registries))
for _, registry := range r.registries {
// 解密敏感数据后返回
registries = append(registries, r.decryptRegistry(registry))
}
return registries, nil
}
// encryptRegistry 加密 Registry 的敏感数据
func (r *RegistryRepositoryMock) encryptRegistry(registry *entity.Registry) *entity.Registry {
encrypted := *registry // 复制
// 加密密码
if registry.Password != "" && !crypto.IsEncrypted(registry.Password) {
if encryptedPassword, err := r.encryptor.Encrypt(registry.Password); err == nil {
encrypted.Password = encryptedPassword
}
}
return &encrypted
}
// decryptRegistry 解密 Registry 的敏感数据
func (r *RegistryRepositoryMock) decryptRegistry(registry *entity.Registry) *entity.Registry {
decrypted := *registry // 复制
// 解密密码
if registry.Password != "" && crypto.IsEncrypted(registry.Password) {
if decryptedPassword, err := r.encryptor.Decrypt(registry.Password); err == nil {
decrypted.Password = decryptedPassword
}
}
return &decrypted
}

View File

@ -0,0 +1,99 @@
package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// UserRepositoryMock 用户仓储 Mock 实现(内存存储)
type UserRepositoryMock struct {
mu sync.RWMutex
users map[string]*entity.User // key: user ID
}
// NewUserRepositoryMock 创建 Mock 实现
func NewUserRepositoryMock() repository.UserRepository {
return &UserRepositoryMock{
users: make(map[string]*entity.User),
}
}
func (r *UserRepositoryMock) Create(ctx context.Context, user *entity.User) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查是否已存在
for _, u := range r.users {
if u.Username == user.Username {
return entity.ErrUserExists
}
}
r.users[user.ID] = user
return nil
}
func (r *UserRepositoryMock) GetByID(ctx context.Context, id string) (*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
user, exists := r.users[id]
if !exists {
return nil, entity.ErrUserNotFound
}
return user, nil
}
func (r *UserRepositoryMock) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, user := range r.users {
if user.Username == username {
return user, nil
}
}
return nil, entity.ErrUserNotFound
}
func (r *UserRepositoryMock) Update(ctx context.Context, user *entity.User) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.users[user.ID]; !exists {
return entity.ErrUserNotFound
}
r.users[user.ID] = user
return nil
}
func (r *UserRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.users[id]; !exists {
return entity.ErrUserNotFound
}
delete(r.users, id)
return nil
}
func (r *UserRepositoryMock) List(ctx context.Context) ([]*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
users := make([]*entity.User, 0, len(r.users))
for _, user := range r.users {
users = append(users, user)
}
return users, nil
}

View File

@ -0,0 +1,337 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// ClusterRepository PostgreSQL 集群仓储实现
type ClusterRepository struct {
db *DB
encryptor crypto.Encryptor
}
// NewClusterRepository 创建 PostgreSQL 集群仓储
func NewClusterRepository(db *DB, encryptor crypto.Encryptor) repository.ClusterRepository {
return &ClusterRepository{
db: db,
encryptor: encryptor,
}
}
// Create 创建集群
func (r *ClusterRepository) Create(ctx context.Context, cluster *entity.Cluster) error {
if cluster.ID == "" {
cluster.ID = uuid.New().String()
}
// 加密敏感数据
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
if err != nil {
return fmt.Errorf("failed to encrypt CA data: %w", err)
}
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
if err != nil {
return fmt.Errorf("failed to encrypt cert data: %w", err)
}
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
if err != nil {
return fmt.Errorf("failed to encrypt key data: %w", err)
}
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := `
INSERT INTO clusters (id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`
_, err = r.db.conn.ExecContext(ctx, query,
cluster.ID,
cluster.Name,
cluster.Host,
encryptedCAData,
encryptedCertData,
encryptedKeyData,
encryptedToken,
cluster.Description,
cluster.CreatedAt,
cluster.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create cluster: %w", err)
}
return nil
}
// GetByID 根据 ID 获取集群
func (r *ClusterRepository) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
FROM clusters
WHERE id = $1
`
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrClusterNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
return cluster, nil
}
// GetByName 根据名称获取集群
func (r *ClusterRepository) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
FROM clusters
WHERE name = $1
`
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrClusterNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
return cluster, nil
}
// Update 更新集群
func (r *ClusterRepository) Update(ctx context.Context, cluster *entity.Cluster) error {
cluster.UpdatedAt = time.Now()
// 加密敏感数据
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
if err != nil {
return fmt.Errorf("failed to encrypt CA data: %w", err)
}
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
if err != nil {
return fmt.Errorf("failed to encrypt cert data: %w", err)
}
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
if err != nil {
return fmt.Errorf("failed to encrypt key data: %w", err)
}
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := `
UPDATE clusters
SET name = $1, host = $2, ca_data = $3, cert_data = $4, key_data = $5,
token = $6, description = $7, updated_at = $8
WHERE id = $9
`
result, err := r.db.conn.ExecContext(ctx, query,
cluster.Name,
cluster.Host,
encryptedCAData,
encryptedCertData,
encryptedKeyData,
encryptedToken,
cluster.Description,
cluster.UpdatedAt,
cluster.ID,
)
if err != nil {
return fmt.Errorf("failed to update cluster: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrClusterNotFound
}
return nil
}
// Delete 删除集群
func (r *ClusterRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM clusters WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete cluster: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrClusterNotFound
}
return nil
}
// List 列出所有集群
func (r *ClusterRepository) List(ctx context.Context) ([]*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
FROM clusters
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list clusters: %w", err)
}
defer rows.Close()
clusters := make([]*entity.Cluster, 0)
for rows.Next() {
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := rows.Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan cluster: %w", err)
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
clusters = append(clusters, cluster)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return clusters, nil
}

View File

@ -0,0 +1,135 @@
package postgres
import (
"database/sql"
"fmt"
"time"
_ "github.com/lib/pq"
)
// DB 数据库连接包装器
type DB struct {
conn *sql.DB
}
// NewDB 创建新的数据库连接
func NewDB(connString string) (*DB, error) {
if connString == "" {
return nil, fmt.Errorf("database connection string cannot be empty")
}
conn, err := sql.Open("postgres", connString)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// 配置连接池
conn.SetMaxOpenConns(25)
conn.SetMaxIdleConns(5)
conn.SetConnMaxLifetime(5 * time.Minute)
// 测试连接
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &DB{conn: conn}, nil
}
// Close 关闭数据库连接
func (db *DB) Close() error {
if db.conn != nil {
return db.conn.Close()
}
return nil
}
// GetConn 获取底层连接(用于事务等高级操作)
func (db *DB) GetConn() *sql.DB {
return db.conn
}
// InitSchema 初始化数据库 schema
func (db *DB) InitSchema() error {
schema := `
-- Users 表
CREATE TABLE IF NOT EXISTS users (
id VARCHAR(36) PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
email VARCHAR(255) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
-- Clusters 表
CREATE TABLE IF NOT EXISTS clusters (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(255) NOT NULL UNIQUE,
host TEXT NOT NULL,
ca_data TEXT,
cert_data TEXT,
key_data TEXT,
token TEXT,
description TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_clusters_name ON clusters(name);
-- Registries 表
CREATE TABLE IF NOT EXISTS registries (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(255) NOT NULL UNIQUE,
url TEXT NOT NULL,
description TEXT,
username VARCHAR(255),
password TEXT,
insecure BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_registries_name ON registries(name);
-- Instances 表
CREATE TABLE IF NOT EXISTS instances (
id VARCHAR(36) PRIMARY KEY,
cluster_id VARCHAR(36) NOT NULL,
name VARCHAR(255) NOT NULL,
namespace VARCHAR(255) NOT NULL,
registry_id VARCHAR(36) NOT NULL,
repository TEXT NOT NULL,
chart VARCHAR(255) NOT NULL,
version VARCHAR(255) NOT NULL,
description TEXT,
values JSONB,
values_yaml TEXT,
status VARCHAR(50) NOT NULL,
status_reason TEXT,
last_operation VARCHAR(50),
last_error TEXT,
revision INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_cluster FOREIGN KEY (cluster_id) REFERENCES clusters(id) ON DELETE CASCADE,
CONSTRAINT fk_registry FOREIGN KEY (registry_id) REFERENCES registries(id) ON DELETE CASCADE,
CONSTRAINT unique_cluster_name UNIQUE (cluster_id, name, namespace)
);
CREATE INDEX IF NOT EXISTS idx_instances_cluster ON instances(cluster_id);
CREATE INDEX IF NOT EXISTS idx_instances_registry ON instances(registry_id);
CREATE INDEX IF NOT EXISTS idx_instances_name ON instances(name);
`
_, err := db.conn.Exec(schema)
if err != nil {
return fmt.Errorf("failed to initialize schema: %w", err)
}
return nil
}

View File

@ -0,0 +1,433 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// InstanceRepository PostgreSQL 实例仓储实现
type InstanceRepository struct {
db *DB
}
// NewInstanceRepository 创建 PostgreSQL 实例仓储
func NewInstanceRepository(db *DB) repository.InstanceRepository {
return &InstanceRepository{db: db}
}
// Create 创建实例
func (r *InstanceRepository) Create(ctx context.Context, instance *entity.Instance) error {
if instance.ID == "" {
instance.ID = uuid.New().String()
}
// 将 Values 转换为 JSON
valuesJSON, err := json.Marshal(instance.Values)
if err != nil {
return fmt.Errorf("failed to marshal values: %w", err)
}
query := `
INSERT INTO instances (id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)
`
_, err = r.db.conn.ExecContext(ctx, query,
instance.ID,
instance.ClusterID,
instance.Name,
instance.Namespace,
instance.RegistryID,
instance.Repository,
instance.Chart,
instance.Version,
instance.Description,
valuesJSON,
instance.ValuesYAML,
instance.Status,
instance.StatusReason,
instance.LastOperation,
instance.LastError,
instance.Revision,
instance.CreatedAt,
instance.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create instance: %w", err)
}
return nil
}
// GetByID 根据 ID 获取实例
func (r *InstanceRepository) GetByID(ctx context.Context, id string) (*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE id = $1
`
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrInstanceNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
return instance, nil
}
// GetByClusterAndName 根据集群 ID 和名称获取实例
func (r *InstanceRepository) GetByClusterAndName(ctx context.Context, clusterID, name string) (*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE cluster_id = $1 AND name = $2
`
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := r.db.conn.QueryRowContext(ctx, query, clusterID, name).Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrInstanceNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
return instance, nil
}
// Update 更新实例
func (r *InstanceRepository) Update(ctx context.Context, instance *entity.Instance) error {
instance.UpdatedAt = time.Now()
// 将 Values 转换为 JSON
valuesJSON, err := json.Marshal(instance.Values)
if err != nil {
return fmt.Errorf("failed to marshal values: %w", err)
}
query := `
UPDATE instances
SET cluster_id = $1, name = $2, namespace = $3, registry_id = $4, repository = $5,
chart = $6, version = $7, description = $8, values = $9, values_yaml = $10,
status = $11, status_reason = $12, last_operation = $13, last_error = $14,
revision = $15, updated_at = $16
WHERE id = $17
`
result, err := r.db.conn.ExecContext(ctx, query,
instance.ClusterID,
instance.Name,
instance.Namespace,
instance.RegistryID,
instance.Repository,
instance.Chart,
instance.Version,
instance.Description,
valuesJSON,
instance.ValuesYAML,
instance.Status,
instance.StatusReason,
instance.LastOperation,
instance.LastError,
instance.Revision,
instance.UpdatedAt,
instance.ID,
)
if err != nil {
return fmt.Errorf("failed to update instance: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrInstanceNotFound
}
return nil
}
// Delete 删除实例
func (r *InstanceRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM instances WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete instance: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrInstanceNotFound
}
return nil
}
// ListByCluster 列出指定集群的所有实例
func (r *InstanceRepository) ListByCluster(ctx context.Context, clusterID string) ([]*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE cluster_id = $1
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query, clusterID)
if err != nil {
return nil, fmt.Errorf("failed to list instances: %w", err)
}
defer rows.Close()
instances := make([]*entity.Instance, 0)
for rows.Next() {
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := rows.Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
instances = append(instances, instance)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return instances, nil
}
// List 列出所有实例
func (r *InstanceRepository) List(ctx context.Context) ([]*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list instances: %w", err)
}
defer rows.Close()
instances := make([]*entity.Instance, 0)
for rows.Next() {
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := rows.Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
instances = append(instances, instance)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return instances, nil
}

View File

@ -0,0 +1,257 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// RegistryRepository PostgreSQL Registry 仓储实现
type RegistryRepository struct {
db *DB
encryptor crypto.Encryptor
}
// NewRegistryRepository 创建 PostgreSQL Registry 仓储
func NewRegistryRepository(db *DB, encryptor crypto.Encryptor) repository.RegistryRepository {
return &RegistryRepository{
db: db,
encryptor: encryptor,
}
}
// Create 创建 Registry
func (r *RegistryRepository) Create(ctx context.Context, registry *entity.Registry) error {
if registry.ID == "" {
registry.ID = uuid.New().String()
}
// 加密密码
encryptedPassword, err := r.encryptor.Encrypt(registry.Password)
if err != nil {
return fmt.Errorf("failed to encrypt password: %w", err)
}
query := `
INSERT INTO registries (id, name, url, description, username, password, insecure, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
_, err = r.db.conn.ExecContext(ctx, query,
registry.ID,
registry.Name,
registry.URL,
registry.Description,
registry.Username,
encryptedPassword,
registry.Insecure,
registry.CreatedAt,
registry.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create registry: %w", err)
}
return nil
}
// GetByID 根据 ID 获取 Registry
func (r *RegistryRepository) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
WHERE id = $1
`
registry := &entity.Registry{}
var encryptedPassword string
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrRegistryNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get registry: %w", err)
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
return registry, nil
}
// GetByName 根据名称获取 Registry
func (r *RegistryRepository) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
WHERE name = $1
`
registry := &entity.Registry{}
var encryptedPassword string
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrRegistryNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get registry: %w", err)
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
return registry, nil
}
// Update 更新 Registry
func (r *RegistryRepository) Update(ctx context.Context, registry *entity.Registry) error {
registry.UpdatedAt = time.Now()
// 加密密码
encryptedPassword, err := r.encryptor.Encrypt(registry.Password)
if err != nil {
return fmt.Errorf("failed to encrypt password: %w", err)
}
query := `
UPDATE registries
SET name = $1, url = $2, description = $3, username = $4, password = $5,
insecure = $6, updated_at = $7
WHERE id = $8
`
result, err := r.db.conn.ExecContext(ctx, query,
registry.Name,
registry.URL,
registry.Description,
registry.Username,
encryptedPassword,
registry.Insecure,
registry.UpdatedAt,
registry.ID,
)
if err != nil {
return fmt.Errorf("failed to update registry: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrRegistryNotFound
}
return nil
}
// Delete 删除 Registry
func (r *RegistryRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM registries WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete registry: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrRegistryNotFound
}
return nil
}
// List 列出所有 Registries
func (r *RegistryRepository) List(ctx context.Context) ([]*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list registries: %w", err)
}
defer rows.Close()
registries := make([]*entity.Registry, 0)
for rows.Next() {
registry := &entity.Registry{}
var encryptedPassword string
err := rows.Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan registry: %w", err)
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
registries = append(registries, registry)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return registries, nil
}

View File

@ -0,0 +1,204 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// UserRepository PostgreSQL 用户仓储实现
type UserRepository struct {
db *DB
}
// NewUserRepository 创建 PostgreSQL 用户仓储
func NewUserRepository(db *DB) repository.UserRepository {
return &UserRepository{db: db}
}
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, user *entity.User) error {
if user.ID == "" {
user.ID = uuid.New().String()
}
query := `
INSERT INTO users (id, username, password_hash, email, revoked_after, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := r.db.conn.ExecContext(ctx, query,
user.ID,
user.Username,
user.PasswordHash,
user.Email,
user.RevokedAfter,
user.CreatedAt,
user.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// GetByID 根据 ID 获取用户
func (r *UserRepository) GetByID(ctx context.Context, id string) (*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
FROM users
WHERE id = $1
`
user := &entity.User{}
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&user.ID,
&user.Username,
&user.PasswordHash,
&user.Email,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrUserNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
FROM users
WHERE username = $1
`
user := &entity.User{}
err := r.db.conn.QueryRowContext(ctx, query, username).Scan(
&user.ID,
&user.Username,
&user.PasswordHash,
&user.Email,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrUserNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
// Update 更新用户
func (r *UserRepository) Update(ctx context.Context, user *entity.User) error {
user.UpdatedAt = time.Now()
query := `
UPDATE users
SET username = $1, password_hash = $2, email = $3, revoked_after = $4, updated_at = $5
WHERE id = $6
`
result, err := r.db.conn.ExecContext(ctx, query,
user.Username,
user.PasswordHash,
user.Email,
user.RevokedAfter,
user.UpdatedAt,
user.ID,
)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrUserNotFound
}
return nil
}
// Delete 删除用户
func (r *UserRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM users WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrUserNotFound
}
return nil
}
// List 列出所有用户
func (r *UserRepository) List(ctx context.Context) ([]*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
FROM users
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
defer rows.Close()
users := make([]*entity.User, 0)
for rows.Next() {
user := &entity.User{}
err := rows.Scan(
&user.ID,
&user.Username,
&user.PasswordHash,
&user.Email,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan user: %w", err)
}
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return users, nil
}