ocdp v1
This commit is contained in:
220
backend/internal/adapter/output/factory.go
Normal file
220
backend/internal/adapter/output/factory.go
Normal file
@ -0,0 +1,220 @@
|
||||
package output
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
helmMock "github.com/ocdp/cluster-service/internal/adapter/output/helm/mock"
|
||||
helmReal "github.com/ocdp/cluster-service/internal/adapter/output/helm/real"
|
||||
"github.com/ocdp/cluster-service/internal/adapter/output/k8s"
|
||||
ociMock "github.com/ocdp/cluster-service/internal/adapter/output/oci/mock"
|
||||
ociReal "github.com/ocdp/cluster-service/internal/adapter/output/oci/real"
|
||||
"github.com/ocdp/cluster-service/internal/adapter/output/persistence/mock"
|
||||
"github.com/ocdp/cluster-service/internal/adapter/output/persistence/postgres"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
"github.com/ocdp/cluster-service/internal/pkg/crypto"
|
||||
)
|
||||
|
||||
// AdapterMode 适配器模式
|
||||
type AdapterMode string
|
||||
|
||||
const (
|
||||
ModeMock AdapterMode = "mock" // Mock 模式(内存存储,用于开发调试)
|
||||
// 默认模式:连接真实 PostgreSQL 和服务(任何非 "mock" 的值都是默认模式)
|
||||
)
|
||||
|
||||
// AdapterFactory 适配器工厂
|
||||
// 用于创建所有 Output Adapters,支持 Mock 和真实实现切换
|
||||
type AdapterFactory struct {
|
||||
mode AdapterMode
|
||||
encryptor crypto.Encryptor // 加密器(用于敏感数据加密)
|
||||
|
||||
// 数据库连接字符串(非 Mock 模式需要)
|
||||
dbConnString string
|
||||
|
||||
// 数据库连接(非 Mock 模式)
|
||||
db *postgres.DB
|
||||
}
|
||||
|
||||
// NewAdapterFactory 创建适配器工厂
|
||||
func NewAdapterFactory(mode AdapterMode, encryptor crypto.Encryptor, dbConnString string) *AdapterFactory {
|
||||
return &AdapterFactory{
|
||||
mode: mode,
|
||||
encryptor: encryptor,
|
||||
dbConnString: dbConnString,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUserRepository 创建用户仓储
|
||||
func (f *AdapterFactory) CreateUserRepository() (repository.UserRepository, error) {
|
||||
if f.mode == ModeMock {
|
||||
return mock.NewUserRepositoryMock(), nil
|
||||
}
|
||||
|
||||
// 默认:真实实现(PostgreSQL)
|
||||
if err := f.ensureDBConnection(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return postgres.NewUserRepository(f.db), nil
|
||||
}
|
||||
|
||||
// CreateClusterRepository 创建集群仓储
|
||||
func (f *AdapterFactory) CreateClusterRepository() (repository.ClusterRepository, error) {
|
||||
if f.mode == ModeMock {
|
||||
return mock.NewClusterRepositoryMock(f.encryptor), nil
|
||||
}
|
||||
|
||||
// 默认:真实实现(PostgreSQL)
|
||||
if err := f.ensureDBConnection(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return postgres.NewClusterRepository(f.db, f.encryptor), nil
|
||||
}
|
||||
|
||||
// CreateRegistryRepository 创建 Registry 仓储
|
||||
func (f *AdapterFactory) CreateRegistryRepository() (repository.RegistryRepository, error) {
|
||||
if f.mode == ModeMock {
|
||||
return mock.NewRegistryRepositoryMock(f.encryptor), nil
|
||||
}
|
||||
|
||||
// 默认:真实实现(PostgreSQL)
|
||||
if err := f.ensureDBConnection(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return postgres.NewRegistryRepository(f.db, f.encryptor), nil
|
||||
}
|
||||
|
||||
// CreateInstanceRepository 创建实例仓储
|
||||
func (f *AdapterFactory) CreateInstanceRepository() (repository.InstanceRepository, error) {
|
||||
if f.mode == ModeMock {
|
||||
return mock.NewInstanceRepositoryMock(), nil
|
||||
}
|
||||
|
||||
// 默认:真实实现(PostgreSQL)
|
||||
if err := f.ensureDBConnection(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return postgres.NewInstanceRepository(f.db), nil
|
||||
}
|
||||
|
||||
// CreateOCIClient 创建 OCI 客户端
|
||||
func (f *AdapterFactory) CreateOCIClient() (repository.OCIClient, error) {
|
||||
if f.mode == ModeMock {
|
||||
return ociMock.NewOCIClientMock(), nil
|
||||
}
|
||||
|
||||
// 默认:真实实现(ORAS SDK)
|
||||
return ociReal.NewOCIClient(), nil
|
||||
}
|
||||
|
||||
// CreateHelmClient 创建 Helm 客户端
|
||||
func (f *AdapterFactory) CreateHelmClient() (repository.HelmClient, error) {
|
||||
if f.mode == ModeMock {
|
||||
return helmMock.NewHelmClientMock(), nil
|
||||
}
|
||||
|
||||
// 默认:真实实现(Helm SDK)
|
||||
return helmReal.NewHelmClient(), nil
|
||||
}
|
||||
|
||||
// CreateMetricsClient 创建 Metrics 客户端
|
||||
func (f *AdapterFactory) CreateMetricsClient(clusterRepo repository.ClusterRepository) repository.MetricsClient {
|
||||
// Metrics client 总是使用真实的 Kubernetes API
|
||||
return k8s.NewMetricsClient(clusterRepo)
|
||||
}
|
||||
|
||||
// CreateEntryClient 创建实例入口查询客户端
|
||||
func (f *AdapterFactory) CreateEntryClient() repository.InstanceEntryClient {
|
||||
return k8s.NewEntryClient()
|
||||
}
|
||||
|
||||
// CreateAllRepositories 一次性创建所有 Repositories
|
||||
func (f *AdapterFactory) CreateAllRepositories() (*Repositories, error) {
|
||||
userRepo, err := f.CreateUserRepository()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user repository: %w", err)
|
||||
}
|
||||
|
||||
clusterRepo, err := f.CreateClusterRepository()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cluster repository: %w", err)
|
||||
}
|
||||
|
||||
registryRepo, err := f.CreateRegistryRepository()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create registry repository: %w", err)
|
||||
}
|
||||
|
||||
instanceRepo, err := f.CreateInstanceRepository()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance repository: %w", err)
|
||||
}
|
||||
|
||||
ociClient, err := f.CreateOCIClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OCI client: %w", err)
|
||||
}
|
||||
|
||||
helmClient, err := f.CreateHelmClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Helm client: %w", err)
|
||||
}
|
||||
|
||||
// 创建 Metrics client(依赖 clusterRepo)
|
||||
metricsClient := f.CreateMetricsClient(clusterRepo)
|
||||
entryClient := f.CreateEntryClient()
|
||||
|
||||
return &Repositories{
|
||||
UserRepo: userRepo,
|
||||
ClusterRepo: clusterRepo,
|
||||
RegistryRepo: registryRepo,
|
||||
InstanceRepo: instanceRepo,
|
||||
OCIClient: ociClient,
|
||||
HelmClient: helmClient,
|
||||
MetricsClient: metricsClient,
|
||||
EntryClient: entryClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Repositories 所有仓储的集合
|
||||
type Repositories struct {
|
||||
UserRepo repository.UserRepository
|
||||
ClusterRepo repository.ClusterRepository
|
||||
RegistryRepo repository.RegistryRepository
|
||||
InstanceRepo repository.InstanceRepository
|
||||
OCIClient repository.OCIClient
|
||||
HelmClient repository.HelmClient
|
||||
MetricsClient repository.MetricsClient
|
||||
EntryClient repository.InstanceEntryClient
|
||||
}
|
||||
|
||||
// ensureDBConnection 确保数据库连接已建立
|
||||
func (f *AdapterFactory) ensureDBConnection() error {
|
||||
if f.db != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.dbConnString == "" {
|
||||
return fmt.Errorf("database connection string is required (set DATABASE_URL environment variable)")
|
||||
}
|
||||
|
||||
db, err := postgres.NewDB(f.dbConnString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// 初始化数据库 schema
|
||||
if err := db.InitSchema(); err != nil {
|
||||
return fmt.Errorf("failed to initialize database schema: %w", err)
|
||||
}
|
||||
|
||||
f.db = db
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭工厂资源
|
||||
func (f *AdapterFactory) Close() error {
|
||||
if f.db != nil {
|
||||
return f.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
196
backend/internal/adapter/output/helm/mock/helm_client_mock.go
Normal file
196
backend/internal/adapter/output/helm/mock/helm_client_mock.go
Normal file
@ -0,0 +1,196 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
)
|
||||
|
||||
// HelmClientMock Helm 客户端 Mock 实现
|
||||
type HelmClientMock struct {
|
||||
// Mock 数据存储
|
||||
releases map[string]map[string]*entity.Instance // clusterID -> releaseName -> instance
|
||||
history map[string]map[string][]*entity.ReleaseHistory // clusterID -> releaseName -> []history
|
||||
}
|
||||
|
||||
// NewHelmClientMock 创建 Mock 实现
|
||||
func NewHelmClientMock() repository.HelmClient {
|
||||
return &HelmClientMock{
|
||||
releases: make(map[string]map[string]*entity.Instance),
|
||||
history: make(map[string]map[string][]*entity.ReleaseHistory),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HelmClientMock) Install(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
|
||||
// 初始化集群数据
|
||||
if c.releases[cluster.ID] == nil {
|
||||
c.releases[cluster.ID] = make(map[string]*entity.Instance)
|
||||
c.history[cluster.ID] = make(map[string][]*entity.ReleaseHistory)
|
||||
}
|
||||
|
||||
// 检查是否已存在
|
||||
key := fmt.Sprintf("%s/%s", instance.Namespace, instance.Name)
|
||||
if _, exists := c.releases[cluster.ID][key]; exists {
|
||||
return entity.ErrInstanceExists
|
||||
}
|
||||
|
||||
// Mock 安装
|
||||
instance.Status = entity.StatusDeployed
|
||||
instance.Revision = 1
|
||||
instance.UpdatedAt = time.Now()
|
||||
|
||||
c.releases[cluster.ID][key] = instance
|
||||
|
||||
// 添加历史记录
|
||||
c.history[cluster.ID][key] = []*entity.ReleaseHistory{
|
||||
{
|
||||
Revision: 1,
|
||||
Updated: time.Now(),
|
||||
Status: entity.StatusDeployed,
|
||||
Chart: fmt.Sprintf("%s-%s", instance.Chart, instance.Version),
|
||||
AppVersion: instance.Version,
|
||||
Description: "Install complete",
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HelmClientMock) Upgrade(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
|
||||
key := fmt.Sprintf("%s/%s", instance.Namespace, instance.Name)
|
||||
|
||||
existing, exists := c.releases[cluster.ID][key]
|
||||
if !exists {
|
||||
return entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
// Mock 升级
|
||||
instance.Revision = existing.Revision + 1
|
||||
instance.Status = entity.StatusDeployed
|
||||
instance.UpdatedAt = time.Now()
|
||||
|
||||
c.releases[cluster.ID][key] = instance
|
||||
|
||||
// 添加历史记录
|
||||
history := &entity.ReleaseHistory{
|
||||
Revision: instance.Revision,
|
||||
Updated: time.Now(),
|
||||
Status: entity.StatusDeployed,
|
||||
Chart: fmt.Sprintf("%s-%s", instance.Chart, instance.Version),
|
||||
AppVersion: instance.Version,
|
||||
Description: "Upgrade complete",
|
||||
}
|
||||
c.history[cluster.ID][key] = append(c.history[cluster.ID][key], history)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HelmClientMock) Uninstall(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) error {
|
||||
key := fmt.Sprintf("%s/%s", namespace, releaseName)
|
||||
|
||||
if _, exists := c.releases[cluster.ID][key]; !exists {
|
||||
return entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
// Mock 卸载
|
||||
delete(c.releases[cluster.ID], key)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HelmClientMock) Rollback(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string, revision int) error {
|
||||
key := fmt.Sprintf("%s/%s", namespace, releaseName)
|
||||
|
||||
instance, exists := c.releases[cluster.ID][key]
|
||||
if !exists {
|
||||
return entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
// 检查历史记录是否存在
|
||||
histories := c.history[cluster.ID][key]
|
||||
if revision > len(histories) || revision < 1 {
|
||||
return fmt.Errorf("revision %d not found", revision)
|
||||
}
|
||||
|
||||
// Mock 回滚
|
||||
instance.Revision = len(histories) + 1
|
||||
instance.Status = entity.StatusDeployed
|
||||
instance.UpdatedAt = time.Now()
|
||||
|
||||
c.releases[cluster.ID][key] = instance
|
||||
|
||||
// 添加回滚历史记录
|
||||
history := &entity.ReleaseHistory{
|
||||
Revision: instance.Revision,
|
||||
Updated: time.Now(),
|
||||
Status: entity.StatusDeployed,
|
||||
Chart: instance.Chart,
|
||||
AppVersion: instance.Version,
|
||||
Description: fmt.Sprintf("Rollback to revision %d", revision),
|
||||
}
|
||||
c.history[cluster.ID][key] = append(c.history[cluster.ID][key], history)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *HelmClientMock) GetStatus(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (*entity.Instance, error) {
|
||||
key := fmt.Sprintf("%s/%s", namespace, releaseName)
|
||||
|
||||
instance, exists := c.releases[cluster.ID][key]
|
||||
if !exists {
|
||||
return nil, entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (c *HelmClientMock) GetHistory(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) ([]*entity.ReleaseHistory, error) {
|
||||
key := fmt.Sprintf("%s/%s", namespace, releaseName)
|
||||
|
||||
if _, exists := c.releases[cluster.ID][key]; !exists {
|
||||
return nil, entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
histories := c.history[cluster.ID][key]
|
||||
if histories == nil {
|
||||
return []*entity.ReleaseHistory{}, nil
|
||||
}
|
||||
|
||||
return histories, nil
|
||||
}
|
||||
|
||||
func (c *HelmClientMock) List(ctx context.Context, cluster *entity.Cluster, namespace string) ([]*entity.Instance, error) {
|
||||
clusterReleases := c.releases[cluster.ID]
|
||||
if clusterReleases == nil {
|
||||
return []*entity.Instance{}, nil
|
||||
}
|
||||
|
||||
instances := make([]*entity.Instance, 0)
|
||||
for key, instance := range clusterReleases {
|
||||
// 如果指定了 namespace,只返回该 namespace 的
|
||||
if namespace != "" && namespace != "all" {
|
||||
keyNamespace := instance.Namespace
|
||||
if keyNamespace != namespace {
|
||||
continue
|
||||
}
|
||||
}
|
||||
instances = append(instances, c.releases[cluster.ID][key])
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
func (c *HelmClientMock) GetValues(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (map[string]interface{}, error) {
|
||||
key := fmt.Sprintf("%s/%s", namespace, releaseName)
|
||||
|
||||
instance, exists := c.releases[cluster.ID][key]
|
||||
if !exists {
|
||||
return nil, entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
return instance.Values, nil
|
||||
}
|
||||
|
||||
313
backend/internal/adapter/output/helm/real/helm_client.go
Normal file
313
backend/internal/adapter/output/helm/real/helm_client.go
Normal file
@ -0,0 +1,313 @@
|
||||
package real
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
"helm.sh/helm/v3/pkg/action"
|
||||
"helm.sh/helm/v3/pkg/chart/loader"
|
||||
"helm.sh/helm/v3/pkg/cli"
|
||||
"helm.sh/helm/v3/pkg/release"
|
||||
"helm.sh/helm/v3/pkg/storage/driver"
|
||||
"k8s.io/apimachinery/pkg/api/meta"
|
||||
"k8s.io/client-go/discovery"
|
||||
"k8s.io/client-go/discovery/cached/memory"
|
||||
"k8s.io/client-go/rest"
|
||||
"k8s.io/client-go/restmapper"
|
||||
"k8s.io/client-go/tools/clientcmd"
|
||||
)
|
||||
|
||||
// HelmClient 真实的 Helm 客户端实现
|
||||
type HelmClient struct {
|
||||
settings *cli.EnvSettings
|
||||
}
|
||||
|
||||
// NewHelmClient 创建真实的 Helm 客户端
|
||||
func NewHelmClient() repository.HelmClient {
|
||||
return &HelmClient{
|
||||
settings: cli.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// getActionConfig 获取 Helm action configuration
|
||||
func (h *HelmClient) getActionConfig(cluster *entity.Cluster, namespace string) (*action.Configuration, error) {
|
||||
actionConfig := new(action.Configuration)
|
||||
|
||||
// 创建临时 kubeconfig 文件
|
||||
kubeconfigContent := cluster.GetKubeConfig()
|
||||
tmpDir, err := os.MkdirTemp("", "helm-kubeconfig-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create temp dir: %w", err)
|
||||
}
|
||||
|
||||
kubeconfigPath := filepath.Join(tmpDir, "kubeconfig")
|
||||
if err := os.WriteFile(kubeconfigPath, []byte(kubeconfigContent), 0600); err != nil {
|
||||
return nil, fmt.Errorf("failed to write kubeconfig: %w", err)
|
||||
}
|
||||
|
||||
// 使用 kubeconfig 初始化 action config
|
||||
if err := actionConfig.Init(
|
||||
&kubeconfigGetter{kubeconfigPath: kubeconfigPath},
|
||||
namespace,
|
||||
os.Getenv("HELM_DRIVER"), // storage driver: configmap, secret, memory
|
||||
func(format string, v ...interface{}) {
|
||||
// Log function
|
||||
},
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize action config: %w", err)
|
||||
}
|
||||
|
||||
return actionConfig, nil
|
||||
}
|
||||
|
||||
// kubeconfigGetter implements RESTClientGetter
|
||||
type kubeconfigGetter struct {
|
||||
kubeconfigPath string
|
||||
}
|
||||
|
||||
func (k *kubeconfigGetter) ToRESTConfig() (*rest.Config, error) {
|
||||
return clientcmd.BuildConfigFromFlags("", k.kubeconfigPath)
|
||||
}
|
||||
|
||||
func (k *kubeconfigGetter) ToDiscoveryClient() (discovery.CachedDiscoveryInterface, error) {
|
||||
config, err := k.ToRESTConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
discoveryClient := discovery.NewDiscoveryClientForConfigOrDie(config)
|
||||
// Wrap in a memory cache
|
||||
return memory.NewMemCacheClient(discoveryClient), nil
|
||||
}
|
||||
|
||||
func (k *kubeconfigGetter) ToRESTMapper() (meta.RESTMapper, error) {
|
||||
discoveryClient, err := k.ToDiscoveryClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mapper := restmapper.NewDeferredDiscoveryRESTMapper(discoveryClient)
|
||||
return mapper, nil
|
||||
}
|
||||
|
||||
func (k *kubeconfigGetter) ToRawKubeConfigLoader() clientcmd.ClientConfig {
|
||||
return clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
|
||||
&clientcmd.ClientConfigLoadingRules{ExplicitPath: k.kubeconfigPath},
|
||||
&clientcmd.ConfigOverrides{},
|
||||
)
|
||||
}
|
||||
|
||||
// Install 安装 Helm Chart
|
||||
func (h *HelmClient) Install(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
|
||||
actionConfig, err := h.getActionConfig(cluster, instance.Namespace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
install := action.NewInstall(actionConfig)
|
||||
install.ReleaseName = instance.Name
|
||||
install.Namespace = instance.Namespace
|
||||
install.CreateNamespace = true
|
||||
install.Wait = true
|
||||
install.Timeout = 5 * time.Minute
|
||||
|
||||
// 加载 Chart(从本地路径或 OCI registry)
|
||||
// 这里简化处理,假设 chart 已经被拉取到本地
|
||||
chartPath := fmt.Sprintf("/tmp/charts/%s-%s.tgz", instance.Chart, instance.Version)
|
||||
|
||||
chart, err := loader.Load(chartPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load chart: %w", err)
|
||||
}
|
||||
|
||||
// 执行安装
|
||||
rel, err := install.Run(chart, instance.Values)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to install release: %w", err)
|
||||
}
|
||||
|
||||
// 更新 revision(状态由调用方根据操作结果设置)
|
||||
instance.Revision = rel.Version
|
||||
// 注意:不在这里设置 Status,让调用方通过 MarkSuccess/MarkFailure 来设置
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Upgrade 升级 Helm Release
|
||||
func (h *HelmClient) Upgrade(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
|
||||
actionConfig, err := h.getActionConfig(cluster, instance.Namespace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
upgrade := action.NewUpgrade(actionConfig)
|
||||
upgrade.Namespace = instance.Namespace
|
||||
upgrade.Wait = true
|
||||
upgrade.Timeout = 5 * time.Minute
|
||||
|
||||
// 加载 Chart
|
||||
chartPath := fmt.Sprintf("/tmp/charts/%s-%s.tgz", instance.Chart, instance.Version)
|
||||
|
||||
chart, err := loader.Load(chartPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load chart: %w", err)
|
||||
}
|
||||
|
||||
// 执行升级
|
||||
rel, err := upgrade.Run(instance.Name, chart, instance.Values)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade release: %w", err)
|
||||
}
|
||||
|
||||
// 更新 revision(状态由调用方根据操作结果设置)
|
||||
instance.Revision = rel.Version
|
||||
// 注意:不在这里设置 Status,让调用方通过 MarkSuccess/MarkFailure 来设置
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Uninstall 卸载 Helm Release
|
||||
func (h *HelmClient) Uninstall(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) error {
|
||||
actionConfig, err := h.getActionConfig(cluster, namespace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uninstall := action.NewUninstall(actionConfig)
|
||||
uninstall.Wait = true
|
||||
uninstall.Timeout = 5 * time.Minute
|
||||
|
||||
_, err = uninstall.Run(releaseName)
|
||||
if err != nil {
|
||||
if errors.Is(err, driver.ErrReleaseNotFound) {
|
||||
return entity.ErrInstanceNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to uninstall release: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback 回滚 Helm Release
|
||||
func (h *HelmClient) Rollback(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string, revision int) error {
|
||||
actionConfig, err := h.getActionConfig(cluster, namespace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rollback := action.NewRollback(actionConfig)
|
||||
rollback.Version = revision
|
||||
rollback.Wait = true
|
||||
rollback.Timeout = 5 * time.Minute
|
||||
|
||||
if err := rollback.Run(releaseName); err != nil {
|
||||
return fmt.Errorf("failed to rollback release: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatus 获取 Release 状态
|
||||
func (h *HelmClient) GetStatus(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (*entity.Instance, error) {
|
||||
actionConfig, err := h.getActionConfig(cluster, namespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
status := action.NewStatus(actionConfig)
|
||||
rel, err := status.Run(releaseName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get release status: %w", err)
|
||||
}
|
||||
|
||||
return h.convertReleaseToInstance(rel), nil
|
||||
}
|
||||
|
||||
// GetHistory 获取 Release 历史
|
||||
func (h *HelmClient) GetHistory(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) ([]*entity.ReleaseHistory, error) {
|
||||
actionConfig, err := h.getActionConfig(cluster, namespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
history := action.NewHistory(actionConfig)
|
||||
history.Max = 256
|
||||
|
||||
releases, err := history.Run(releaseName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get release history: %w", err)
|
||||
}
|
||||
|
||||
result := make([]*entity.ReleaseHistory, 0, len(releases))
|
||||
for _, rel := range releases {
|
||||
result = append(result, &entity.ReleaseHistory{
|
||||
Revision: rel.Version,
|
||||
Updated: rel.Info.LastDeployed.Time,
|
||||
Status: entity.InstanceStatus(rel.Info.Status),
|
||||
Chart: rel.Chart.Metadata.Name,
|
||||
AppVersion: rel.Chart.Metadata.AppVersion,
|
||||
Description: rel.Info.Description,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// List 列出集群中的所有 Releases
|
||||
func (h *HelmClient) List(ctx context.Context, cluster *entity.Cluster, namespace string) ([]*entity.Instance, error) {
|
||||
actionConfig, err := h.getActionConfig(cluster, namespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list := action.NewList(actionConfig)
|
||||
if namespace == "" {
|
||||
list.AllNamespaces = true
|
||||
}
|
||||
|
||||
releases, err := list.Run()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list releases: %w", err)
|
||||
}
|
||||
|
||||
instances := make([]*entity.Instance, 0, len(releases))
|
||||
for _, rel := range releases {
|
||||
instances = append(instances, h.convertReleaseToInstance(rel))
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// GetValues 获取 Release 的 values
|
||||
func (h *HelmClient) GetValues(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (map[string]interface{}, error) {
|
||||
actionConfig, err := h.getActionConfig(cluster, namespace)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getValues := action.NewGetValues(actionConfig)
|
||||
values, err := getValues.Run(releaseName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get values: %w", err)
|
||||
}
|
||||
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// convertReleaseToInstance 转换 Helm Release 为 Instance
|
||||
func (h *HelmClient) convertReleaseToInstance(rel *release.Release) *entity.Instance {
|
||||
return &entity.Instance{
|
||||
Name: rel.Name,
|
||||
Namespace: rel.Namespace,
|
||||
Chart: rel.Chart.Metadata.Name,
|
||||
Version: rel.Chart.Metadata.Version,
|
||||
Status: entity.InstanceStatus(rel.Info.Status),
|
||||
Revision: rel.Version,
|
||||
Values: rel.Config,
|
||||
UpdatedAt: rel.Info.LastDeployed.Time,
|
||||
}
|
||||
}
|
||||
321
backend/internal/adapter/output/k8s/entry_client.go
Normal file
321
backend/internal/adapter/output/k8s/entry_client.go
Normal file
@ -0,0 +1,321 @@
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
networkingv1 "k8s.io/api/networking/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/util/intstr"
|
||||
"k8s.io/client-go/kubernetes"
|
||||
"k8s.io/client-go/rest"
|
||||
"k8s.io/client-go/tools/clientcmd"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
)
|
||||
|
||||
// EntryClient 使用 Kubernetes API 查询实例相关 Service/Ingress
|
||||
type EntryClient struct{}
|
||||
|
||||
// NewEntryClient 创建 EntryClient
|
||||
func NewEntryClient() repository.InstanceEntryClient {
|
||||
return &EntryClient{}
|
||||
}
|
||||
|
||||
// ListEntries 查询实例的 Service/Ingress 入口
|
||||
func (c *EntryClient) ListEntries(
|
||||
ctx context.Context,
|
||||
cluster *entity.Cluster,
|
||||
instance *entity.Instance,
|
||||
) ([]*entity.InstanceEntry, error) {
|
||||
clientset, err := c.createClientset(cluster)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
selector := fmt.Sprintf("app.kubernetes.io/instance=%s", instance.Name)
|
||||
|
||||
serviceEntries, err := c.collectServiceEntries(ctx, clientset, instance, selector)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ingressEntries, err := c.collectIngressEntries(ctx, clientset, instance, selector)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return append(serviceEntries, ingressEntries...), nil
|
||||
}
|
||||
|
||||
func (c *EntryClient) collectServiceEntries(
|
||||
ctx context.Context,
|
||||
clientset *kubernetes.Clientset,
|
||||
instance *entity.Instance,
|
||||
selector string,
|
||||
) ([]*entity.InstanceEntry, error) {
|
||||
services, err := c.listServices(ctx, clientset, instance.Namespace, selector)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries := convertServicesToEntries(services, instance, selector == "")
|
||||
if len(entries) == 0 && selector != "" {
|
||||
// Fallback: widen the search scope and filter manually.
|
||||
services, err = c.listServices(ctx, clientset, instance.Namespace, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries = convertServicesToEntries(services, instance, true)
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (c *EntryClient) collectIngressEntries(
|
||||
ctx context.Context,
|
||||
clientset *kubernetes.Clientset,
|
||||
instance *entity.Instance,
|
||||
selector string,
|
||||
) ([]*entity.InstanceEntry, error) {
|
||||
ingresses, err := c.listIngresses(ctx, clientset, instance.Namespace, selector)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries := convertIngressesToEntries(ingresses, instance, selector == "")
|
||||
if len(entries) == 0 && selector != "" {
|
||||
ingresses, err = c.listIngresses(ctx, clientset, instance.Namespace, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entries = convertIngressesToEntries(ingresses, instance, true)
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (c *EntryClient) listServices(
|
||||
ctx context.Context,
|
||||
clientset *kubernetes.Clientset,
|
||||
namespace, selector string,
|
||||
) ([]corev1.Service, error) {
|
||||
listOptions := metav1.ListOptions{}
|
||||
if selector != "" {
|
||||
listOptions.LabelSelector = selector
|
||||
}
|
||||
services, err := clientset.CoreV1().
|
||||
Services(namespace).
|
||||
List(ctx, listOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list services: %w", err)
|
||||
}
|
||||
return services.Items, nil
|
||||
}
|
||||
|
||||
func (c *EntryClient) listIngresses(
|
||||
ctx context.Context,
|
||||
clientset *kubernetes.Clientset,
|
||||
namespace, selector string,
|
||||
) ([]networkingv1.Ingress, error) {
|
||||
listOptions := metav1.ListOptions{}
|
||||
if selector != "" {
|
||||
listOptions.LabelSelector = selector
|
||||
}
|
||||
ingresses, err := clientset.NetworkingV1().
|
||||
Ingresses(namespace).
|
||||
List(ctx, listOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list ingresses: %w", err)
|
||||
}
|
||||
return ingresses.Items, nil
|
||||
}
|
||||
|
||||
func convertServicesToEntries(services []corev1.Service, instance *entity.Instance, enforceMatch bool) []*entity.InstanceEntry {
|
||||
entries := make([]*entity.InstanceEntry, 0, len(services))
|
||||
for _, svc := range services {
|
||||
if enforceMatch && !resourceMatchesInstance(svc.ObjectMeta, instance) {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, convertServiceToEntry(&svc))
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func convertIngressesToEntries(ingresses []networkingv1.Ingress, instance *entity.Instance, enforceMatch bool) []*entity.InstanceEntry {
|
||||
entries := make([]*entity.InstanceEntry, 0, len(ingresses))
|
||||
for _, ing := range ingresses {
|
||||
if enforceMatch && !resourceMatchesInstance(ing.ObjectMeta, instance) {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, convertIngressToEntry(&ing))
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func (c *EntryClient) createClientset(cluster *entity.Cluster) (*kubernetes.Clientset, error) {
|
||||
config, err := clientcmd.RESTConfigFromKubeConfig([]byte(cluster.GetKubeConfig()))
|
||||
if err != nil {
|
||||
config = &rest.Config{
|
||||
Host: cluster.Host,
|
||||
TLSClientConfig: rest.TLSClientConfig{
|
||||
CAData: []byte(cluster.CAData),
|
||||
CertData: []byte(cluster.CertData),
|
||||
KeyData: []byte(cluster.KeyData),
|
||||
},
|
||||
BearerToken: cluster.Token,
|
||||
}
|
||||
}
|
||||
|
||||
clientset, err := kubernetes.NewForConfig(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
|
||||
}
|
||||
|
||||
return clientset, nil
|
||||
}
|
||||
|
||||
func convertServiceToEntry(svc *corev1.Service) *entity.InstanceEntry {
|
||||
clusterIP := svc.Spec.ClusterIP
|
||||
if clusterIP == corev1.ClusterIPNone {
|
||||
clusterIP = ""
|
||||
}
|
||||
|
||||
lbIngress := make([]string, 0, len(svc.Status.LoadBalancer.Ingress))
|
||||
for _, ing := range svc.Status.LoadBalancer.Ingress {
|
||||
if ing.IP != "" {
|
||||
lbIngress = append(lbIngress, ing.IP)
|
||||
}
|
||||
if ing.Hostname != "" {
|
||||
lbIngress = append(lbIngress, ing.Hostname)
|
||||
}
|
||||
}
|
||||
|
||||
ports := make([]entity.InstanceEntryPort, 0, len(svc.Spec.Ports))
|
||||
for _, port := range svc.Spec.Ports {
|
||||
ports = append(ports, entity.InstanceEntryPort{
|
||||
Name: port.Name,
|
||||
Protocol: string(port.Protocol),
|
||||
Port: port.Port,
|
||||
TargetPort: intOrStringToString(port.TargetPort),
|
||||
NodePort: port.NodePort,
|
||||
})
|
||||
}
|
||||
|
||||
return &entity.InstanceEntry{
|
||||
Kind: "Service",
|
||||
Name: svc.Name,
|
||||
Namespace: svc.Namespace,
|
||||
Type: string(svc.Spec.Type),
|
||||
ClusterIP: clusterIP,
|
||||
ExternalIPs: append([]string{}, svc.Spec.ExternalIPs...),
|
||||
LoadBalancerIngress: lbIngress,
|
||||
Ports: ports,
|
||||
}
|
||||
}
|
||||
|
||||
func convertIngressToEntry(ing *networkingv1.Ingress) *entity.InstanceEntry {
|
||||
lbIngress := make([]string, 0, len(ing.Status.LoadBalancer.Ingress))
|
||||
for _, addr := range ing.Status.LoadBalancer.Ingress {
|
||||
if addr.IP != "" {
|
||||
lbIngress = append(lbIngress, addr.IP)
|
||||
}
|
||||
if addr.Hostname != "" {
|
||||
lbIngress = append(lbIngress, addr.Hostname)
|
||||
}
|
||||
}
|
||||
|
||||
hosts := make([]entity.InstanceEntryHost, 0, len(ing.Spec.Rules))
|
||||
for _, rule := range ing.Spec.Rules {
|
||||
hostEntry := entity.InstanceEntryHost{
|
||||
Host: rule.Host,
|
||||
}
|
||||
if rule.HTTP != nil {
|
||||
paths := make([]entity.InstanceEntryPath, 0, len(rule.HTTP.Paths))
|
||||
for _, path := range rule.HTTP.Paths {
|
||||
name := ""
|
||||
port := ""
|
||||
if path.Backend.Service != nil {
|
||||
name = path.Backend.Service.Name
|
||||
port = serviceBackendPortString(path.Backend.Service.Port)
|
||||
}
|
||||
paths = append(paths, entity.InstanceEntryPath{
|
||||
Path: path.Path,
|
||||
ServiceName: name,
|
||||
ServicePort: port,
|
||||
})
|
||||
}
|
||||
hostEntry.Paths = paths
|
||||
}
|
||||
hosts = append(hosts, hostEntry)
|
||||
}
|
||||
|
||||
tlsEntries := make([]entity.InstanceEntryTLS, 0, len(ing.Spec.TLS))
|
||||
for _, tls := range ing.Spec.TLS {
|
||||
tlsEntries = append(tlsEntries, entity.InstanceEntryTLS{
|
||||
Hosts: append([]string{}, tls.Hosts...),
|
||||
SecretName: tls.SecretName,
|
||||
})
|
||||
}
|
||||
|
||||
entryType := "Ingress"
|
||||
if ing.Spec.IngressClassName != nil {
|
||||
entryType = *ing.Spec.IngressClassName
|
||||
}
|
||||
|
||||
return &entity.InstanceEntry{
|
||||
Kind: "Ingress",
|
||||
Name: ing.Name,
|
||||
Namespace: ing.Namespace,
|
||||
Type: entryType,
|
||||
LoadBalancerIngress: lbIngress,
|
||||
Hosts: hosts,
|
||||
TLS: tlsEntries,
|
||||
}
|
||||
}
|
||||
|
||||
func intOrStringToString(v intstr.IntOrString) string {
|
||||
if v.Type == intstr.String {
|
||||
return v.StrVal
|
||||
}
|
||||
return fmt.Sprintf("%d", v.IntValue())
|
||||
}
|
||||
|
||||
func serviceBackendPortString(port networkingv1.ServiceBackendPort) string {
|
||||
if port.Name != "" {
|
||||
return port.Name
|
||||
}
|
||||
if port.Number != 0 {
|
||||
return fmt.Sprintf("%d", port.Number)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func resourceMatchesInstance(meta metav1.ObjectMeta, instance *entity.Instance) bool {
|
||||
if instance == nil {
|
||||
return false
|
||||
}
|
||||
labels := meta.GetLabels()
|
||||
if labels != nil {
|
||||
if labels["app.kubernetes.io/instance"] == instance.Name {
|
||||
return true
|
||||
}
|
||||
labelKeys := []string{"app", "app.kubernetes.io/name", "app.kubernetes.io/component", "release"}
|
||||
for _, key := range labelKeys {
|
||||
if labels[key] == instance.Name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
annotations := meta.GetAnnotations()
|
||||
if annotations != nil {
|
||||
if annotations["meta.helm.sh/release-name"] == instance.Name {
|
||||
if ns := annotations["meta.helm.sh/release-namespace"]; ns == "" || ns == instance.Namespace {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
name := meta.GetName()
|
||||
if name == instance.Name || strings.HasPrefix(name, instance.Name+"-") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
54
backend/internal/adapter/output/k8s/entry_client_test.go
Normal file
54
backend/internal/adapter/output/k8s/entry_client_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
)
|
||||
|
||||
func TestResourceMatchesInstance(t *testing.T) {
|
||||
instance := &entity.Instance{
|
||||
Name: "demo",
|
||||
Namespace: "default",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
meta metav1.ObjectMeta
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "matches by standard label",
|
||||
meta: metav1.ObjectMeta{Labels: map[string]string{
|
||||
"app.kubernetes.io/instance": "demo",
|
||||
}},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "matches by helm annotations",
|
||||
meta: metav1.ObjectMeta{Annotations: map[string]string{
|
||||
"meta.helm.sh/release-name": "demo",
|
||||
"meta.helm.sh/release-namespace": "default",
|
||||
}},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "matches by resource name prefix",
|
||||
meta: metav1.ObjectMeta{Name: "demo-nginx"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "does not match unrelated resource",
|
||||
meta: metav1.ObjectMeta{Name: "other"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if got := resourceMatchesInstance(tc.meta, instance); got != tc.want {
|
||||
t.Fatalf("%s: expected %v, got %v", tc.name, tc.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
370
backend/internal/adapter/output/k8s/metrics_client.go
Normal file
370
backend/internal/adapter/output/k8s/metrics_client.go
Normal file
@ -0,0 +1,370 @@
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/client-go/kubernetes"
|
||||
"k8s.io/client-go/rest"
|
||||
"k8s.io/client-go/tools/clientcmd"
|
||||
metricsv "k8s.io/metrics/pkg/client/clientset/versioned"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
)
|
||||
|
||||
// MetricsClient 实现从 Kubernetes 集群获取监控指标
|
||||
type MetricsClient struct {
|
||||
clusterRepo repository.ClusterRepository
|
||||
}
|
||||
|
||||
// NewMetricsClient 创建 MetricsClient
|
||||
func NewMetricsClient(clusterRepo repository.ClusterRepository) *MetricsClient {
|
||||
return &MetricsClient{
|
||||
clusterRepo: clusterRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// GetClusterMetrics 获取集群监控指标
|
||||
func (c *MetricsClient) GetClusterMetrics(ctx context.Context, clusterID string) (*entity.ClusterMetrics, error) {
|
||||
// 获取集群信息
|
||||
cluster, err := c.clusterRepo.GetByID(ctx, clusterID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cluster: %w", err)
|
||||
}
|
||||
|
||||
// 创建 Kubernetes 客户端
|
||||
clientset, metricsClient, err := c.createK8sClients(cluster)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create k8s client: %w", err)
|
||||
}
|
||||
|
||||
// 获取节点列表
|
||||
nodes, err := clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list nodes: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有 Pods
|
||||
pods, err := clientset.CoreV1().Pods("").List(ctx, metav1.ListOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list pods: %w", err)
|
||||
}
|
||||
|
||||
// 获取节点指标(CPU/内存使用情况)
|
||||
nodeMetrics, err := c.getNodeMetricsData(ctx, clientset, metricsClient, nodes.Items)
|
||||
if err != nil {
|
||||
// 如果无法获取 metrics,记录错误但继续
|
||||
fmt.Printf("Warning: failed to get node metrics: %v\n", err)
|
||||
}
|
||||
|
||||
// 计算集群级别汇总
|
||||
metrics := c.aggregateClusterMetrics(cluster, nodes.Items, pods.Items, nodeMetrics)
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// GetNodeMetrics 获取集群节点指标
|
||||
func (c *MetricsClient) GetNodeMetrics(ctx context.Context, clusterID string) ([]*entity.NodeMetrics, error) {
|
||||
cluster, err := c.clusterRepo.GetByID(ctx, clusterID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cluster: %w", err)
|
||||
}
|
||||
|
||||
clientset, metricsClient, err := c.createK8sClients(cluster)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create k8s client: %w", err)
|
||||
}
|
||||
|
||||
nodes, err := clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list nodes: %w", err)
|
||||
}
|
||||
|
||||
return c.getNodeMetricsData(ctx, clientset, metricsClient, nodes.Items)
|
||||
}
|
||||
|
||||
// createK8sClients 创建 Kubernetes 客户端
|
||||
func (c *MetricsClient) createK8sClients(cluster *entity.Cluster) (*kubernetes.Clientset, *metricsv.Clientset, error) {
|
||||
config, err := clientcmd.RESTConfigFromKubeConfig([]byte(cluster.GetKubeConfig()))
|
||||
if err != nil {
|
||||
// 如果无法从 kubeconfig 创建,尝试使用集群配置
|
||||
config = &rest.Config{
|
||||
Host: cluster.Host,
|
||||
TLSClientConfig: rest.TLSClientConfig{
|
||||
CAData: []byte(cluster.CAData),
|
||||
CertData: []byte(cluster.CertData),
|
||||
KeyData: []byte(cluster.KeyData),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
clientset, err := kubernetes.NewForConfig(config)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create clientset: %w", err)
|
||||
}
|
||||
|
||||
metricsClient, err := metricsv.NewForConfig(config)
|
||||
if err != nil {
|
||||
// Metrics API 可能不可用,返回 nil 但不报错
|
||||
return clientset, nil, nil
|
||||
}
|
||||
|
||||
return clientset, metricsClient, nil
|
||||
}
|
||||
|
||||
// getNodeMetricsData 获取节点详细指标
|
||||
func (c *MetricsClient) getNodeMetricsData(
|
||||
ctx context.Context,
|
||||
clientset *kubernetes.Clientset,
|
||||
metricsClient *metricsv.Clientset,
|
||||
nodes []corev1.Node,
|
||||
) ([]*entity.NodeMetrics, error) {
|
||||
result := make([]*entity.NodeMetrics, 0, len(nodes))
|
||||
|
||||
for _, node := range nodes {
|
||||
nodeMetric := &entity.NodeMetrics{
|
||||
NodeName: node.Name,
|
||||
Status: getNodeStatus(&node),
|
||||
Role: getNodeRole(&node),
|
||||
Age: getNodeAge(&node),
|
||||
OSImage: node.Status.NodeInfo.OSImage,
|
||||
KernelVersion: node.Status.NodeInfo.KernelVersion,
|
||||
ContainerRuntime: node.Status.NodeInfo.ContainerRuntimeVersion,
|
||||
KubeletVersion: node.Status.NodeInfo.KubeletVersion,
|
||||
}
|
||||
|
||||
// CPU
|
||||
cpuCapacity := node.Status.Capacity.Cpu()
|
||||
cpuAllocatable := node.Status.Allocatable.Cpu()
|
||||
nodeMetric.CPUCapacity = fmt.Sprintf("%.2f cores", float64(cpuCapacity.MilliValue())/1000.0)
|
||||
nodeMetric.CPUAllocatable = fmt.Sprintf("%.2f cores", float64(cpuAllocatable.MilliValue())/1000.0)
|
||||
|
||||
// Memory
|
||||
memCapacity := node.Status.Capacity.Memory()
|
||||
memAllocatable := node.Status.Allocatable.Memory()
|
||||
nodeMetric.MemoryCapacity = formatBytes(memCapacity.Value())
|
||||
nodeMetric.MemoryAllocatable = formatBytes(memAllocatable.Value())
|
||||
|
||||
// GPU (从 node allocatable 中查找)
|
||||
if gpu, ok := node.Status.Allocatable["nvidia.com/gpu"]; ok {
|
||||
nodeMetric.GPUCapacity = int(gpu.Value())
|
||||
// 尝试获取 GPU 类型
|
||||
if gpuType, ok := node.Labels["nvidia.com/gpu.product"]; ok {
|
||||
nodeMetric.GPUType = gpuType
|
||||
}
|
||||
}
|
||||
|
||||
// 获取 Pod 数量
|
||||
pods, err := clientset.CoreV1().Pods("").List(ctx, metav1.ListOptions{
|
||||
FieldSelector: fmt.Sprintf("spec.nodeName=%s", node.Name),
|
||||
})
|
||||
if err == nil {
|
||||
nodeMetric.PodCount = len(pods.Items)
|
||||
}
|
||||
|
||||
// 如果有 metrics client,获取实时使用情况
|
||||
if metricsClient != nil {
|
||||
nodeMetricData, err := metricsClient.MetricsV1beta1().NodeMetricses().Get(ctx, node.Name, metav1.GetOptions{})
|
||||
if err == nil {
|
||||
// CPU 使用
|
||||
cpuUsage := nodeMetricData.Usage.Cpu()
|
||||
nodeMetric.CPUUsage = fmt.Sprintf("%.2f cores", float64(cpuUsage.MilliValue())/1000.0)
|
||||
if cpuAllocatable.MilliValue() > 0 {
|
||||
nodeMetric.CPUPercent = float64(cpuUsage.MilliValue()) / float64(cpuAllocatable.MilliValue()) * 100
|
||||
}
|
||||
|
||||
// Memory 使用
|
||||
memUsage := nodeMetricData.Usage.Memory()
|
||||
nodeMetric.MemoryUsage = formatBytes(memUsage.Value())
|
||||
if memAllocatable.Value() > 0 {
|
||||
nodeMetric.MemoryPercent = float64(memUsage.Value()) / float64(memAllocatable.Value()) * 100
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, nodeMetric)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// aggregateClusterMetrics 聚合集群级别指标
|
||||
func (c *MetricsClient) aggregateClusterMetrics(
|
||||
cluster *entity.Cluster,
|
||||
nodes []corev1.Node,
|
||||
pods []corev1.Pod,
|
||||
nodeMetrics []*entity.NodeMetrics,
|
||||
) *entity.ClusterMetrics {
|
||||
metrics := &entity.ClusterMetrics{
|
||||
ClusterID: cluster.ID,
|
||||
ClusterName: cluster.Name,
|
||||
Status: "healthy",
|
||||
NodeCount: len(nodes),
|
||||
PodCount: len(pods),
|
||||
LastCheck: time.Now(),
|
||||
Nodes: make([]entity.NodeMetrics, 0),
|
||||
}
|
||||
|
||||
// 汇总资源
|
||||
var totalCPU, totalMem, usedCPU, usedMem int64
|
||||
var totalGPU, usedGPU int
|
||||
healthyNodes := 0
|
||||
|
||||
// 单机最大值
|
||||
var maxNodeCPU, maxNodeMem int64
|
||||
var maxNodeGPU int
|
||||
var maxNodeCPUUsage, maxNodeMemUsage, maxNodeGPUUsage float64
|
||||
|
||||
for i, node := range nodes {
|
||||
// CPU
|
||||
cpuCap := node.Status.Capacity.Cpu()
|
||||
totalCPU += cpuCap.MilliValue()
|
||||
if cpuCap.MilliValue() > maxNodeCPU {
|
||||
maxNodeCPU = cpuCap.MilliValue()
|
||||
}
|
||||
|
||||
// Memory
|
||||
memCap := node.Status.Capacity.Memory()
|
||||
totalMem += memCap.Value()
|
||||
if memCap.Value() > maxNodeMem {
|
||||
maxNodeMem = memCap.Value()
|
||||
}
|
||||
|
||||
// GPU
|
||||
if gpu, ok := node.Status.Allocatable["nvidia.com/gpu"]; ok {
|
||||
gpuCount := int(gpu.Value())
|
||||
totalGPU += gpuCount
|
||||
if gpuCount > maxNodeGPU {
|
||||
maxNodeGPU = gpuCount
|
||||
}
|
||||
}
|
||||
|
||||
// Node status
|
||||
if getNodeStatus(&node) == "Ready" {
|
||||
healthyNodes++
|
||||
}
|
||||
|
||||
// 从 nodeMetrics 获取使用情况
|
||||
if i < len(nodeMetrics) && nodeMetrics[i] != nil {
|
||||
metrics.Nodes = append(metrics.Nodes, *nodeMetrics[i])
|
||||
|
||||
// 更新单机最大使用率
|
||||
if nodeMetrics[i].CPUPercent > maxNodeCPUUsage {
|
||||
maxNodeCPUUsage = nodeMetrics[i].CPUPercent
|
||||
}
|
||||
if nodeMetrics[i].MemoryPercent > maxNodeMemUsage {
|
||||
maxNodeMemUsage = nodeMetrics[i].MemoryPercent
|
||||
}
|
||||
if nodeMetrics[i].GPUPercent > maxNodeGPUUsage {
|
||||
maxNodeGPUUsage = nodeMetrics[i].GPUPercent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算集群 uptime(简化:使用最老节点的年龄)
|
||||
if len(nodes) > 0 {
|
||||
metrics.Uptime = getNodeAge(&nodes[0])
|
||||
}
|
||||
|
||||
// 格式化总资源
|
||||
metrics.TotalCPU = fmt.Sprintf("%.2f cores", float64(totalCPU)/1000.0)
|
||||
metrics.TotalMemory = formatBytes(totalMem)
|
||||
metrics.TotalGPU = totalGPU
|
||||
|
||||
// 格式化单机最大值
|
||||
metrics.MaxNodeCPU = fmt.Sprintf("%.2f cores", float64(maxNodeCPU)/1000.0)
|
||||
metrics.MaxNodeMemory = formatBytes(maxNodeMem)
|
||||
metrics.MaxNodeGPU = maxNodeGPU
|
||||
metrics.MaxNodeCPUUsage = maxNodeCPUUsage
|
||||
metrics.MaxNodeMemUsage = maxNodeMemUsage
|
||||
metrics.MaxNodeGPUUsage = maxNodeGPUUsage
|
||||
|
||||
// 使用情况(简化处理)
|
||||
if len(nodeMetrics) > 0 {
|
||||
for _, nm := range nodeMetrics {
|
||||
// 解析使用的 CPU 和内存
|
||||
// 这里简化处理,实际应该解析字符串
|
||||
usedCPU += int64(nm.CPUPercent * float64(totalCPU) / 100.0)
|
||||
usedMem += int64(nm.MemoryPercent * float64(totalMem) / 100.0)
|
||||
usedGPU += nm.GPUUsage
|
||||
}
|
||||
|
||||
if totalCPU > 0 {
|
||||
metrics.CPUUsage = float64(usedCPU) / float64(totalCPU) * 100
|
||||
}
|
||||
if totalMem > 0 {
|
||||
metrics.MemoryUsage = float64(usedMem) / float64(totalMem) * 100
|
||||
}
|
||||
if totalGPU > 0 {
|
||||
metrics.GPUUsage = float64(usedGPU) / float64(totalGPU) * 100
|
||||
}
|
||||
|
||||
metrics.UsedCPU = fmt.Sprintf("%.2f cores", float64(usedCPU)/1000.0)
|
||||
metrics.UsedMemory = formatBytes(usedMem)
|
||||
metrics.UsedGPU = usedGPU
|
||||
}
|
||||
|
||||
// 确定集群状态
|
||||
if healthyNodes == len(nodes) {
|
||||
metrics.Status = "healthy"
|
||||
} else if healthyNodes > 0 {
|
||||
metrics.Status = "warning"
|
||||
} else {
|
||||
metrics.Status = "error"
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func getNodeStatus(node *corev1.Node) string {
|
||||
for _, condition := range node.Status.Conditions {
|
||||
if condition.Type == corev1.NodeReady {
|
||||
if condition.Status == corev1.ConditionTrue {
|
||||
return "Ready"
|
||||
}
|
||||
return "NotReady"
|
||||
}
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
func getNodeRole(node *corev1.Node) string {
|
||||
if _, ok := node.Labels["node-role.kubernetes.io/control-plane"]; ok {
|
||||
return "control-plane"
|
||||
}
|
||||
if _, ok := node.Labels["node-role.kubernetes.io/master"]; ok {
|
||||
return "control-plane"
|
||||
}
|
||||
return "worker"
|
||||
}
|
||||
|
||||
func getNodeAge(node *corev1.Node) string {
|
||||
age := time.Since(node.CreationTimestamp.Time)
|
||||
days := int(age.Hours() / 24)
|
||||
hours := int(age.Hours()) % 24
|
||||
|
||||
if days > 0 {
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
return fmt.Sprintf("%dh", hours)
|
||||
}
|
||||
|
||||
func formatBytes(bytes int64) string {
|
||||
const unit = 1024
|
||||
if bytes < unit {
|
||||
return fmt.Sprintf("%d B", bytes)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := bytes / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
||||
}
|
||||
|
||||
284
backend/internal/adapter/output/oci/mock/oci_client_mock.go
Normal file
284
backend/internal/adapter/output/oci/mock/oci_client_mock.go
Normal file
@ -0,0 +1,284 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
)
|
||||
|
||||
// OCIClientMock OCI Registry 客户端 Mock 实现
|
||||
type OCIClientMock struct {
|
||||
// Mock 数据存储
|
||||
repositories map[string][]string // registryID -> []repositoryName
|
||||
artifacts map[string]map[string][]*entity.Artifact // registryID -> repository -> []artifact
|
||||
}
|
||||
|
||||
// NewOCIClientMock 创建 Mock 实现
|
||||
func NewOCIClientMock() repository.OCIClient {
|
||||
mock := &OCIClientMock{
|
||||
repositories: make(map[string][]string),
|
||||
artifacts: make(map[string]map[string][]*entity.Artifact),
|
||||
}
|
||||
|
||||
// 初始化一些测试数据
|
||||
mock.initMockData()
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
func (c *OCIClientMock) initMockData() {
|
||||
// Note: This method intentionally left empty
|
||||
// Mock data will be generated dynamically per registry to support any registry ID
|
||||
}
|
||||
|
||||
// initArtifactsForRegistry initializes mock artifacts for a given registry ID
|
||||
func (c *OCIClientMock) initArtifactsForRegistry(registryID string) {
|
||||
c.artifacts[registryID] = make(map[string][]*entity.Artifact)
|
||||
|
||||
// vllm-serve artifacts (OCI 格式的 Helm Chart)
|
||||
c.artifacts[registryID]["charts/vllm-serve"] = []*entity.Artifact{
|
||||
{
|
||||
RegistryID: registryID,
|
||||
Repository: "charts/vllm-serve",
|
||||
Tag: "0.1.0",
|
||||
Digest: "sha256:abc123def456",
|
||||
Type: entity.ArtifactTypeChart,
|
||||
Size: 12345678,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
|
||||
Annotations: map[string]string{
|
||||
"org.opencontainers.image.title": "vllm-serve",
|
||||
"org.opencontainers.image.version": "0.1.0",
|
||||
},
|
||||
CreatedAt: time.Now().Add(-24 * time.Hour),
|
||||
},
|
||||
{
|
||||
RegistryID: registryID,
|
||||
Repository: "charts/vllm-serve",
|
||||
Tag: "0.2.0",
|
||||
Digest: "sha256:xyz789uvw012",
|
||||
Type: entity.ArtifactTypeChart,
|
||||
Size: 13456789,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
|
||||
Annotations: map[string]string{
|
||||
"org.opencontainers.image.title": "vllm-serve",
|
||||
"org.opencontainers.image.version": "0.2.0",
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
// nginx artifacts (OCI 格式的 Helm Chart)
|
||||
c.artifacts[registryID]["charts/nginx"] = []*entity.Artifact{
|
||||
{
|
||||
RegistryID: registryID,
|
||||
Repository: "charts/nginx",
|
||||
Tag: "1.0.0",
|
||||
Digest: "sha256:nginx123456",
|
||||
Type: entity.ArtifactTypeChart,
|
||||
Size: 5678901,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
|
||||
Annotations: map[string]string{
|
||||
"org.opencontainers.image.title": "nginx",
|
||||
},
|
||||
CreatedAt: time.Now().Add(-48 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
// redis artifacts (OCI 格式的 Helm Chart)
|
||||
c.artifacts[registryID]["charts/redis"] = []*entity.Artifact{
|
||||
{
|
||||
RegistryID: registryID,
|
||||
Repository: "charts/redis",
|
||||
Tag: "6.2.0",
|
||||
Digest: "sha256:redis789abc",
|
||||
Type: entity.ArtifactTypeChart,
|
||||
Size: 8901234,
|
||||
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
|
||||
Annotations: map[string]string{
|
||||
"org.opencontainers.image.title": "redis",
|
||||
"org.opencontainers.image.version": "6.2.0",
|
||||
},
|
||||
CreatedAt: time.Now().Add(-72 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
// alpine artifacts (Docker Image)
|
||||
c.artifacts[registryID]["library/alpine"] = []*entity.Artifact{
|
||||
{
|
||||
RegistryID: registryID,
|
||||
Repository: "library/alpine",
|
||||
Tag: "3.18",
|
||||
Digest: "sha256:alpine123",
|
||||
Type: entity.ArtifactTypeImage,
|
||||
Size: 2345678,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
ConfigType: "application/vnd.docker.container.image.v1+json", // Docker Image 的 config type
|
||||
Annotations: map[string]string{
|
||||
"org.opencontainers.image.title": "alpine",
|
||||
"org.opencontainers.image.version": "3.18",
|
||||
},
|
||||
CreatedAt: time.Now().Add(-96 * time.Hour),
|
||||
},
|
||||
{
|
||||
RegistryID: registryID,
|
||||
Repository: "library/alpine",
|
||||
Tag: "latest",
|
||||
Digest: "sha256:alpine456",
|
||||
Type: entity.ArtifactTypeImage,
|
||||
Size: 2456789,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
ConfigType: "application/vnd.docker.container.image.v1+json", // Docker Image 的 config type
|
||||
Annotations: map[string]string{
|
||||
"org.opencontainers.image.title": "alpine",
|
||||
},
|
||||
CreatedAt: time.Now().Add(-24 * time.Hour),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OCIClientMock) ListRepositories(ctx context.Context, registry *entity.Registry) ([]string, error) {
|
||||
// Check if we have cached data for this registry
|
||||
repos, exists := c.repositories[registry.ID]
|
||||
if !exists {
|
||||
// Generate mock data dynamically for any registry
|
||||
repos = []string{
|
||||
"charts/vllm-serve",
|
||||
"charts/nginx",
|
||||
"charts/redis",
|
||||
"library/alpine",
|
||||
}
|
||||
c.repositories[registry.ID] = repos
|
||||
|
||||
// Also initialize artifacts for this registry
|
||||
c.initArtifactsForRegistry(registry.ID)
|
||||
}
|
||||
return repos, nil
|
||||
}
|
||||
|
||||
func (c *OCIClientMock) ListArtifacts(ctx context.Context, registry *entity.Registry, repository, mediaTypeFilter string) ([]*entity.Artifact, error) {
|
||||
regArtifacts, exists := c.artifacts[registry.ID]
|
||||
if !exists {
|
||||
// Initialize artifacts for this registry if not exists
|
||||
c.initArtifactsForRegistry(registry.ID)
|
||||
regArtifacts = c.artifacts[registry.ID]
|
||||
}
|
||||
|
||||
artifacts, exists := regArtifacts[repository]
|
||||
if !exists {
|
||||
return []*entity.Artifact{}, nil
|
||||
}
|
||||
|
||||
// 应用 mediaType 过滤
|
||||
if mediaTypeFilter == "" || mediaTypeFilter == "all" {
|
||||
return artifacts, nil
|
||||
}
|
||||
|
||||
filtered := make([]*entity.Artifact, 0)
|
||||
filter := strings.ToLower(strings.TrimSpace(mediaTypeFilter))
|
||||
|
||||
for _, artifact := range artifacts {
|
||||
switch filter {
|
||||
case "chart":
|
||||
if artifact.Type == entity.ArtifactTypeChart {
|
||||
filtered = append(filtered, artifact)
|
||||
}
|
||||
case "image":
|
||||
if artifact.Type == entity.ArtifactTypeImage {
|
||||
filtered = append(filtered, artifact)
|
||||
}
|
||||
case "other":
|
||||
if artifact.Type == entity.ArtifactTypeOther {
|
||||
filtered = append(filtered, artifact)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (c *OCIClientMock) GetArtifact(ctx context.Context, registry *entity.Registry, repository, reference string) (*entity.Artifact, error) {
|
||||
regArtifacts, exists := c.artifacts[registry.ID]
|
||||
if !exists {
|
||||
// Initialize artifacts for this registry if not exists
|
||||
c.initArtifactsForRegistry(registry.ID)
|
||||
regArtifacts = c.artifacts[registry.ID]
|
||||
}
|
||||
|
||||
artifacts, exists := regArtifacts[repository]
|
||||
if !exists {
|
||||
return nil, entity.ErrArtifactNotFound
|
||||
}
|
||||
|
||||
// 根据 tag 或 digest 查找
|
||||
for _, artifact := range artifacts {
|
||||
if artifact.Tag == reference || artifact.Digest == reference {
|
||||
return artifact, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, entity.ErrArtifactNotFound
|
||||
}
|
||||
|
||||
func (c *OCIClientMock) GetValuesSchema(ctx context.Context, registry *entity.Registry, repository, reference string) (string, error) {
|
||||
artifact, err := c.GetArtifact(ctx, registry, repository, reference)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if !artifact.IsChart() {
|
||||
return "", fmt.Errorf("not a helm chart")
|
||||
}
|
||||
|
||||
// 返回 Mock values schema
|
||||
mockSchema := `{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"replicaCount": {
|
||||
"type": "integer",
|
||||
"default": 1
|
||||
},
|
||||
"image": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"repository": {
|
||||
"type": "string"
|
||||
},
|
||||
"tag": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
return mockSchema, nil
|
||||
}
|
||||
|
||||
func (c *OCIClientMock) PullArtifact(ctx context.Context, registry *entity.Registry, repository, reference, destPath string) error {
|
||||
_, err := c.GetArtifact(ctx, registry, repository, reference)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mock 实现,不实际下载
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OCIClientMock) PushArtifact(ctx context.Context, registry *entity.Registry, repository, tag, sourcePath string) error {
|
||||
// Mock 实现,不实际上传
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OCIClientMock) CheckHealth(ctx context.Context, registry *entity.Registry) error {
|
||||
// Mock 实现,总是返回健康
|
||||
return nil
|
||||
}
|
||||
|
||||
468
backend/internal/adapter/output/oci/real/oci_client.go
Normal file
468
backend/internal/adapter/output/oci/real/oci_client.go
Normal file
@ -0,0 +1,468 @@
|
||||
package real
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
"github.com/opencontainers/go-digest"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
"oras.land/oras-go/v2/registry/remote"
|
||||
"oras.land/oras-go/v2/registry/remote/auth"
|
||||
)
|
||||
|
||||
// OCIClient 真实的 OCI 客户端实现(使用 ORAS)
|
||||
type OCIClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewOCIClient 创建真实的 OCI 客户端
|
||||
func NewOCIClient() repository.OCIClient {
|
||||
return &OCIClient{
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// getRegistry 创建 ORAS Registry 客户端
|
||||
func (c *OCIClient) getRegistry(reg *entity.Registry) (*remote.Registry, error) {
|
||||
// 解析 Registry URL
|
||||
registryURL := strings.TrimPrefix(reg.URL, "https://")
|
||||
registryURL = strings.TrimPrefix(registryURL, "http://")
|
||||
|
||||
registry, err := remote.NewRegistry(registryURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create registry client: %w", err)
|
||||
}
|
||||
|
||||
// 设置认证
|
||||
if reg.Username != "" && reg.Password != "" {
|
||||
registry.Client = &auth.Client{
|
||||
Client: c.httpClient,
|
||||
Credential: auth.StaticCredential(registryURL, auth.Credential{
|
||||
Username: reg.Username,
|
||||
Password: reg.Password,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// 设置 PlainHTTP(如果是 insecure)
|
||||
registry.PlainHTTP = reg.Insecure
|
||||
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
// ListRepositories 列出 Registry 中的所有 repositories
|
||||
func (c *OCIClient) ListRepositories(ctx context.Context, registry *entity.Registry) ([]string, error) {
|
||||
reg, err := c.getRegistry(registry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repositories := make([]string, 0)
|
||||
|
||||
err = reg.Repositories(ctx, "", func(repos []string) error {
|
||||
repositories = append(repositories, repos...)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list repositories: %w", err)
|
||||
}
|
||||
|
||||
return repositories, nil
|
||||
}
|
||||
|
||||
// ListArtifacts 列出指定 repository 的所有 artifacts
|
||||
// mediaTypeFilter: "all", "image", "chart", "other" - 使用模糊匹配过滤
|
||||
func (c *OCIClient) ListArtifacts(ctx context.Context, registry *entity.Registry, repository, mediaTypeFilter string) ([]*entity.Artifact, error) {
|
||||
reg, err := c.getRegistry(registry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repo, err := reg.Repository(ctx, repository)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get repository: %w", err)
|
||||
}
|
||||
|
||||
artifacts := make([]*entity.Artifact, 0)
|
||||
|
||||
err = repo.Tags(ctx, "", func(tags []string) error {
|
||||
for _, tag := range tags {
|
||||
// 获取 manifest 以获取更多信息
|
||||
desc, err := repo.Resolve(ctx, tag)
|
||||
if err != nil {
|
||||
// 跳过无法解析的 tag
|
||||
continue
|
||||
}
|
||||
|
||||
artifact := &entity.Artifact{
|
||||
Repository: repository,
|
||||
Tag: tag,
|
||||
Digest: desc.Digest.String(),
|
||||
MediaType: desc.MediaType,
|
||||
Size: desc.Size,
|
||||
}
|
||||
|
||||
// 尝试获取 config.mediaType 以更准确判断类型
|
||||
if manifestBytes, err := repo.Fetch(ctx, desc); err == nil {
|
||||
defer manifestBytes.Close()
|
||||
if manifestData, err := io.ReadAll(manifestBytes); err == nil {
|
||||
var manifest map[string]interface{}
|
||||
if err := json.Unmarshal(manifestData, &manifest); err == nil {
|
||||
// 获取 config.mediaType
|
||||
if config, ok := manifest["config"].(map[string]interface{}); ok {
|
||||
if configMediaType, ok := config["mediaType"].(string); ok {
|
||||
artifact.ConfigType = configMediaType
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用智能类型判断(综合多种信息)
|
||||
artifact.DetermineType()
|
||||
|
||||
// 应用 mediaType 过滤
|
||||
if c.shouldIncludeArtifact(artifact, mediaTypeFilter) {
|
||||
artifacts = append(artifacts, artifact)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list artifacts: %w", err)
|
||||
}
|
||||
|
||||
return artifacts, nil
|
||||
}
|
||||
|
||||
// shouldIncludeArtifact 判断是否应该包含该 artifact
|
||||
func (c *OCIClient) shouldIncludeArtifact(artifact *entity.Artifact, filter string) bool {
|
||||
// 默认或 "all" 返回所有
|
||||
if filter == "" || filter == "all" {
|
||||
return true
|
||||
}
|
||||
|
||||
filter = strings.ToLower(strings.TrimSpace(filter))
|
||||
|
||||
switch filter {
|
||||
case "chart":
|
||||
// 只返回 Helm Charts
|
||||
return artifact.Type == entity.ArtifactTypeChart
|
||||
case "image":
|
||||
// 返回 Docker 或 OCI images
|
||||
return artifact.Type == entity.ArtifactTypeImage
|
||||
case "other":
|
||||
// 返回其他类型
|
||||
return artifact.Type == entity.ArtifactTypeOther
|
||||
default:
|
||||
// 未知的 filter,返回所有
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// GetArtifact 获取指定 artifact 的详细信息
|
||||
func (c *OCIClient) GetArtifact(ctx context.Context, registry *entity.Registry, repository, reference string) (*entity.Artifact, error) {
|
||||
reg, err := c.getRegistry(registry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repo, err := reg.Repository(ctx, repository)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get repository: %w", err)
|
||||
}
|
||||
|
||||
// 解析 reference
|
||||
desc, err := repo.Resolve(ctx, reference)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve artifact: %w", err)
|
||||
}
|
||||
|
||||
// 获取 manifest
|
||||
manifestBytes, err := repo.Fetch(ctx, desc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch manifest: %w", err)
|
||||
}
|
||||
defer manifestBytes.Close()
|
||||
|
||||
manifestData, err := io.ReadAll(manifestBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read manifest: %w", err)
|
||||
}
|
||||
|
||||
// 解析 manifest 获取配置信息
|
||||
var manifest map[string]interface{}
|
||||
if err := json.Unmarshal(manifestData, &manifest); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal manifest: %w", err)
|
||||
}
|
||||
|
||||
artifact := &entity.Artifact{
|
||||
Repository: repository,
|
||||
Tag: reference,
|
||||
Digest: desc.Digest.String(),
|
||||
MediaType: desc.MediaType,
|
||||
Size: desc.Size,
|
||||
Annotations: make(map[string]string),
|
||||
}
|
||||
|
||||
// 获取 config.mediaType 和 annotations
|
||||
if config, ok := manifest["config"].(map[string]interface{}); ok {
|
||||
// 获取 config.mediaType(用于准确的类型判断)
|
||||
if configMediaType, ok := config["mediaType"].(string); ok {
|
||||
artifact.ConfigType = configMediaType
|
||||
}
|
||||
|
||||
// 获取 annotations
|
||||
if annotations, ok := config["annotations"].(map[string]interface{}); ok {
|
||||
for k, v := range annotations {
|
||||
if str, ok := v.(string); ok {
|
||||
artifact.Annotations[k] = str
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 使用智能类型判断(综合 ConfigType, Annotations, Repository 名称等)
|
||||
artifact.DetermineType()
|
||||
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
// GetValuesSchema 获取 Helm Chart 的 values schema
|
||||
func (c *OCIClient) GetValuesSchema(ctx context.Context, registry *entity.Registry, repository, reference string) (string, error) {
|
||||
reg, err := c.getRegistry(registry)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
repo, err := reg.Repository(ctx, repository)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get repository: %w", err)
|
||||
}
|
||||
|
||||
// 解析 reference (tag 或 digest)
|
||||
desc, err := repo.Resolve(ctx, reference)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve artifact: %w", err)
|
||||
}
|
||||
|
||||
manifestReader, err := repo.Fetch(ctx, desc)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch manifest: %w", err)
|
||||
}
|
||||
defer manifestReader.Close()
|
||||
|
||||
manifestBytes, err := io.ReadAll(manifestReader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest ocispec.Manifest
|
||||
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal manifest: %w", err)
|
||||
}
|
||||
|
||||
// 优先查找是否存在独立的 values schema layer(一些 registry 会将 values.schema.json 作为单独的 layer 存储)
|
||||
var valuesSchemaLayer *ocispec.Descriptor
|
||||
for i := range manifest.Layers {
|
||||
layer := manifest.Layers[i]
|
||||
mediaType := strings.ToLower(layer.MediaType)
|
||||
|
||||
if strings.Contains(mediaType, "helm.values.schema") ||
|
||||
strings.Contains(mediaType, "values.schema") {
|
||||
valuesSchemaLayer = &manifest.Layers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 如果存在独立的 values schema layer,直接返回
|
||||
if valuesSchemaLayer != nil {
|
||||
reader, err := repo.Fetch(ctx, *valuesSchemaLayer)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch values schema layer: %w", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
data, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read values schema layer: %w", err)
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return "", entity.ErrValuesSchemaNotFound
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// 回退:查找 Helm Chart layer(tar+gzip 包含 chart 内容)并从中读取 values.schema.json
|
||||
var chartLayer *ocispec.Descriptor
|
||||
for i := range manifest.Layers {
|
||||
layer := manifest.Layers[i]
|
||||
if strings.Contains(layer.MediaType, "cncf.helm.chart") ||
|
||||
strings.Contains(layer.MediaType, "helm.chart.content") {
|
||||
chartLayer = &manifest.Layers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chartLayer == nil {
|
||||
return "", entity.ErrValuesSchemaNotFound
|
||||
}
|
||||
|
||||
if chartLayer.Digest == "" {
|
||||
return "", fmt.Errorf("chart layer digest is empty")
|
||||
}
|
||||
if _, err := digest.Parse(string(chartLayer.Digest)); err != nil {
|
||||
return "", fmt.Errorf("invalid chart layer digest: %w", err)
|
||||
}
|
||||
|
||||
layerReader, err := repo.Fetch(ctx, *chartLayer)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch chart layer: %w", err)
|
||||
}
|
||||
defer layerReader.Close()
|
||||
|
||||
gzipReader, err := gzip.NewReader(layerReader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create gzip reader: %w", err)
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
|
||||
tarReader := tar.NewReader(gzipReader)
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read chart archive: %w", err)
|
||||
}
|
||||
|
||||
if header.Typeflag != tar.TypeReg {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasSuffix(header.Name, "values.schema.json") {
|
||||
data, err := io.ReadAll(tarReader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read values.schema.json: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return "", entity.ErrValuesSchemaNotFound
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", entity.ErrValuesSchemaNotFound
|
||||
}
|
||||
|
||||
// PullArtifact 下载 artifact 到本地
|
||||
func (c *OCIClient) PullArtifact(ctx context.Context, registry *entity.Registry, repository, reference, destPath string) error {
|
||||
reg, err := c.getRegistry(registry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
repo, err := reg.Repository(ctx, repository)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get repository: %w", err)
|
||||
}
|
||||
|
||||
// 解析 reference
|
||||
desc, err := repo.Resolve(ctx, reference)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve artifact: %w", err)
|
||||
}
|
||||
|
||||
// 获取 manifest 内容
|
||||
manifestReader, err := repo.Fetch(ctx, desc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch manifest: %w", err)
|
||||
}
|
||||
defer manifestReader.Close()
|
||||
|
||||
manifestBytes, err := io.ReadAll(manifestReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read manifest: %w", err)
|
||||
}
|
||||
|
||||
var manifest ocispec.Manifest
|
||||
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal manifest: %w", err)
|
||||
}
|
||||
|
||||
var chartLayer *ocispec.Descriptor
|
||||
for i := range manifest.Layers {
|
||||
layer := manifest.Layers[i]
|
||||
if strings.Contains(layer.MediaType, "cncf.helm.chart") ||
|
||||
strings.Contains(layer.MediaType, "helm.chart.content") {
|
||||
chartLayer = &layer
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chartLayer == nil {
|
||||
return fmt.Errorf("helm chart layer not found in manifest")
|
||||
}
|
||||
|
||||
content, err := repo.Fetch(ctx, *chartLayer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch chart layer: %w", err)
|
||||
}
|
||||
defer content.Close()
|
||||
|
||||
// 确保目标目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil {
|
||||
return fmt.Errorf("failed to create destination directory: %w", err)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
file, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err := io.Copy(file, content); err != nil {
|
||||
return fmt.Errorf("failed to write artifact: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PushArtifact 推送 artifact 到 Registry
|
||||
func (c *OCIClient) PushArtifact(ctx context.Context, registry *entity.Registry, repository, tag, sourcePath string) error {
|
||||
// 这是一个简化实现
|
||||
// 实际应该实现完整的 OCI artifact push 流程
|
||||
return fmt.Errorf("push artifact not fully implemented yet")
|
||||
}
|
||||
|
||||
// CheckHealth 检查 Registry 健康状态
|
||||
func (c *OCIClient) CheckHealth(ctx context.Context, registry *entity.Registry) error {
|
||||
reg, err := c.getRegistry(registry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 尝试 ping registry
|
||||
err = reg.Ping(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("registry health check failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,174 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
"github.com/ocdp/cluster-service/internal/pkg/crypto"
|
||||
)
|
||||
|
||||
// ClusterRepositoryMock 集群仓储 Mock 实现(内存存储,支持加密)
|
||||
type ClusterRepositoryMock struct {
|
||||
mu sync.RWMutex
|
||||
clusters map[string]*entity.Cluster // key: cluster ID
|
||||
encryptor crypto.Encryptor // 加密器
|
||||
}
|
||||
|
||||
// NewClusterRepositoryMock 创建 Mock 实现
|
||||
func NewClusterRepositoryMock(encryptor crypto.Encryptor) repository.ClusterRepository {
|
||||
return &ClusterRepositoryMock{
|
||||
clusters: make(map[string]*entity.Cluster),
|
||||
encryptor: encryptor,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ClusterRepositoryMock) Create(ctx context.Context, cluster *entity.Cluster) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 检查名称是否已存在
|
||||
for _, c := range r.clusters {
|
||||
if c.Name == cluster.Name {
|
||||
return entity.ErrClusterExists
|
||||
}
|
||||
}
|
||||
|
||||
// Mock 模式:如果没有提供认证信息,自动填充默认的 Mock 证书
|
||||
if (cluster.CertData == "" || cluster.KeyData == "") && cluster.Token == "" {
|
||||
cluster.CAData = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1vY2sgQ0EgQ2VydGlmaWNhdGUKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQ=="
|
||||
cluster.CertData = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1vY2sgQ2xpZW50IENlcnRpZmljYXRlCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0="
|
||||
cluster.KeyData = "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNb2NrIFByaXZhdGUgS2V5Ci0tLS0tRU5EIFJTQSBQUklWQVRFIEtFWS0tLS0t"
|
||||
}
|
||||
|
||||
// 加密敏感数据后存储
|
||||
encryptedCluster := r.encryptCluster(cluster)
|
||||
r.clusters[cluster.ID] = encryptedCluster
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ClusterRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
cluster, exists := r.clusters[id]
|
||||
if !exists {
|
||||
return nil, entity.ErrClusterNotFound
|
||||
}
|
||||
|
||||
// 解密敏感数据后返回
|
||||
return r.decryptCluster(cluster), nil
|
||||
}
|
||||
|
||||
func (r *ClusterRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, cluster := range r.clusters {
|
||||
if cluster.Name == name {
|
||||
// 解密敏感数据后返回
|
||||
return r.decryptCluster(cluster), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, entity.ErrClusterNotFound
|
||||
}
|
||||
|
||||
func (r *ClusterRepositoryMock) Update(ctx context.Context, cluster *entity.Cluster) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.clusters[cluster.ID]; !exists {
|
||||
return entity.ErrClusterNotFound
|
||||
}
|
||||
|
||||
// 加密敏感数据后存储
|
||||
encryptedCluster := r.encryptCluster(cluster)
|
||||
r.clusters[cluster.ID] = encryptedCluster
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ClusterRepositoryMock) Delete(ctx context.Context, id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.clusters[id]; !exists {
|
||||
return entity.ErrClusterNotFound
|
||||
}
|
||||
|
||||
delete(r.clusters, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *ClusterRepositoryMock) List(ctx context.Context) ([]*entity.Cluster, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
clusters := make([]*entity.Cluster, 0, len(r.clusters))
|
||||
for _, cluster := range r.clusters {
|
||||
// 解密敏感数据后返回
|
||||
clusters = append(clusters, r.decryptCluster(cluster))
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// encryptCluster 加密 Cluster 的敏感数据
|
||||
func (r *ClusterRepositoryMock) encryptCluster(cluster *entity.Cluster) *entity.Cluster {
|
||||
encrypted := *cluster // 复制
|
||||
|
||||
// 加密证书数据
|
||||
if cluster.CAData != "" && !crypto.IsEncrypted(cluster.CAData) {
|
||||
if encryptedData, err := r.encryptor.Encrypt(cluster.CAData); err == nil {
|
||||
encrypted.CAData = encryptedData
|
||||
}
|
||||
}
|
||||
if cluster.CertData != "" && !crypto.IsEncrypted(cluster.CertData) {
|
||||
if encryptedData, err := r.encryptor.Encrypt(cluster.CertData); err == nil {
|
||||
encrypted.CertData = encryptedData
|
||||
}
|
||||
}
|
||||
if cluster.KeyData != "" && !crypto.IsEncrypted(cluster.KeyData) {
|
||||
if encryptedData, err := r.encryptor.Encrypt(cluster.KeyData); err == nil {
|
||||
encrypted.KeyData = encryptedData
|
||||
}
|
||||
}
|
||||
if cluster.Token != "" && !crypto.IsEncrypted(cluster.Token) {
|
||||
if encryptedData, err := r.encryptor.Encrypt(cluster.Token); err == nil {
|
||||
encrypted.Token = encryptedData
|
||||
}
|
||||
}
|
||||
|
||||
return &encrypted
|
||||
}
|
||||
|
||||
// decryptCluster 解密 Cluster 的敏感数据
|
||||
func (r *ClusterRepositoryMock) decryptCluster(cluster *entity.Cluster) *entity.Cluster {
|
||||
decrypted := *cluster // 复制
|
||||
|
||||
// 解密证书数据
|
||||
if cluster.CAData != "" && crypto.IsEncrypted(cluster.CAData) {
|
||||
if decryptedData, err := r.encryptor.Decrypt(cluster.CAData); err == nil {
|
||||
decrypted.CAData = decryptedData
|
||||
}
|
||||
}
|
||||
if cluster.CertData != "" && crypto.IsEncrypted(cluster.CertData) {
|
||||
if decryptedData, err := r.encryptor.Decrypt(cluster.CertData); err == nil {
|
||||
decrypted.CertData = decryptedData
|
||||
}
|
||||
}
|
||||
if cluster.KeyData != "" && crypto.IsEncrypted(cluster.KeyData) {
|
||||
if decryptedData, err := r.encryptor.Decrypt(cluster.KeyData); err == nil {
|
||||
decrypted.KeyData = decryptedData
|
||||
}
|
||||
}
|
||||
if cluster.Token != "" && crypto.IsEncrypted(cluster.Token) {
|
||||
if decryptedData, err := r.encryptor.Decrypt(cluster.Token); err == nil {
|
||||
decrypted.Token = decryptedData
|
||||
}
|
||||
}
|
||||
|
||||
return &decrypted
|
||||
}
|
||||
|
||||
@ -0,0 +1,113 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
)
|
||||
|
||||
// InstanceRepositoryMock 实例仓储 Mock 实现(内存存储)
|
||||
type InstanceRepositoryMock struct {
|
||||
mu sync.RWMutex
|
||||
instances map[string]*entity.Instance // key: instance ID
|
||||
}
|
||||
|
||||
// NewInstanceRepositoryMock 创建 Mock 实现
|
||||
func NewInstanceRepositoryMock() repository.InstanceRepository {
|
||||
return &InstanceRepositoryMock{
|
||||
instances: make(map[string]*entity.Instance),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *InstanceRepositoryMock) Create(ctx context.Context, instance *entity.Instance) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 检查同一集群中名称是否已存在
|
||||
for _, inst := range r.instances {
|
||||
if inst.ClusterID == instance.ClusterID && inst.Name == instance.Name {
|
||||
return entity.ErrInstanceExists
|
||||
}
|
||||
}
|
||||
|
||||
r.instances[instance.ID] = instance
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *InstanceRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Instance, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
instance, exists := r.instances[id]
|
||||
if !exists {
|
||||
return nil, entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (r *InstanceRepositoryMock) GetByClusterAndName(ctx context.Context, clusterID, name string) (*entity.Instance, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, instance := range r.instances {
|
||||
if instance.ClusterID == clusterID && instance.Name == name {
|
||||
return instance, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
func (r *InstanceRepositoryMock) Update(ctx context.Context, instance *entity.Instance) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.instances[instance.ID]; !exists {
|
||||
return entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
r.instances[instance.ID] = instance
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *InstanceRepositoryMock) Delete(ctx context.Context, id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.instances[id]; !exists {
|
||||
return entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
delete(r.instances, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *InstanceRepositoryMock) ListByCluster(ctx context.Context, clusterID string) ([]*entity.Instance, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
instances := make([]*entity.Instance, 0)
|
||||
for _, instance := range r.instances {
|
||||
if instance.ClusterID == clusterID {
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
func (r *InstanceRepositoryMock) List(ctx context.Context) ([]*entity.Instance, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
instances := make([]*entity.Instance, 0, len(r.instances))
|
||||
for _, instance := range r.instances {
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
@ -0,0 +1,137 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
"github.com/ocdp/cluster-service/internal/pkg/crypto"
|
||||
)
|
||||
|
||||
// RegistryRepositoryMock Registry 仓储 Mock 实现(内存存储,支持加密)
|
||||
type RegistryRepositoryMock struct {
|
||||
mu sync.RWMutex
|
||||
registries map[string]*entity.Registry // key: registry ID
|
||||
encryptor crypto.Encryptor // 加密器
|
||||
}
|
||||
|
||||
// NewRegistryRepositoryMock 创建 Mock 实现
|
||||
func NewRegistryRepositoryMock(encryptor crypto.Encryptor) repository.RegistryRepository {
|
||||
return &RegistryRepositoryMock{
|
||||
registries: make(map[string]*entity.Registry),
|
||||
encryptor: encryptor,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RegistryRepositoryMock) Create(ctx context.Context, registry *entity.Registry) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 检查名称是否已存在
|
||||
for _, reg := range r.registries {
|
||||
if reg.Name == registry.Name {
|
||||
return entity.ErrRegistryExists
|
||||
}
|
||||
}
|
||||
|
||||
// 加密敏感数据后存储
|
||||
encryptedRegistry := r.encryptRegistry(registry)
|
||||
r.registries[registry.ID] = encryptedRegistry
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RegistryRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
registry, exists := r.registries[id]
|
||||
if !exists {
|
||||
return nil, entity.ErrRegistryNotFound
|
||||
}
|
||||
|
||||
// 解密敏感数据后返回
|
||||
return r.decryptRegistry(registry), nil
|
||||
}
|
||||
|
||||
func (r *RegistryRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, registry := range r.registries {
|
||||
if registry.Name == name {
|
||||
// 解密敏感数据后返回
|
||||
return r.decryptRegistry(registry), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, entity.ErrRegistryNotFound
|
||||
}
|
||||
|
||||
func (r *RegistryRepositoryMock) Update(ctx context.Context, registry *entity.Registry) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.registries[registry.ID]; !exists {
|
||||
return entity.ErrRegistryNotFound
|
||||
}
|
||||
|
||||
// 加密敏感数据后存储
|
||||
encryptedRegistry := r.encryptRegistry(registry)
|
||||
r.registries[registry.ID] = encryptedRegistry
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RegistryRepositoryMock) Delete(ctx context.Context, id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.registries[id]; !exists {
|
||||
return entity.ErrRegistryNotFound
|
||||
}
|
||||
|
||||
delete(r.registries, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RegistryRepositoryMock) List(ctx context.Context) ([]*entity.Registry, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
registries := make([]*entity.Registry, 0, len(r.registries))
|
||||
for _, registry := range r.registries {
|
||||
// 解密敏感数据后返回
|
||||
registries = append(registries, r.decryptRegistry(registry))
|
||||
}
|
||||
|
||||
return registries, nil
|
||||
}
|
||||
|
||||
// encryptRegistry 加密 Registry 的敏感数据
|
||||
func (r *RegistryRepositoryMock) encryptRegistry(registry *entity.Registry) *entity.Registry {
|
||||
encrypted := *registry // 复制
|
||||
|
||||
// 加密密码
|
||||
if registry.Password != "" && !crypto.IsEncrypted(registry.Password) {
|
||||
if encryptedPassword, err := r.encryptor.Encrypt(registry.Password); err == nil {
|
||||
encrypted.Password = encryptedPassword
|
||||
}
|
||||
}
|
||||
|
||||
return &encrypted
|
||||
}
|
||||
|
||||
// decryptRegistry 解密 Registry 的敏感数据
|
||||
func (r *RegistryRepositoryMock) decryptRegistry(registry *entity.Registry) *entity.Registry {
|
||||
decrypted := *registry // 复制
|
||||
|
||||
// 解密密码
|
||||
if registry.Password != "" && crypto.IsEncrypted(registry.Password) {
|
||||
if decryptedPassword, err := r.encryptor.Decrypt(registry.Password); err == nil {
|
||||
decrypted.Password = decryptedPassword
|
||||
}
|
||||
}
|
||||
|
||||
return &decrypted
|
||||
}
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
)
|
||||
|
||||
// UserRepositoryMock 用户仓储 Mock 实现(内存存储)
|
||||
type UserRepositoryMock struct {
|
||||
mu sync.RWMutex
|
||||
users map[string]*entity.User // key: user ID
|
||||
}
|
||||
|
||||
// NewUserRepositoryMock 创建 Mock 实现
|
||||
func NewUserRepositoryMock() repository.UserRepository {
|
||||
return &UserRepositoryMock{
|
||||
users: make(map[string]*entity.User),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMock) Create(ctx context.Context, user *entity.User) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
for _, u := range r.users {
|
||||
if u.Username == user.Username {
|
||||
return entity.ErrUserExists
|
||||
}
|
||||
}
|
||||
|
||||
r.users[user.ID] = user
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMock) GetByID(ctx context.Context, id string) (*entity.User, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
user, exists := r.users[id]
|
||||
if !exists {
|
||||
return nil, entity.ErrUserNotFound
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMock) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, user := range r.users {
|
||||
if user.Username == username {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, entity.ErrUserNotFound
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMock) Update(ctx context.Context, user *entity.User) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.users[user.ID]; !exists {
|
||||
return entity.ErrUserNotFound
|
||||
}
|
||||
|
||||
r.users[user.ID] = user
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMock) Delete(ctx context.Context, id string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.users[id]; !exists {
|
||||
return entity.ErrUserNotFound
|
||||
}
|
||||
|
||||
delete(r.users, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *UserRepositoryMock) List(ctx context.Context) ([]*entity.User, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
users := make([]*entity.User, 0, len(r.users))
|
||||
for _, user := range r.users {
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
@ -0,0 +1,337 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
"github.com/ocdp/cluster-service/internal/pkg/crypto"
|
||||
)
|
||||
|
||||
// ClusterRepository PostgreSQL 集群仓储实现
|
||||
type ClusterRepository struct {
|
||||
db *DB
|
||||
encryptor crypto.Encryptor
|
||||
}
|
||||
|
||||
// NewClusterRepository 创建 PostgreSQL 集群仓储
|
||||
func NewClusterRepository(db *DB, encryptor crypto.Encryptor) repository.ClusterRepository {
|
||||
return &ClusterRepository{
|
||||
db: db,
|
||||
encryptor: encryptor,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建集群
|
||||
func (r *ClusterRepository) Create(ctx context.Context, cluster *entity.Cluster) error {
|
||||
if cluster.ID == "" {
|
||||
cluster.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 加密敏感数据
|
||||
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt CA data: %w", err)
|
||||
}
|
||||
|
||||
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt cert data: %w", err)
|
||||
}
|
||||
|
||||
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key data: %w", err)
|
||||
}
|
||||
|
||||
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt token: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO clusters (id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
`
|
||||
|
||||
_, err = r.db.conn.ExecContext(ctx, query,
|
||||
cluster.ID,
|
||||
cluster.Name,
|
||||
cluster.Host,
|
||||
encryptedCAData,
|
||||
encryptedCertData,
|
||||
encryptedKeyData,
|
||||
encryptedToken,
|
||||
cluster.Description,
|
||||
cluster.CreatedAt,
|
||||
cluster.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create cluster: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取集群
|
||||
func (r *ClusterRepository) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
|
||||
query := `
|
||||
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
|
||||
FROM clusters
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
cluster := &entity.Cluster{}
|
||||
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
|
||||
|
||||
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
|
||||
&cluster.ID,
|
||||
&cluster.Name,
|
||||
&cluster.Host,
|
||||
&encryptedCAData,
|
||||
&encryptedCertData,
|
||||
&encryptedKeyData,
|
||||
&encryptedToken,
|
||||
&cluster.Description,
|
||||
&cluster.CreatedAt,
|
||||
&cluster.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, entity.ErrClusterNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cluster: %w", err)
|
||||
}
|
||||
|
||||
// 解密敏感数据
|
||||
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
|
||||
}
|
||||
|
||||
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
|
||||
}
|
||||
|
||||
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
|
||||
}
|
||||
|
||||
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
||||
}
|
||||
|
||||
return cluster, nil
|
||||
}
|
||||
|
||||
// GetByName 根据名称获取集群
|
||||
func (r *ClusterRepository) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
|
||||
query := `
|
||||
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
|
||||
FROM clusters
|
||||
WHERE name = $1
|
||||
`
|
||||
|
||||
cluster := &entity.Cluster{}
|
||||
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
|
||||
|
||||
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
|
||||
&cluster.ID,
|
||||
&cluster.Name,
|
||||
&cluster.Host,
|
||||
&encryptedCAData,
|
||||
&encryptedCertData,
|
||||
&encryptedKeyData,
|
||||
&encryptedToken,
|
||||
&cluster.Description,
|
||||
&cluster.CreatedAt,
|
||||
&cluster.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, entity.ErrClusterNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cluster: %w", err)
|
||||
}
|
||||
|
||||
// 解密敏感数据
|
||||
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
|
||||
}
|
||||
|
||||
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
|
||||
}
|
||||
|
||||
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
|
||||
}
|
||||
|
||||
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
||||
}
|
||||
|
||||
return cluster, nil
|
||||
}
|
||||
|
||||
// Update 更新集群
|
||||
func (r *ClusterRepository) Update(ctx context.Context, cluster *entity.Cluster) error {
|
||||
cluster.UpdatedAt = time.Now()
|
||||
|
||||
// 加密敏感数据
|
||||
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt CA data: %w", err)
|
||||
}
|
||||
|
||||
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt cert data: %w", err)
|
||||
}
|
||||
|
||||
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt key data: %w", err)
|
||||
}
|
||||
|
||||
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt token: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE clusters
|
||||
SET name = $1, host = $2, ca_data = $3, cert_data = $4, key_data = $5,
|
||||
token = $6, description = $7, updated_at = $8
|
||||
WHERE id = $9
|
||||
`
|
||||
|
||||
result, err := r.db.conn.ExecContext(ctx, query,
|
||||
cluster.Name,
|
||||
cluster.Host,
|
||||
encryptedCAData,
|
||||
encryptedCertData,
|
||||
encryptedKeyData,
|
||||
encryptedToken,
|
||||
cluster.Description,
|
||||
cluster.UpdatedAt,
|
||||
cluster.ID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update cluster: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return entity.ErrClusterNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除集群
|
||||
func (r *ClusterRepository) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM clusters WHERE id = $1`
|
||||
|
||||
result, err := r.db.conn.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete cluster: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return entity.ErrClusterNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 列出所有集群
|
||||
func (r *ClusterRepository) List(ctx context.Context) ([]*entity.Cluster, error) {
|
||||
query := `
|
||||
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
|
||||
FROM clusters
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list clusters: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
clusters := make([]*entity.Cluster, 0)
|
||||
for rows.Next() {
|
||||
cluster := &entity.Cluster{}
|
||||
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
|
||||
|
||||
err := rows.Scan(
|
||||
&cluster.ID,
|
||||
&cluster.Name,
|
||||
&cluster.Host,
|
||||
&encryptedCAData,
|
||||
&encryptedCertData,
|
||||
&encryptedKeyData,
|
||||
&encryptedToken,
|
||||
&cluster.Description,
|
||||
&cluster.CreatedAt,
|
||||
&cluster.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan cluster: %w", err)
|
||||
}
|
||||
|
||||
// 解密敏感数据
|
||||
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
|
||||
}
|
||||
|
||||
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
|
||||
}
|
||||
|
||||
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
|
||||
}
|
||||
|
||||
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
||||
}
|
||||
|
||||
clusters = append(clusters, cluster)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("rows iteration error: %w", err)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
135
backend/internal/adapter/output/persistence/postgres/db.go
Normal file
135
backend/internal/adapter/output/persistence/postgres/db.go
Normal file
@ -0,0 +1,135 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// DB 数据库连接包装器
|
||||
type DB struct {
|
||||
conn *sql.DB
|
||||
}
|
||||
|
||||
// NewDB 创建新的数据库连接
|
||||
func NewDB(connString string) (*DB, error) {
|
||||
if connString == "" {
|
||||
return nil, fmt.Errorf("database connection string cannot be empty")
|
||||
}
|
||||
|
||||
conn, err := sql.Open("postgres", connString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
conn.SetMaxOpenConns(25)
|
||||
conn.SetMaxIdleConns(5)
|
||||
conn.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
// 测试连接
|
||||
if err := conn.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return &DB{conn: conn}, nil
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
if db.conn != nil {
|
||||
return db.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConn 获取底层连接(用于事务等高级操作)
|
||||
func (db *DB) GetConn() *sql.DB {
|
||||
return db.conn
|
||||
}
|
||||
|
||||
// InitSchema 初始化数据库 schema
|
||||
func (db *DB) InitSchema() error {
|
||||
schema := `
|
||||
-- Users 表
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
password_hash TEXT NOT NULL,
|
||||
email VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
|
||||
|
||||
-- Clusters 表
|
||||
CREATE TABLE IF NOT EXISTS clusters (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL UNIQUE,
|
||||
host TEXT NOT NULL,
|
||||
ca_data TEXT,
|
||||
cert_data TEXT,
|
||||
key_data TEXT,
|
||||
token TEXT,
|
||||
description TEXT,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_clusters_name ON clusters(name);
|
||||
|
||||
-- Registries 表
|
||||
CREATE TABLE IF NOT EXISTS registries (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL UNIQUE,
|
||||
url TEXT NOT NULL,
|
||||
description TEXT,
|
||||
username VARCHAR(255),
|
||||
password TEXT,
|
||||
insecure BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_registries_name ON registries(name);
|
||||
|
||||
-- Instances 表
|
||||
CREATE TABLE IF NOT EXISTS instances (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
cluster_id VARCHAR(36) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
namespace VARCHAR(255) NOT NULL,
|
||||
registry_id VARCHAR(36) NOT NULL,
|
||||
repository TEXT NOT NULL,
|
||||
chart VARCHAR(255) NOT NULL,
|
||||
version VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
values JSONB,
|
||||
values_yaml TEXT,
|
||||
status VARCHAR(50) NOT NULL,
|
||||
status_reason TEXT,
|
||||
last_operation VARCHAR(50),
|
||||
last_error TEXT,
|
||||
revision INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT fk_cluster FOREIGN KEY (cluster_id) REFERENCES clusters(id) ON DELETE CASCADE,
|
||||
CONSTRAINT fk_registry FOREIGN KEY (registry_id) REFERENCES registries(id) ON DELETE CASCADE,
|
||||
CONSTRAINT unique_cluster_name UNIQUE (cluster_id, name, namespace)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_instances_cluster ON instances(cluster_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_instances_registry ON instances(registry_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_instances_name ON instances(name);
|
||||
`
|
||||
|
||||
_, err := db.conn.Exec(schema)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize schema: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,433 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
)
|
||||
|
||||
// InstanceRepository PostgreSQL 实例仓储实现
|
||||
type InstanceRepository struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewInstanceRepository 创建 PostgreSQL 实例仓储
|
||||
func NewInstanceRepository(db *DB) repository.InstanceRepository {
|
||||
return &InstanceRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建实例
|
||||
func (r *InstanceRepository) Create(ctx context.Context, instance *entity.Instance) error {
|
||||
if instance.ID == "" {
|
||||
instance.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 将 Values 转换为 JSON
|
||||
valuesJSON, err := json.Marshal(instance.Values)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal values: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO instances (id, cluster_id, name, namespace, registry_id, repository, chart, version,
|
||||
description, values, values_yaml, status, status_reason, last_operation, last_error,
|
||||
revision, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)
|
||||
`
|
||||
|
||||
_, err = r.db.conn.ExecContext(ctx, query,
|
||||
instance.ID,
|
||||
instance.ClusterID,
|
||||
instance.Name,
|
||||
instance.Namespace,
|
||||
instance.RegistryID,
|
||||
instance.Repository,
|
||||
instance.Chart,
|
||||
instance.Version,
|
||||
instance.Description,
|
||||
valuesJSON,
|
||||
instance.ValuesYAML,
|
||||
instance.Status,
|
||||
instance.StatusReason,
|
||||
instance.LastOperation,
|
||||
instance.LastError,
|
||||
instance.Revision,
|
||||
instance.CreatedAt,
|
||||
instance.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create instance: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取实例
|
||||
func (r *InstanceRepository) GetByID(ctx context.Context, id string) (*entity.Instance, error) {
|
||||
query := `
|
||||
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
|
||||
description, values, values_yaml, status, status_reason, last_operation, last_error,
|
||||
revision, created_at, updated_at
|
||||
FROM instances
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
instance := &entity.Instance{}
|
||||
var (
|
||||
valuesJSON []byte
|
||||
statusReason sql.NullString
|
||||
lastOperation sql.NullString
|
||||
lastError sql.NullString
|
||||
)
|
||||
|
||||
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
|
||||
&instance.ID,
|
||||
&instance.ClusterID,
|
||||
&instance.Name,
|
||||
&instance.Namespace,
|
||||
&instance.RegistryID,
|
||||
&instance.Repository,
|
||||
&instance.Chart,
|
||||
&instance.Version,
|
||||
&instance.Description,
|
||||
&valuesJSON,
|
||||
&instance.ValuesYAML,
|
||||
&instance.Status,
|
||||
&statusReason,
|
||||
&lastOperation,
|
||||
&lastError,
|
||||
&instance.Revision,
|
||||
&instance.CreatedAt,
|
||||
&instance.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, entity.ErrInstanceNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get instance: %w", err)
|
||||
}
|
||||
|
||||
// 解析 JSON Values
|
||||
if len(valuesJSON) > 0 {
|
||||
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if statusReason.Valid {
|
||||
instance.StatusReason = statusReason.String
|
||||
}
|
||||
if lastOperation.Valid {
|
||||
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
|
||||
}
|
||||
if lastError.Valid {
|
||||
instance.LastError = lastError.String
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// GetByClusterAndName 根据集群 ID 和名称获取实例
|
||||
func (r *InstanceRepository) GetByClusterAndName(ctx context.Context, clusterID, name string) (*entity.Instance, error) {
|
||||
query := `
|
||||
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
|
||||
description, values, values_yaml, status, status_reason, last_operation, last_error,
|
||||
revision, created_at, updated_at
|
||||
FROM instances
|
||||
WHERE cluster_id = $1 AND name = $2
|
||||
`
|
||||
|
||||
instance := &entity.Instance{}
|
||||
var (
|
||||
valuesJSON []byte
|
||||
statusReason sql.NullString
|
||||
lastOperation sql.NullString
|
||||
lastError sql.NullString
|
||||
)
|
||||
|
||||
err := r.db.conn.QueryRowContext(ctx, query, clusterID, name).Scan(
|
||||
&instance.ID,
|
||||
&instance.ClusterID,
|
||||
&instance.Name,
|
||||
&instance.Namespace,
|
||||
&instance.RegistryID,
|
||||
&instance.Repository,
|
||||
&instance.Chart,
|
||||
&instance.Version,
|
||||
&instance.Description,
|
||||
&valuesJSON,
|
||||
&instance.ValuesYAML,
|
||||
&instance.Status,
|
||||
&statusReason,
|
||||
&lastOperation,
|
||||
&lastError,
|
||||
&instance.Revision,
|
||||
&instance.CreatedAt,
|
||||
&instance.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, entity.ErrInstanceNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get instance: %w", err)
|
||||
}
|
||||
|
||||
// 解析 JSON Values
|
||||
if len(valuesJSON) > 0 {
|
||||
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if statusReason.Valid {
|
||||
instance.StatusReason = statusReason.String
|
||||
}
|
||||
if lastOperation.Valid {
|
||||
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
|
||||
}
|
||||
if lastError.Valid {
|
||||
instance.LastError = lastError.String
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Update 更新实例
|
||||
func (r *InstanceRepository) Update(ctx context.Context, instance *entity.Instance) error {
|
||||
instance.UpdatedAt = time.Now()
|
||||
|
||||
// 将 Values 转换为 JSON
|
||||
valuesJSON, err := json.Marshal(instance.Values)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal values: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE instances
|
||||
SET cluster_id = $1, name = $2, namespace = $3, registry_id = $4, repository = $5,
|
||||
chart = $6, version = $7, description = $8, values = $9, values_yaml = $10,
|
||||
status = $11, status_reason = $12, last_operation = $13, last_error = $14,
|
||||
revision = $15, updated_at = $16
|
||||
WHERE id = $17
|
||||
`
|
||||
|
||||
result, err := r.db.conn.ExecContext(ctx, query,
|
||||
instance.ClusterID,
|
||||
instance.Name,
|
||||
instance.Namespace,
|
||||
instance.RegistryID,
|
||||
instance.Repository,
|
||||
instance.Chart,
|
||||
instance.Version,
|
||||
instance.Description,
|
||||
valuesJSON,
|
||||
instance.ValuesYAML,
|
||||
instance.Status,
|
||||
instance.StatusReason,
|
||||
instance.LastOperation,
|
||||
instance.LastError,
|
||||
instance.Revision,
|
||||
instance.UpdatedAt,
|
||||
instance.ID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update instance: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除实例
|
||||
func (r *InstanceRepository) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM instances WHERE id = $1`
|
||||
|
||||
result, err := r.db.conn.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete instance: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return entity.ErrInstanceNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListByCluster 列出指定集群的所有实例
|
||||
func (r *InstanceRepository) ListByCluster(ctx context.Context, clusterID string) ([]*entity.Instance, error) {
|
||||
query := `
|
||||
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
|
||||
description, values, values_yaml, status, status_reason, last_operation, last_error,
|
||||
revision, created_at, updated_at
|
||||
FROM instances
|
||||
WHERE cluster_id = $1
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query, clusterID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list instances: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
instances := make([]*entity.Instance, 0)
|
||||
for rows.Next() {
|
||||
instance := &entity.Instance{}
|
||||
var (
|
||||
valuesJSON []byte
|
||||
statusReason sql.NullString
|
||||
lastOperation sql.NullString
|
||||
lastError sql.NullString
|
||||
)
|
||||
|
||||
err := rows.Scan(
|
||||
&instance.ID,
|
||||
&instance.ClusterID,
|
||||
&instance.Name,
|
||||
&instance.Namespace,
|
||||
&instance.RegistryID,
|
||||
&instance.Repository,
|
||||
&instance.Chart,
|
||||
&instance.Version,
|
||||
&instance.Description,
|
||||
&valuesJSON,
|
||||
&instance.ValuesYAML,
|
||||
&instance.Status,
|
||||
&statusReason,
|
||||
&lastOperation,
|
||||
&lastError,
|
||||
&instance.Revision,
|
||||
&instance.CreatedAt,
|
||||
&instance.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan instance: %w", err)
|
||||
}
|
||||
|
||||
// 解析 JSON Values
|
||||
if len(valuesJSON) > 0 {
|
||||
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if statusReason.Valid {
|
||||
instance.StatusReason = statusReason.String
|
||||
}
|
||||
if lastOperation.Valid {
|
||||
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
|
||||
}
|
||||
if lastError.Valid {
|
||||
instance.LastError = lastError.String
|
||||
}
|
||||
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("rows iteration error: %w", err)
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// List 列出所有实例
|
||||
func (r *InstanceRepository) List(ctx context.Context) ([]*entity.Instance, error) {
|
||||
query := `
|
||||
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
|
||||
description, values, values_yaml, status, status_reason, last_operation, last_error,
|
||||
revision, created_at, updated_at
|
||||
FROM instances
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list instances: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
instances := make([]*entity.Instance, 0)
|
||||
for rows.Next() {
|
||||
instance := &entity.Instance{}
|
||||
var (
|
||||
valuesJSON []byte
|
||||
statusReason sql.NullString
|
||||
lastOperation sql.NullString
|
||||
lastError sql.NullString
|
||||
)
|
||||
|
||||
err := rows.Scan(
|
||||
&instance.ID,
|
||||
&instance.ClusterID,
|
||||
&instance.Name,
|
||||
&instance.Namespace,
|
||||
&instance.RegistryID,
|
||||
&instance.Repository,
|
||||
&instance.Chart,
|
||||
&instance.Version,
|
||||
&instance.Description,
|
||||
&valuesJSON,
|
||||
&instance.ValuesYAML,
|
||||
&instance.Status,
|
||||
&statusReason,
|
||||
&lastOperation,
|
||||
&lastError,
|
||||
&instance.Revision,
|
||||
&instance.CreatedAt,
|
||||
&instance.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan instance: %w", err)
|
||||
}
|
||||
|
||||
// 解析 JSON Values
|
||||
if len(valuesJSON) > 0 {
|
||||
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if statusReason.Valid {
|
||||
instance.StatusReason = statusReason.String
|
||||
}
|
||||
if lastOperation.Valid {
|
||||
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
|
||||
}
|
||||
if lastError.Valid {
|
||||
instance.LastError = lastError.String
|
||||
}
|
||||
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("rows iteration error: %w", err)
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
@ -0,0 +1,257 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
"github.com/ocdp/cluster-service/internal/pkg/crypto"
|
||||
)
|
||||
|
||||
// RegistryRepository PostgreSQL Registry 仓储实现
|
||||
type RegistryRepository struct {
|
||||
db *DB
|
||||
encryptor crypto.Encryptor
|
||||
}
|
||||
|
||||
// NewRegistryRepository 创建 PostgreSQL Registry 仓储
|
||||
func NewRegistryRepository(db *DB, encryptor crypto.Encryptor) repository.RegistryRepository {
|
||||
return &RegistryRepository{
|
||||
db: db,
|
||||
encryptor: encryptor,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建 Registry
|
||||
func (r *RegistryRepository) Create(ctx context.Context, registry *entity.Registry) error {
|
||||
if registry.ID == "" {
|
||||
registry.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 加密密码
|
||||
encryptedPassword, err := r.encryptor.Encrypt(registry.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt password: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO registries (id, name, url, description, username, password, insecure, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`
|
||||
|
||||
_, err = r.db.conn.ExecContext(ctx, query,
|
||||
registry.ID,
|
||||
registry.Name,
|
||||
registry.URL,
|
||||
registry.Description,
|
||||
registry.Username,
|
||||
encryptedPassword,
|
||||
registry.Insecure,
|
||||
registry.CreatedAt,
|
||||
registry.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create registry: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取 Registry
|
||||
func (r *RegistryRepository) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
|
||||
query := `
|
||||
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
|
||||
FROM registries
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
registry := &entity.Registry{}
|
||||
var encryptedPassword string
|
||||
|
||||
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
|
||||
®istry.ID,
|
||||
®istry.Name,
|
||||
®istry.URL,
|
||||
®istry.Description,
|
||||
®istry.Username,
|
||||
&encryptedPassword,
|
||||
®istry.Insecure,
|
||||
®istry.CreatedAt,
|
||||
®istry.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, entity.ErrRegistryNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get registry: %w", err)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
// GetByName 根据名称获取 Registry
|
||||
func (r *RegistryRepository) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
|
||||
query := `
|
||||
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
|
||||
FROM registries
|
||||
WHERE name = $1
|
||||
`
|
||||
|
||||
registry := &entity.Registry{}
|
||||
var encryptedPassword string
|
||||
|
||||
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
|
||||
®istry.ID,
|
||||
®istry.Name,
|
||||
®istry.URL,
|
||||
®istry.Description,
|
||||
®istry.Username,
|
||||
&encryptedPassword,
|
||||
®istry.Insecure,
|
||||
®istry.CreatedAt,
|
||||
®istry.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, entity.ErrRegistryNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get registry: %w", err)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
// Update 更新 Registry
|
||||
func (r *RegistryRepository) Update(ctx context.Context, registry *entity.Registry) error {
|
||||
registry.UpdatedAt = time.Now()
|
||||
|
||||
// 加密密码
|
||||
encryptedPassword, err := r.encryptor.Encrypt(registry.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt password: %w", err)
|
||||
}
|
||||
|
||||
query := `
|
||||
UPDATE registries
|
||||
SET name = $1, url = $2, description = $3, username = $4, password = $5,
|
||||
insecure = $6, updated_at = $7
|
||||
WHERE id = $8
|
||||
`
|
||||
|
||||
result, err := r.db.conn.ExecContext(ctx, query,
|
||||
registry.Name,
|
||||
registry.URL,
|
||||
registry.Description,
|
||||
registry.Username,
|
||||
encryptedPassword,
|
||||
registry.Insecure,
|
||||
registry.UpdatedAt,
|
||||
registry.ID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update registry: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return entity.ErrRegistryNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除 Registry
|
||||
func (r *RegistryRepository) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM registries WHERE id = $1`
|
||||
|
||||
result, err := r.db.conn.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete registry: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return entity.ErrRegistryNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 列出所有 Registries
|
||||
func (r *RegistryRepository) List(ctx context.Context) ([]*entity.Registry, error) {
|
||||
query := `
|
||||
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
|
||||
FROM registries
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list registries: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
registries := make([]*entity.Registry, 0)
|
||||
for rows.Next() {
|
||||
registry := &entity.Registry{}
|
||||
var encryptedPassword string
|
||||
|
||||
err := rows.Scan(
|
||||
®istry.ID,
|
||||
®istry.Name,
|
||||
®istry.URL,
|
||||
®istry.Description,
|
||||
®istry.Username,
|
||||
&encryptedPassword,
|
||||
®istry.Insecure,
|
||||
®istry.CreatedAt,
|
||||
®istry.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan registry: %w", err)
|
||||
}
|
||||
|
||||
// 解密密码
|
||||
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
registries = append(registries, registry)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("rows iteration error: %w", err)
|
||||
}
|
||||
|
||||
return registries, nil
|
||||
}
|
||||
|
||||
@ -0,0 +1,204 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/ocdp/cluster-service/internal/domain/entity"
|
||||
"github.com/ocdp/cluster-service/internal/domain/repository"
|
||||
)
|
||||
|
||||
// UserRepository PostgreSQL 用户仓储实现
|
||||
type UserRepository struct {
|
||||
db *DB
|
||||
}
|
||||
|
||||
// NewUserRepository 创建 PostgreSQL 用户仓储
|
||||
func NewUserRepository(db *DB) repository.UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (r *UserRepository) Create(ctx context.Context, user *entity.User) error {
|
||||
if user.ID == "" {
|
||||
user.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO users (id, username, password_hash, email, revoked_after, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
`
|
||||
|
||||
_, err := r.db.conn.ExecContext(ctx, query,
|
||||
user.ID,
|
||||
user.Username,
|
||||
user.PasswordHash,
|
||||
user.Email,
|
||||
user.RevokedAfter,
|
||||
user.CreatedAt,
|
||||
user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取用户
|
||||
func (r *UserRepository) GetByID(ctx context.Context, id string) (*entity.User, error) {
|
||||
query := `
|
||||
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
|
||||
FROM users
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
user := &entity.User{}
|
||||
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.PasswordHash,
|
||||
&user.Email,
|
||||
&user.RevokedAfter,
|
||||
&user.CreatedAt,
|
||||
&user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, entity.ErrUserNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetByUsername 根据用户名获取用户
|
||||
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
|
||||
query := `
|
||||
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
|
||||
FROM users
|
||||
WHERE username = $1
|
||||
`
|
||||
|
||||
user := &entity.User{}
|
||||
err := r.db.conn.QueryRowContext(ctx, query, username).Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.PasswordHash,
|
||||
&user.Email,
|
||||
&user.RevokedAfter,
|
||||
&user.CreatedAt,
|
||||
&user.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, entity.ErrUserNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (r *UserRepository) Update(ctx context.Context, user *entity.User) error {
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
query := `
|
||||
UPDATE users
|
||||
SET username = $1, password_hash = $2, email = $3, revoked_after = $4, updated_at = $5
|
||||
WHERE id = $6
|
||||
`
|
||||
|
||||
result, err := r.db.conn.ExecContext(ctx, query,
|
||||
user.Username,
|
||||
user.PasswordHash,
|
||||
user.Email,
|
||||
user.RevokedAfter,
|
||||
user.UpdatedAt,
|
||||
user.ID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return entity.ErrUserNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (r *UserRepository) Delete(ctx context.Context, id string) error {
|
||||
query := `DELETE FROM users WHERE id = $1`
|
||||
|
||||
result, err := r.db.conn.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get affected rows: %w", err)
|
||||
}
|
||||
|
||||
if rows == 0 {
|
||||
return entity.ErrUserNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 列出所有用户
|
||||
func (r *UserRepository) List(ctx context.Context) ([]*entity.User, error) {
|
||||
query := `
|
||||
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
|
||||
FROM users
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.conn.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
users := make([]*entity.User, 0)
|
||||
for rows.Next() {
|
||||
user := &entity.User{}
|
||||
err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&user.PasswordHash,
|
||||
&user.Email,
|
||||
&user.RevokedAfter,
|
||||
&user.CreatedAt,
|
||||
&user.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan user: %w", err)
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("rows iteration error: %w", err)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user