This commit is contained in:
mangomqy
2025-11-13 02:54:06 +00:00
commit c5e51ed069
254 changed files with 54901 additions and 0 deletions

View File

@ -0,0 +1,220 @@
package output
import (
"fmt"
helmMock "github.com/ocdp/cluster-service/internal/adapter/output/helm/mock"
helmReal "github.com/ocdp/cluster-service/internal/adapter/output/helm/real"
"github.com/ocdp/cluster-service/internal/adapter/output/k8s"
ociMock "github.com/ocdp/cluster-service/internal/adapter/output/oci/mock"
ociReal "github.com/ocdp/cluster-service/internal/adapter/output/oci/real"
"github.com/ocdp/cluster-service/internal/adapter/output/persistence/mock"
"github.com/ocdp/cluster-service/internal/adapter/output/persistence/postgres"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// AdapterMode 适配器模式
type AdapterMode string
const (
ModeMock AdapterMode = "mock" // Mock 模式(内存存储,用于开发调试)
// 默认模式:连接真实 PostgreSQL 和服务(任何非 "mock" 的值都是默认模式)
)
// AdapterFactory 适配器工厂
// 用于创建所有 Output Adapters支持 Mock 和真实实现切换
type AdapterFactory struct {
mode AdapterMode
encryptor crypto.Encryptor // 加密器(用于敏感数据加密)
// 数据库连接字符串(非 Mock 模式需要)
dbConnString string
// 数据库连接(非 Mock 模式)
db *postgres.DB
}
// NewAdapterFactory 创建适配器工厂
func NewAdapterFactory(mode AdapterMode, encryptor crypto.Encryptor, dbConnString string) *AdapterFactory {
return &AdapterFactory{
mode: mode,
encryptor: encryptor,
dbConnString: dbConnString,
}
}
// CreateUserRepository 创建用户仓储
func (f *AdapterFactory) CreateUserRepository() (repository.UserRepository, error) {
if f.mode == ModeMock {
return mock.NewUserRepositoryMock(), nil
}
// 默认真实实现PostgreSQL
if err := f.ensureDBConnection(); err != nil {
return nil, err
}
return postgres.NewUserRepository(f.db), nil
}
// CreateClusterRepository 创建集群仓储
func (f *AdapterFactory) CreateClusterRepository() (repository.ClusterRepository, error) {
if f.mode == ModeMock {
return mock.NewClusterRepositoryMock(f.encryptor), nil
}
// 默认真实实现PostgreSQL
if err := f.ensureDBConnection(); err != nil {
return nil, err
}
return postgres.NewClusterRepository(f.db, f.encryptor), nil
}
// CreateRegistryRepository 创建 Registry 仓储
func (f *AdapterFactory) CreateRegistryRepository() (repository.RegistryRepository, error) {
if f.mode == ModeMock {
return mock.NewRegistryRepositoryMock(f.encryptor), nil
}
// 默认真实实现PostgreSQL
if err := f.ensureDBConnection(); err != nil {
return nil, err
}
return postgres.NewRegistryRepository(f.db, f.encryptor), nil
}
// CreateInstanceRepository 创建实例仓储
func (f *AdapterFactory) CreateInstanceRepository() (repository.InstanceRepository, error) {
if f.mode == ModeMock {
return mock.NewInstanceRepositoryMock(), nil
}
// 默认真实实现PostgreSQL
if err := f.ensureDBConnection(); err != nil {
return nil, err
}
return postgres.NewInstanceRepository(f.db), nil
}
// CreateOCIClient 创建 OCI 客户端
func (f *AdapterFactory) CreateOCIClient() (repository.OCIClient, error) {
if f.mode == ModeMock {
return ociMock.NewOCIClientMock(), nil
}
// 默认真实实现ORAS SDK
return ociReal.NewOCIClient(), nil
}
// CreateHelmClient 创建 Helm 客户端
func (f *AdapterFactory) CreateHelmClient() (repository.HelmClient, error) {
if f.mode == ModeMock {
return helmMock.NewHelmClientMock(), nil
}
// 默认真实实现Helm SDK
return helmReal.NewHelmClient(), nil
}
// CreateMetricsClient 创建 Metrics 客户端
func (f *AdapterFactory) CreateMetricsClient(clusterRepo repository.ClusterRepository) repository.MetricsClient {
// Metrics client 总是使用真实的 Kubernetes API
return k8s.NewMetricsClient(clusterRepo)
}
// CreateEntryClient 创建实例入口查询客户端
func (f *AdapterFactory) CreateEntryClient() repository.InstanceEntryClient {
return k8s.NewEntryClient()
}
// CreateAllRepositories 一次性创建所有 Repositories
func (f *AdapterFactory) CreateAllRepositories() (*Repositories, error) {
userRepo, err := f.CreateUserRepository()
if err != nil {
return nil, fmt.Errorf("failed to create user repository: %w", err)
}
clusterRepo, err := f.CreateClusterRepository()
if err != nil {
return nil, fmt.Errorf("failed to create cluster repository: %w", err)
}
registryRepo, err := f.CreateRegistryRepository()
if err != nil {
return nil, fmt.Errorf("failed to create registry repository: %w", err)
}
instanceRepo, err := f.CreateInstanceRepository()
if err != nil {
return nil, fmt.Errorf("failed to create instance repository: %w", err)
}
ociClient, err := f.CreateOCIClient()
if err != nil {
return nil, fmt.Errorf("failed to create OCI client: %w", err)
}
helmClient, err := f.CreateHelmClient()
if err != nil {
return nil, fmt.Errorf("failed to create Helm client: %w", err)
}
// 创建 Metrics client依赖 clusterRepo
metricsClient := f.CreateMetricsClient(clusterRepo)
entryClient := f.CreateEntryClient()
return &Repositories{
UserRepo: userRepo,
ClusterRepo: clusterRepo,
RegistryRepo: registryRepo,
InstanceRepo: instanceRepo,
OCIClient: ociClient,
HelmClient: helmClient,
MetricsClient: metricsClient,
EntryClient: entryClient,
}, nil
}
// Repositories 所有仓储的集合
type Repositories struct {
UserRepo repository.UserRepository
ClusterRepo repository.ClusterRepository
RegistryRepo repository.RegistryRepository
InstanceRepo repository.InstanceRepository
OCIClient repository.OCIClient
HelmClient repository.HelmClient
MetricsClient repository.MetricsClient
EntryClient repository.InstanceEntryClient
}
// ensureDBConnection 确保数据库连接已建立
func (f *AdapterFactory) ensureDBConnection() error {
if f.db != nil {
return nil
}
if f.dbConnString == "" {
return fmt.Errorf("database connection string is required (set DATABASE_URL environment variable)")
}
db, err := postgres.NewDB(f.dbConnString)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
// 初始化数据库 schema
if err := db.InitSchema(); err != nil {
return fmt.Errorf("failed to initialize database schema: %w", err)
}
f.db = db
return nil
}
// Close 关闭工厂资源
func (f *AdapterFactory) Close() error {
if f.db != nil {
return f.db.Close()
}
return nil
}

View File

@ -0,0 +1,196 @@
package mock
import (
"context"
"fmt"
"time"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// HelmClientMock Helm 客户端 Mock 实现
type HelmClientMock struct {
// Mock 数据存储
releases map[string]map[string]*entity.Instance // clusterID -> releaseName -> instance
history map[string]map[string][]*entity.ReleaseHistory // clusterID -> releaseName -> []history
}
// NewHelmClientMock 创建 Mock 实现
func NewHelmClientMock() repository.HelmClient {
return &HelmClientMock{
releases: make(map[string]map[string]*entity.Instance),
history: make(map[string]map[string][]*entity.ReleaseHistory),
}
}
func (c *HelmClientMock) Install(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
// 初始化集群数据
if c.releases[cluster.ID] == nil {
c.releases[cluster.ID] = make(map[string]*entity.Instance)
c.history[cluster.ID] = make(map[string][]*entity.ReleaseHistory)
}
// 检查是否已存在
key := fmt.Sprintf("%s/%s", instance.Namespace, instance.Name)
if _, exists := c.releases[cluster.ID][key]; exists {
return entity.ErrInstanceExists
}
// Mock 安装
instance.Status = entity.StatusDeployed
instance.Revision = 1
instance.UpdatedAt = time.Now()
c.releases[cluster.ID][key] = instance
// 添加历史记录
c.history[cluster.ID][key] = []*entity.ReleaseHistory{
{
Revision: 1,
Updated: time.Now(),
Status: entity.StatusDeployed,
Chart: fmt.Sprintf("%s-%s", instance.Chart, instance.Version),
AppVersion: instance.Version,
Description: "Install complete",
},
}
return nil
}
func (c *HelmClientMock) Upgrade(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
key := fmt.Sprintf("%s/%s", instance.Namespace, instance.Name)
existing, exists := c.releases[cluster.ID][key]
if !exists {
return entity.ErrInstanceNotFound
}
// Mock 升级
instance.Revision = existing.Revision + 1
instance.Status = entity.StatusDeployed
instance.UpdatedAt = time.Now()
c.releases[cluster.ID][key] = instance
// 添加历史记录
history := &entity.ReleaseHistory{
Revision: instance.Revision,
Updated: time.Now(),
Status: entity.StatusDeployed,
Chart: fmt.Sprintf("%s-%s", instance.Chart, instance.Version),
AppVersion: instance.Version,
Description: "Upgrade complete",
}
c.history[cluster.ID][key] = append(c.history[cluster.ID][key], history)
return nil
}
func (c *HelmClientMock) Uninstall(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) error {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
if _, exists := c.releases[cluster.ID][key]; !exists {
return entity.ErrInstanceNotFound
}
// Mock 卸载
delete(c.releases[cluster.ID], key)
return nil
}
func (c *HelmClientMock) Rollback(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string, revision int) error {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
instance, exists := c.releases[cluster.ID][key]
if !exists {
return entity.ErrInstanceNotFound
}
// 检查历史记录是否存在
histories := c.history[cluster.ID][key]
if revision > len(histories) || revision < 1 {
return fmt.Errorf("revision %d not found", revision)
}
// Mock 回滚
instance.Revision = len(histories) + 1
instance.Status = entity.StatusDeployed
instance.UpdatedAt = time.Now()
c.releases[cluster.ID][key] = instance
// 添加回滚历史记录
history := &entity.ReleaseHistory{
Revision: instance.Revision,
Updated: time.Now(),
Status: entity.StatusDeployed,
Chart: instance.Chart,
AppVersion: instance.Version,
Description: fmt.Sprintf("Rollback to revision %d", revision),
}
c.history[cluster.ID][key] = append(c.history[cluster.ID][key], history)
return nil
}
func (c *HelmClientMock) GetStatus(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (*entity.Instance, error) {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
instance, exists := c.releases[cluster.ID][key]
if !exists {
return nil, entity.ErrInstanceNotFound
}
return instance, nil
}
func (c *HelmClientMock) GetHistory(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) ([]*entity.ReleaseHistory, error) {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
if _, exists := c.releases[cluster.ID][key]; !exists {
return nil, entity.ErrInstanceNotFound
}
histories := c.history[cluster.ID][key]
if histories == nil {
return []*entity.ReleaseHistory{}, nil
}
return histories, nil
}
func (c *HelmClientMock) List(ctx context.Context, cluster *entity.Cluster, namespace string) ([]*entity.Instance, error) {
clusterReleases := c.releases[cluster.ID]
if clusterReleases == nil {
return []*entity.Instance{}, nil
}
instances := make([]*entity.Instance, 0)
for key, instance := range clusterReleases {
// 如果指定了 namespace只返回该 namespace 的
if namespace != "" && namespace != "all" {
keyNamespace := instance.Namespace
if keyNamespace != namespace {
continue
}
}
instances = append(instances, c.releases[cluster.ID][key])
}
return instances, nil
}
func (c *HelmClientMock) GetValues(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (map[string]interface{}, error) {
key := fmt.Sprintf("%s/%s", namespace, releaseName)
instance, exists := c.releases[cluster.ID][key]
if !exists {
return nil, entity.ErrInstanceNotFound
}
return instance.Values, nil
}

View File

@ -0,0 +1,313 @@
package real
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"helm.sh/helm/v3/pkg/action"
"helm.sh/helm/v3/pkg/chart/loader"
"helm.sh/helm/v3/pkg/cli"
"helm.sh/helm/v3/pkg/release"
"helm.sh/helm/v3/pkg/storage/driver"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/client-go/discovery"
"k8s.io/client-go/discovery/cached/memory"
"k8s.io/client-go/rest"
"k8s.io/client-go/restmapper"
"k8s.io/client-go/tools/clientcmd"
)
// HelmClient 真实的 Helm 客户端实现
type HelmClient struct {
settings *cli.EnvSettings
}
// NewHelmClient 创建真实的 Helm 客户端
func NewHelmClient() repository.HelmClient {
return &HelmClient{
settings: cli.New(),
}
}
// getActionConfig 获取 Helm action configuration
func (h *HelmClient) getActionConfig(cluster *entity.Cluster, namespace string) (*action.Configuration, error) {
actionConfig := new(action.Configuration)
// 创建临时 kubeconfig 文件
kubeconfigContent := cluster.GetKubeConfig()
tmpDir, err := os.MkdirTemp("", "helm-kubeconfig-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp dir: %w", err)
}
kubeconfigPath := filepath.Join(tmpDir, "kubeconfig")
if err := os.WriteFile(kubeconfigPath, []byte(kubeconfigContent), 0600); err != nil {
return nil, fmt.Errorf("failed to write kubeconfig: %w", err)
}
// 使用 kubeconfig 初始化 action config
if err := actionConfig.Init(
&kubeconfigGetter{kubeconfigPath: kubeconfigPath},
namespace,
os.Getenv("HELM_DRIVER"), // storage driver: configmap, secret, memory
func(format string, v ...interface{}) {
// Log function
},
); err != nil {
return nil, fmt.Errorf("failed to initialize action config: %w", err)
}
return actionConfig, nil
}
// kubeconfigGetter implements RESTClientGetter
type kubeconfigGetter struct {
kubeconfigPath string
}
func (k *kubeconfigGetter) ToRESTConfig() (*rest.Config, error) {
return clientcmd.BuildConfigFromFlags("", k.kubeconfigPath)
}
func (k *kubeconfigGetter) ToDiscoveryClient() (discovery.CachedDiscoveryInterface, error) {
config, err := k.ToRESTConfig()
if err != nil {
return nil, err
}
discoveryClient := discovery.NewDiscoveryClientForConfigOrDie(config)
// Wrap in a memory cache
return memory.NewMemCacheClient(discoveryClient), nil
}
func (k *kubeconfigGetter) ToRESTMapper() (meta.RESTMapper, error) {
discoveryClient, err := k.ToDiscoveryClient()
if err != nil {
return nil, err
}
mapper := restmapper.NewDeferredDiscoveryRESTMapper(discoveryClient)
return mapper, nil
}
func (k *kubeconfigGetter) ToRawKubeConfigLoader() clientcmd.ClientConfig {
return clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
&clientcmd.ClientConfigLoadingRules{ExplicitPath: k.kubeconfigPath},
&clientcmd.ConfigOverrides{},
)
}
// Install 安装 Helm Chart
func (h *HelmClient) Install(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
actionConfig, err := h.getActionConfig(cluster, instance.Namespace)
if err != nil {
return err
}
install := action.NewInstall(actionConfig)
install.ReleaseName = instance.Name
install.Namespace = instance.Namespace
install.CreateNamespace = true
install.Wait = true
install.Timeout = 5 * time.Minute
// 加载 Chart从本地路径或 OCI registry
// 这里简化处理,假设 chart 已经被拉取到本地
chartPath := fmt.Sprintf("/tmp/charts/%s-%s.tgz", instance.Chart, instance.Version)
chart, err := loader.Load(chartPath)
if err != nil {
return fmt.Errorf("failed to load chart: %w", err)
}
// 执行安装
rel, err := install.Run(chart, instance.Values)
if err != nil {
return fmt.Errorf("failed to install release: %w", err)
}
// 更新 revision状态由调用方根据操作结果设置
instance.Revision = rel.Version
// 注意:不在这里设置 Status让调用方通过 MarkSuccess/MarkFailure 来设置
return nil
}
// Upgrade 升级 Helm Release
func (h *HelmClient) Upgrade(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
actionConfig, err := h.getActionConfig(cluster, instance.Namespace)
if err != nil {
return err
}
upgrade := action.NewUpgrade(actionConfig)
upgrade.Namespace = instance.Namespace
upgrade.Wait = true
upgrade.Timeout = 5 * time.Minute
// 加载 Chart
chartPath := fmt.Sprintf("/tmp/charts/%s-%s.tgz", instance.Chart, instance.Version)
chart, err := loader.Load(chartPath)
if err != nil {
return fmt.Errorf("failed to load chart: %w", err)
}
// 执行升级
rel, err := upgrade.Run(instance.Name, chart, instance.Values)
if err != nil {
return fmt.Errorf("failed to upgrade release: %w", err)
}
// 更新 revision状态由调用方根据操作结果设置
instance.Revision = rel.Version
// 注意:不在这里设置 Status让调用方通过 MarkSuccess/MarkFailure 来设置
return nil
}
// Uninstall 卸载 Helm Release
func (h *HelmClient) Uninstall(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) error {
actionConfig, err := h.getActionConfig(cluster, namespace)
if err != nil {
return err
}
uninstall := action.NewUninstall(actionConfig)
uninstall.Wait = true
uninstall.Timeout = 5 * time.Minute
_, err = uninstall.Run(releaseName)
if err != nil {
if errors.Is(err, driver.ErrReleaseNotFound) {
return entity.ErrInstanceNotFound
}
return fmt.Errorf("failed to uninstall release: %w", err)
}
return nil
}
// Rollback 回滚 Helm Release
func (h *HelmClient) Rollback(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string, revision int) error {
actionConfig, err := h.getActionConfig(cluster, namespace)
if err != nil {
return err
}
rollback := action.NewRollback(actionConfig)
rollback.Version = revision
rollback.Wait = true
rollback.Timeout = 5 * time.Minute
if err := rollback.Run(releaseName); err != nil {
return fmt.Errorf("failed to rollback release: %w", err)
}
return nil
}
// GetStatus 获取 Release 状态
func (h *HelmClient) GetStatus(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (*entity.Instance, error) {
actionConfig, err := h.getActionConfig(cluster, namespace)
if err != nil {
return nil, err
}
status := action.NewStatus(actionConfig)
rel, err := status.Run(releaseName)
if err != nil {
return nil, fmt.Errorf("failed to get release status: %w", err)
}
return h.convertReleaseToInstance(rel), nil
}
// GetHistory 获取 Release 历史
func (h *HelmClient) GetHistory(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) ([]*entity.ReleaseHistory, error) {
actionConfig, err := h.getActionConfig(cluster, namespace)
if err != nil {
return nil, err
}
history := action.NewHistory(actionConfig)
history.Max = 256
releases, err := history.Run(releaseName)
if err != nil {
return nil, fmt.Errorf("failed to get release history: %w", err)
}
result := make([]*entity.ReleaseHistory, 0, len(releases))
for _, rel := range releases {
result = append(result, &entity.ReleaseHistory{
Revision: rel.Version,
Updated: rel.Info.LastDeployed.Time,
Status: entity.InstanceStatus(rel.Info.Status),
Chart: rel.Chart.Metadata.Name,
AppVersion: rel.Chart.Metadata.AppVersion,
Description: rel.Info.Description,
})
}
return result, nil
}
// List 列出集群中的所有 Releases
func (h *HelmClient) List(ctx context.Context, cluster *entity.Cluster, namespace string) ([]*entity.Instance, error) {
actionConfig, err := h.getActionConfig(cluster, namespace)
if err != nil {
return nil, err
}
list := action.NewList(actionConfig)
if namespace == "" {
list.AllNamespaces = true
}
releases, err := list.Run()
if err != nil {
return nil, fmt.Errorf("failed to list releases: %w", err)
}
instances := make([]*entity.Instance, 0, len(releases))
for _, rel := range releases {
instances = append(instances, h.convertReleaseToInstance(rel))
}
return instances, nil
}
// GetValues 获取 Release 的 values
func (h *HelmClient) GetValues(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (map[string]interface{}, error) {
actionConfig, err := h.getActionConfig(cluster, namespace)
if err != nil {
return nil, err
}
getValues := action.NewGetValues(actionConfig)
values, err := getValues.Run(releaseName)
if err != nil {
return nil, fmt.Errorf("failed to get values: %w", err)
}
return values, nil
}
// convertReleaseToInstance 转换 Helm Release 为 Instance
func (h *HelmClient) convertReleaseToInstance(rel *release.Release) *entity.Instance {
return &entity.Instance{
Name: rel.Name,
Namespace: rel.Namespace,
Chart: rel.Chart.Metadata.Name,
Version: rel.Chart.Metadata.Version,
Status: entity.InstanceStatus(rel.Info.Status),
Revision: rel.Version,
Values: rel.Config,
UpdatedAt: rel.Info.LastDeployed.Time,
}
}

View File

@ -0,0 +1,321 @@
package k8s
import (
"context"
"fmt"
"strings"
corev1 "k8s.io/api/core/v1"
networkingv1 "k8s.io/api/networking/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// EntryClient 使用 Kubernetes API 查询实例相关 Service/Ingress
type EntryClient struct{}
// NewEntryClient 创建 EntryClient
func NewEntryClient() repository.InstanceEntryClient {
return &EntryClient{}
}
// ListEntries 查询实例的 Service/Ingress 入口
func (c *EntryClient) ListEntries(
ctx context.Context,
cluster *entity.Cluster,
instance *entity.Instance,
) ([]*entity.InstanceEntry, error) {
clientset, err := c.createClientset(cluster)
if err != nil {
return nil, err
}
selector := fmt.Sprintf("app.kubernetes.io/instance=%s", instance.Name)
serviceEntries, err := c.collectServiceEntries(ctx, clientset, instance, selector)
if err != nil {
return nil, err
}
ingressEntries, err := c.collectIngressEntries(ctx, clientset, instance, selector)
if err != nil {
return nil, err
}
return append(serviceEntries, ingressEntries...), nil
}
func (c *EntryClient) collectServiceEntries(
ctx context.Context,
clientset *kubernetes.Clientset,
instance *entity.Instance,
selector string,
) ([]*entity.InstanceEntry, error) {
services, err := c.listServices(ctx, clientset, instance.Namespace, selector)
if err != nil {
return nil, err
}
entries := convertServicesToEntries(services, instance, selector == "")
if len(entries) == 0 && selector != "" {
// Fallback: widen the search scope and filter manually.
services, err = c.listServices(ctx, clientset, instance.Namespace, "")
if err != nil {
return nil, err
}
entries = convertServicesToEntries(services, instance, true)
}
return entries, nil
}
func (c *EntryClient) collectIngressEntries(
ctx context.Context,
clientset *kubernetes.Clientset,
instance *entity.Instance,
selector string,
) ([]*entity.InstanceEntry, error) {
ingresses, err := c.listIngresses(ctx, clientset, instance.Namespace, selector)
if err != nil {
return nil, err
}
entries := convertIngressesToEntries(ingresses, instance, selector == "")
if len(entries) == 0 && selector != "" {
ingresses, err = c.listIngresses(ctx, clientset, instance.Namespace, "")
if err != nil {
return nil, err
}
entries = convertIngressesToEntries(ingresses, instance, true)
}
return entries, nil
}
func (c *EntryClient) listServices(
ctx context.Context,
clientset *kubernetes.Clientset,
namespace, selector string,
) ([]corev1.Service, error) {
listOptions := metav1.ListOptions{}
if selector != "" {
listOptions.LabelSelector = selector
}
services, err := clientset.CoreV1().
Services(namespace).
List(ctx, listOptions)
if err != nil {
return nil, fmt.Errorf("failed to list services: %w", err)
}
return services.Items, nil
}
func (c *EntryClient) listIngresses(
ctx context.Context,
clientset *kubernetes.Clientset,
namespace, selector string,
) ([]networkingv1.Ingress, error) {
listOptions := metav1.ListOptions{}
if selector != "" {
listOptions.LabelSelector = selector
}
ingresses, err := clientset.NetworkingV1().
Ingresses(namespace).
List(ctx, listOptions)
if err != nil {
return nil, fmt.Errorf("failed to list ingresses: %w", err)
}
return ingresses.Items, nil
}
func convertServicesToEntries(services []corev1.Service, instance *entity.Instance, enforceMatch bool) []*entity.InstanceEntry {
entries := make([]*entity.InstanceEntry, 0, len(services))
for _, svc := range services {
if enforceMatch && !resourceMatchesInstance(svc.ObjectMeta, instance) {
continue
}
entries = append(entries, convertServiceToEntry(&svc))
}
return entries
}
func convertIngressesToEntries(ingresses []networkingv1.Ingress, instance *entity.Instance, enforceMatch bool) []*entity.InstanceEntry {
entries := make([]*entity.InstanceEntry, 0, len(ingresses))
for _, ing := range ingresses {
if enforceMatch && !resourceMatchesInstance(ing.ObjectMeta, instance) {
continue
}
entries = append(entries, convertIngressToEntry(&ing))
}
return entries
}
func (c *EntryClient) createClientset(cluster *entity.Cluster) (*kubernetes.Clientset, error) {
config, err := clientcmd.RESTConfigFromKubeConfig([]byte(cluster.GetKubeConfig()))
if err != nil {
config = &rest.Config{
Host: cluster.Host,
TLSClientConfig: rest.TLSClientConfig{
CAData: []byte(cluster.CAData),
CertData: []byte(cluster.CertData),
KeyData: []byte(cluster.KeyData),
},
BearerToken: cluster.Token,
}
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
}
return clientset, nil
}
func convertServiceToEntry(svc *corev1.Service) *entity.InstanceEntry {
clusterIP := svc.Spec.ClusterIP
if clusterIP == corev1.ClusterIPNone {
clusterIP = ""
}
lbIngress := make([]string, 0, len(svc.Status.LoadBalancer.Ingress))
for _, ing := range svc.Status.LoadBalancer.Ingress {
if ing.IP != "" {
lbIngress = append(lbIngress, ing.IP)
}
if ing.Hostname != "" {
lbIngress = append(lbIngress, ing.Hostname)
}
}
ports := make([]entity.InstanceEntryPort, 0, len(svc.Spec.Ports))
for _, port := range svc.Spec.Ports {
ports = append(ports, entity.InstanceEntryPort{
Name: port.Name,
Protocol: string(port.Protocol),
Port: port.Port,
TargetPort: intOrStringToString(port.TargetPort),
NodePort: port.NodePort,
})
}
return &entity.InstanceEntry{
Kind: "Service",
Name: svc.Name,
Namespace: svc.Namespace,
Type: string(svc.Spec.Type),
ClusterIP: clusterIP,
ExternalIPs: append([]string{}, svc.Spec.ExternalIPs...),
LoadBalancerIngress: lbIngress,
Ports: ports,
}
}
func convertIngressToEntry(ing *networkingv1.Ingress) *entity.InstanceEntry {
lbIngress := make([]string, 0, len(ing.Status.LoadBalancer.Ingress))
for _, addr := range ing.Status.LoadBalancer.Ingress {
if addr.IP != "" {
lbIngress = append(lbIngress, addr.IP)
}
if addr.Hostname != "" {
lbIngress = append(lbIngress, addr.Hostname)
}
}
hosts := make([]entity.InstanceEntryHost, 0, len(ing.Spec.Rules))
for _, rule := range ing.Spec.Rules {
hostEntry := entity.InstanceEntryHost{
Host: rule.Host,
}
if rule.HTTP != nil {
paths := make([]entity.InstanceEntryPath, 0, len(rule.HTTP.Paths))
for _, path := range rule.HTTP.Paths {
name := ""
port := ""
if path.Backend.Service != nil {
name = path.Backend.Service.Name
port = serviceBackendPortString(path.Backend.Service.Port)
}
paths = append(paths, entity.InstanceEntryPath{
Path: path.Path,
ServiceName: name,
ServicePort: port,
})
}
hostEntry.Paths = paths
}
hosts = append(hosts, hostEntry)
}
tlsEntries := make([]entity.InstanceEntryTLS, 0, len(ing.Spec.TLS))
for _, tls := range ing.Spec.TLS {
tlsEntries = append(tlsEntries, entity.InstanceEntryTLS{
Hosts: append([]string{}, tls.Hosts...),
SecretName: tls.SecretName,
})
}
entryType := "Ingress"
if ing.Spec.IngressClassName != nil {
entryType = *ing.Spec.IngressClassName
}
return &entity.InstanceEntry{
Kind: "Ingress",
Name: ing.Name,
Namespace: ing.Namespace,
Type: entryType,
LoadBalancerIngress: lbIngress,
Hosts: hosts,
TLS: tlsEntries,
}
}
func intOrStringToString(v intstr.IntOrString) string {
if v.Type == intstr.String {
return v.StrVal
}
return fmt.Sprintf("%d", v.IntValue())
}
func serviceBackendPortString(port networkingv1.ServiceBackendPort) string {
if port.Name != "" {
return port.Name
}
if port.Number != 0 {
return fmt.Sprintf("%d", port.Number)
}
return ""
}
func resourceMatchesInstance(meta metav1.ObjectMeta, instance *entity.Instance) bool {
if instance == nil {
return false
}
labels := meta.GetLabels()
if labels != nil {
if labels["app.kubernetes.io/instance"] == instance.Name {
return true
}
labelKeys := []string{"app", "app.kubernetes.io/name", "app.kubernetes.io/component", "release"}
for _, key := range labelKeys {
if labels[key] == instance.Name {
return true
}
}
}
annotations := meta.GetAnnotations()
if annotations != nil {
if annotations["meta.helm.sh/release-name"] == instance.Name {
if ns := annotations["meta.helm.sh/release-namespace"]; ns == "" || ns == instance.Namespace {
return true
}
}
}
name := meta.GetName()
if name == instance.Name || strings.HasPrefix(name, instance.Name+"-") {
return true
}
return false
}

View File

@ -0,0 +1,54 @@
package k8s
import (
"testing"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"github.com/ocdp/cluster-service/internal/domain/entity"
)
func TestResourceMatchesInstance(t *testing.T) {
instance := &entity.Instance{
Name: "demo",
Namespace: "default",
}
testCases := []struct {
name string
meta metav1.ObjectMeta
want bool
}{
{
name: "matches by standard label",
meta: metav1.ObjectMeta{Labels: map[string]string{
"app.kubernetes.io/instance": "demo",
}},
want: true,
},
{
name: "matches by helm annotations",
meta: metav1.ObjectMeta{Annotations: map[string]string{
"meta.helm.sh/release-name": "demo",
"meta.helm.sh/release-namespace": "default",
}},
want: true,
},
{
name: "matches by resource name prefix",
meta: metav1.ObjectMeta{Name: "demo-nginx"},
want: true,
},
{
name: "does not match unrelated resource",
meta: metav1.ObjectMeta{Name: "other"},
want: false,
},
}
for _, tc := range testCases {
if got := resourceMatchesInstance(tc.meta, instance); got != tc.want {
t.Fatalf("%s: expected %v, got %v", tc.name, tc.want, got)
}
}
}

View File

@ -0,0 +1,370 @@
package k8s
import (
"context"
"fmt"
"time"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
metricsv "k8s.io/metrics/pkg/client/clientset/versioned"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// MetricsClient 实现从 Kubernetes 集群获取监控指标
type MetricsClient struct {
clusterRepo repository.ClusterRepository
}
// NewMetricsClient 创建 MetricsClient
func NewMetricsClient(clusterRepo repository.ClusterRepository) *MetricsClient {
return &MetricsClient{
clusterRepo: clusterRepo,
}
}
// GetClusterMetrics 获取集群监控指标
func (c *MetricsClient) GetClusterMetrics(ctx context.Context, clusterID string) (*entity.ClusterMetrics, error) {
// 获取集群信息
cluster, err := c.clusterRepo.GetByID(ctx, clusterID)
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
// 创建 Kubernetes 客户端
clientset, metricsClient, err := c.createK8sClients(cluster)
if err != nil {
return nil, fmt.Errorf("failed to create k8s client: %w", err)
}
// 获取节点列表
nodes, err := clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list nodes: %w", err)
}
// 获取所有 Pods
pods, err := clientset.CoreV1().Pods("").List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list pods: %w", err)
}
// 获取节点指标CPU/内存使用情况)
nodeMetrics, err := c.getNodeMetricsData(ctx, clientset, metricsClient, nodes.Items)
if err != nil {
// 如果无法获取 metrics记录错误但继续
fmt.Printf("Warning: failed to get node metrics: %v\n", err)
}
// 计算集群级别汇总
metrics := c.aggregateClusterMetrics(cluster, nodes.Items, pods.Items, nodeMetrics)
return metrics, nil
}
// GetNodeMetrics 获取集群节点指标
func (c *MetricsClient) GetNodeMetrics(ctx context.Context, clusterID string) ([]*entity.NodeMetrics, error) {
cluster, err := c.clusterRepo.GetByID(ctx, clusterID)
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
clientset, metricsClient, err := c.createK8sClients(cluster)
if err != nil {
return nil, fmt.Errorf("failed to create k8s client: %w", err)
}
nodes, err := clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list nodes: %w", err)
}
return c.getNodeMetricsData(ctx, clientset, metricsClient, nodes.Items)
}
// createK8sClients 创建 Kubernetes 客户端
func (c *MetricsClient) createK8sClients(cluster *entity.Cluster) (*kubernetes.Clientset, *metricsv.Clientset, error) {
config, err := clientcmd.RESTConfigFromKubeConfig([]byte(cluster.GetKubeConfig()))
if err != nil {
// 如果无法从 kubeconfig 创建,尝试使用集群配置
config = &rest.Config{
Host: cluster.Host,
TLSClientConfig: rest.TLSClientConfig{
CAData: []byte(cluster.CAData),
CertData: []byte(cluster.CertData),
KeyData: []byte(cluster.KeyData),
},
}
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, nil, fmt.Errorf("failed to create clientset: %w", err)
}
metricsClient, err := metricsv.NewForConfig(config)
if err != nil {
// Metrics API 可能不可用,返回 nil 但不报错
return clientset, nil, nil
}
return clientset, metricsClient, nil
}
// getNodeMetricsData 获取节点详细指标
func (c *MetricsClient) getNodeMetricsData(
ctx context.Context,
clientset *kubernetes.Clientset,
metricsClient *metricsv.Clientset,
nodes []corev1.Node,
) ([]*entity.NodeMetrics, error) {
result := make([]*entity.NodeMetrics, 0, len(nodes))
for _, node := range nodes {
nodeMetric := &entity.NodeMetrics{
NodeName: node.Name,
Status: getNodeStatus(&node),
Role: getNodeRole(&node),
Age: getNodeAge(&node),
OSImage: node.Status.NodeInfo.OSImage,
KernelVersion: node.Status.NodeInfo.KernelVersion,
ContainerRuntime: node.Status.NodeInfo.ContainerRuntimeVersion,
KubeletVersion: node.Status.NodeInfo.KubeletVersion,
}
// CPU
cpuCapacity := node.Status.Capacity.Cpu()
cpuAllocatable := node.Status.Allocatable.Cpu()
nodeMetric.CPUCapacity = fmt.Sprintf("%.2f cores", float64(cpuCapacity.MilliValue())/1000.0)
nodeMetric.CPUAllocatable = fmt.Sprintf("%.2f cores", float64(cpuAllocatable.MilliValue())/1000.0)
// Memory
memCapacity := node.Status.Capacity.Memory()
memAllocatable := node.Status.Allocatable.Memory()
nodeMetric.MemoryCapacity = formatBytes(memCapacity.Value())
nodeMetric.MemoryAllocatable = formatBytes(memAllocatable.Value())
// GPU (从 node allocatable 中查找)
if gpu, ok := node.Status.Allocatable["nvidia.com/gpu"]; ok {
nodeMetric.GPUCapacity = int(gpu.Value())
// 尝试获取 GPU 类型
if gpuType, ok := node.Labels["nvidia.com/gpu.product"]; ok {
nodeMetric.GPUType = gpuType
}
}
// 获取 Pod 数量
pods, err := clientset.CoreV1().Pods("").List(ctx, metav1.ListOptions{
FieldSelector: fmt.Sprintf("spec.nodeName=%s", node.Name),
})
if err == nil {
nodeMetric.PodCount = len(pods.Items)
}
// 如果有 metrics client获取实时使用情况
if metricsClient != nil {
nodeMetricData, err := metricsClient.MetricsV1beta1().NodeMetricses().Get(ctx, node.Name, metav1.GetOptions{})
if err == nil {
// CPU 使用
cpuUsage := nodeMetricData.Usage.Cpu()
nodeMetric.CPUUsage = fmt.Sprintf("%.2f cores", float64(cpuUsage.MilliValue())/1000.0)
if cpuAllocatable.MilliValue() > 0 {
nodeMetric.CPUPercent = float64(cpuUsage.MilliValue()) / float64(cpuAllocatable.MilliValue()) * 100
}
// Memory 使用
memUsage := nodeMetricData.Usage.Memory()
nodeMetric.MemoryUsage = formatBytes(memUsage.Value())
if memAllocatable.Value() > 0 {
nodeMetric.MemoryPercent = float64(memUsage.Value()) / float64(memAllocatable.Value()) * 100
}
}
}
result = append(result, nodeMetric)
}
return result, nil
}
// aggregateClusterMetrics 聚合集群级别指标
func (c *MetricsClient) aggregateClusterMetrics(
cluster *entity.Cluster,
nodes []corev1.Node,
pods []corev1.Pod,
nodeMetrics []*entity.NodeMetrics,
) *entity.ClusterMetrics {
metrics := &entity.ClusterMetrics{
ClusterID: cluster.ID,
ClusterName: cluster.Name,
Status: "healthy",
NodeCount: len(nodes),
PodCount: len(pods),
LastCheck: time.Now(),
Nodes: make([]entity.NodeMetrics, 0),
}
// 汇总资源
var totalCPU, totalMem, usedCPU, usedMem int64
var totalGPU, usedGPU int
healthyNodes := 0
// 单机最大值
var maxNodeCPU, maxNodeMem int64
var maxNodeGPU int
var maxNodeCPUUsage, maxNodeMemUsage, maxNodeGPUUsage float64
for i, node := range nodes {
// CPU
cpuCap := node.Status.Capacity.Cpu()
totalCPU += cpuCap.MilliValue()
if cpuCap.MilliValue() > maxNodeCPU {
maxNodeCPU = cpuCap.MilliValue()
}
// Memory
memCap := node.Status.Capacity.Memory()
totalMem += memCap.Value()
if memCap.Value() > maxNodeMem {
maxNodeMem = memCap.Value()
}
// GPU
if gpu, ok := node.Status.Allocatable["nvidia.com/gpu"]; ok {
gpuCount := int(gpu.Value())
totalGPU += gpuCount
if gpuCount > maxNodeGPU {
maxNodeGPU = gpuCount
}
}
// Node status
if getNodeStatus(&node) == "Ready" {
healthyNodes++
}
// 从 nodeMetrics 获取使用情况
if i < len(nodeMetrics) && nodeMetrics[i] != nil {
metrics.Nodes = append(metrics.Nodes, *nodeMetrics[i])
// 更新单机最大使用率
if nodeMetrics[i].CPUPercent > maxNodeCPUUsage {
maxNodeCPUUsage = nodeMetrics[i].CPUPercent
}
if nodeMetrics[i].MemoryPercent > maxNodeMemUsage {
maxNodeMemUsage = nodeMetrics[i].MemoryPercent
}
if nodeMetrics[i].GPUPercent > maxNodeGPUUsage {
maxNodeGPUUsage = nodeMetrics[i].GPUPercent
}
}
}
// 计算集群 uptime简化使用最老节点的年龄
if len(nodes) > 0 {
metrics.Uptime = getNodeAge(&nodes[0])
}
// 格式化总资源
metrics.TotalCPU = fmt.Sprintf("%.2f cores", float64(totalCPU)/1000.0)
metrics.TotalMemory = formatBytes(totalMem)
metrics.TotalGPU = totalGPU
// 格式化单机最大值
metrics.MaxNodeCPU = fmt.Sprintf("%.2f cores", float64(maxNodeCPU)/1000.0)
metrics.MaxNodeMemory = formatBytes(maxNodeMem)
metrics.MaxNodeGPU = maxNodeGPU
metrics.MaxNodeCPUUsage = maxNodeCPUUsage
metrics.MaxNodeMemUsage = maxNodeMemUsage
metrics.MaxNodeGPUUsage = maxNodeGPUUsage
// 使用情况(简化处理)
if len(nodeMetrics) > 0 {
for _, nm := range nodeMetrics {
// 解析使用的 CPU 和内存
// 这里简化处理,实际应该解析字符串
usedCPU += int64(nm.CPUPercent * float64(totalCPU) / 100.0)
usedMem += int64(nm.MemoryPercent * float64(totalMem) / 100.0)
usedGPU += nm.GPUUsage
}
if totalCPU > 0 {
metrics.CPUUsage = float64(usedCPU) / float64(totalCPU) * 100
}
if totalMem > 0 {
metrics.MemoryUsage = float64(usedMem) / float64(totalMem) * 100
}
if totalGPU > 0 {
metrics.GPUUsage = float64(usedGPU) / float64(totalGPU) * 100
}
metrics.UsedCPU = fmt.Sprintf("%.2f cores", float64(usedCPU)/1000.0)
metrics.UsedMemory = formatBytes(usedMem)
metrics.UsedGPU = usedGPU
}
// 确定集群状态
if healthyNodes == len(nodes) {
metrics.Status = "healthy"
} else if healthyNodes > 0 {
metrics.Status = "warning"
} else {
metrics.Status = "error"
}
return metrics
}
// Helper functions
func getNodeStatus(node *corev1.Node) string {
for _, condition := range node.Status.Conditions {
if condition.Type == corev1.NodeReady {
if condition.Status == corev1.ConditionTrue {
return "Ready"
}
return "NotReady"
}
}
return "Unknown"
}
func getNodeRole(node *corev1.Node) string {
if _, ok := node.Labels["node-role.kubernetes.io/control-plane"]; ok {
return "control-plane"
}
if _, ok := node.Labels["node-role.kubernetes.io/master"]; ok {
return "control-plane"
}
return "worker"
}
func getNodeAge(node *corev1.Node) string {
age := time.Since(node.CreationTimestamp.Time)
days := int(age.Hours() / 24)
hours := int(age.Hours()) % 24
if days > 0 {
return fmt.Sprintf("%dd %dh", days, hours)
}
return fmt.Sprintf("%dh", hours)
}
func formatBytes(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(bytes)/float64(div), "KMGTPE"[exp])
}

View File

@ -0,0 +1,284 @@
package mock
import (
"context"
"fmt"
"strings"
"time"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// OCIClientMock OCI Registry 客户端 Mock 实现
type OCIClientMock struct {
// Mock 数据存储
repositories map[string][]string // registryID -> []repositoryName
artifacts map[string]map[string][]*entity.Artifact // registryID -> repository -> []artifact
}
// NewOCIClientMock 创建 Mock 实现
func NewOCIClientMock() repository.OCIClient {
mock := &OCIClientMock{
repositories: make(map[string][]string),
artifacts: make(map[string]map[string][]*entity.Artifact),
}
// 初始化一些测试数据
mock.initMockData()
return mock
}
func (c *OCIClientMock) initMockData() {
// Note: This method intentionally left empty
// Mock data will be generated dynamically per registry to support any registry ID
}
// initArtifactsForRegistry initializes mock artifacts for a given registry ID
func (c *OCIClientMock) initArtifactsForRegistry(registryID string) {
c.artifacts[registryID] = make(map[string][]*entity.Artifact)
// vllm-serve artifacts (OCI 格式的 Helm Chart)
c.artifacts[registryID]["charts/vllm-serve"] = []*entity.Artifact{
{
RegistryID: registryID,
Repository: "charts/vllm-serve",
Tag: "0.1.0",
Digest: "sha256:abc123def456",
Type: entity.ArtifactTypeChart,
Size: 12345678,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "vllm-serve",
"org.opencontainers.image.version": "0.1.0",
},
CreatedAt: time.Now().Add(-24 * time.Hour),
},
{
RegistryID: registryID,
Repository: "charts/vllm-serve",
Tag: "0.2.0",
Digest: "sha256:xyz789uvw012",
Type: entity.ArtifactTypeChart,
Size: 13456789,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "vllm-serve",
"org.opencontainers.image.version": "0.2.0",
},
CreatedAt: time.Now(),
},
}
// nginx artifacts (OCI 格式的 Helm Chart)
c.artifacts[registryID]["charts/nginx"] = []*entity.Artifact{
{
RegistryID: registryID,
Repository: "charts/nginx",
Tag: "1.0.0",
Digest: "sha256:nginx123456",
Type: entity.ArtifactTypeChart,
Size: 5678901,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "nginx",
},
CreatedAt: time.Now().Add(-48 * time.Hour),
},
}
// redis artifacts (OCI 格式的 Helm Chart)
c.artifacts[registryID]["charts/redis"] = []*entity.Artifact{
{
RegistryID: registryID,
Repository: "charts/redis",
Tag: "6.2.0",
Digest: "sha256:redis789abc",
Type: entity.ArtifactTypeChart,
Size: 8901234,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "redis",
"org.opencontainers.image.version": "6.2.0",
},
CreatedAt: time.Now().Add(-72 * time.Hour),
},
}
// alpine artifacts (Docker Image)
c.artifacts[registryID]["library/alpine"] = []*entity.Artifact{
{
RegistryID: registryID,
Repository: "library/alpine",
Tag: "3.18",
Digest: "sha256:alpine123",
Type: entity.ArtifactTypeImage,
Size: 2345678,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
ConfigType: "application/vnd.docker.container.image.v1+json", // Docker Image 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "alpine",
"org.opencontainers.image.version": "3.18",
},
CreatedAt: time.Now().Add(-96 * time.Hour),
},
{
RegistryID: registryID,
Repository: "library/alpine",
Tag: "latest",
Digest: "sha256:alpine456",
Type: entity.ArtifactTypeImage,
Size: 2456789,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
ConfigType: "application/vnd.docker.container.image.v1+json", // Docker Image 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "alpine",
},
CreatedAt: time.Now().Add(-24 * time.Hour),
},
}
}
func (c *OCIClientMock) ListRepositories(ctx context.Context, registry *entity.Registry) ([]string, error) {
// Check if we have cached data for this registry
repos, exists := c.repositories[registry.ID]
if !exists {
// Generate mock data dynamically for any registry
repos = []string{
"charts/vllm-serve",
"charts/nginx",
"charts/redis",
"library/alpine",
}
c.repositories[registry.ID] = repos
// Also initialize artifacts for this registry
c.initArtifactsForRegistry(registry.ID)
}
return repos, nil
}
func (c *OCIClientMock) ListArtifacts(ctx context.Context, registry *entity.Registry, repository, mediaTypeFilter string) ([]*entity.Artifact, error) {
regArtifacts, exists := c.artifacts[registry.ID]
if !exists {
// Initialize artifacts for this registry if not exists
c.initArtifactsForRegistry(registry.ID)
regArtifacts = c.artifacts[registry.ID]
}
artifacts, exists := regArtifacts[repository]
if !exists {
return []*entity.Artifact{}, nil
}
// 应用 mediaType 过滤
if mediaTypeFilter == "" || mediaTypeFilter == "all" {
return artifacts, nil
}
filtered := make([]*entity.Artifact, 0)
filter := strings.ToLower(strings.TrimSpace(mediaTypeFilter))
for _, artifact := range artifacts {
switch filter {
case "chart":
if artifact.Type == entity.ArtifactTypeChart {
filtered = append(filtered, artifact)
}
case "image":
if artifact.Type == entity.ArtifactTypeImage {
filtered = append(filtered, artifact)
}
case "other":
if artifact.Type == entity.ArtifactTypeOther {
filtered = append(filtered, artifact)
}
}
}
return filtered, nil
}
func (c *OCIClientMock) GetArtifact(ctx context.Context, registry *entity.Registry, repository, reference string) (*entity.Artifact, error) {
regArtifacts, exists := c.artifacts[registry.ID]
if !exists {
// Initialize artifacts for this registry if not exists
c.initArtifactsForRegistry(registry.ID)
regArtifacts = c.artifacts[registry.ID]
}
artifacts, exists := regArtifacts[repository]
if !exists {
return nil, entity.ErrArtifactNotFound
}
// 根据 tag 或 digest 查找
for _, artifact := range artifacts {
if artifact.Tag == reference || artifact.Digest == reference {
return artifact, nil
}
}
return nil, entity.ErrArtifactNotFound
}
func (c *OCIClientMock) GetValuesSchema(ctx context.Context, registry *entity.Registry, repository, reference string) (string, error) {
artifact, err := c.GetArtifact(ctx, registry, repository, reference)
if err != nil {
return "", err
}
if !artifact.IsChart() {
return "", fmt.Errorf("not a helm chart")
}
// 返回 Mock values schema
mockSchema := `{
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"replicaCount": {
"type": "integer",
"default": 1
},
"image": {
"type": "object",
"properties": {
"repository": {
"type": "string"
},
"tag": {
"type": "string"
}
}
}
}
}`
return mockSchema, nil
}
func (c *OCIClientMock) PullArtifact(ctx context.Context, registry *entity.Registry, repository, reference, destPath string) error {
_, err := c.GetArtifact(ctx, registry, repository, reference)
if err != nil {
return err
}
// Mock 实现,不实际下载
return nil
}
func (c *OCIClientMock) PushArtifact(ctx context.Context, registry *entity.Registry, repository, tag, sourcePath string) error {
// Mock 实现,不实际上传
return nil
}
func (c *OCIClientMock) CheckHealth(ctx context.Context, registry *entity.Registry) error {
// Mock 实现,总是返回健康
return nil
}

View File

@ -0,0 +1,468 @@
package real
import (
"archive/tar"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/opencontainers/go-digest"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"oras.land/oras-go/v2/registry/remote"
"oras.land/oras-go/v2/registry/remote/auth"
)
// OCIClient 真实的 OCI 客户端实现(使用 ORAS
type OCIClient struct {
httpClient *http.Client
}
// NewOCIClient 创建真实的 OCI 客户端
func NewOCIClient() repository.OCIClient {
return &OCIClient{
httpClient: &http.Client{},
}
}
// getRegistry 创建 ORAS Registry 客户端
func (c *OCIClient) getRegistry(reg *entity.Registry) (*remote.Registry, error) {
// 解析 Registry URL
registryURL := strings.TrimPrefix(reg.URL, "https://")
registryURL = strings.TrimPrefix(registryURL, "http://")
registry, err := remote.NewRegistry(registryURL)
if err != nil {
return nil, fmt.Errorf("failed to create registry client: %w", err)
}
// 设置认证
if reg.Username != "" && reg.Password != "" {
registry.Client = &auth.Client{
Client: c.httpClient,
Credential: auth.StaticCredential(registryURL, auth.Credential{
Username: reg.Username,
Password: reg.Password,
}),
}
}
// 设置 PlainHTTP如果是 insecure
registry.PlainHTTP = reg.Insecure
return registry, nil
}
// ListRepositories 列出 Registry 中的所有 repositories
func (c *OCIClient) ListRepositories(ctx context.Context, registry *entity.Registry) ([]string, error) {
reg, err := c.getRegistry(registry)
if err != nil {
return nil, err
}
repositories := make([]string, 0)
err = reg.Repositories(ctx, "", func(repos []string) error {
repositories = append(repositories, repos...)
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to list repositories: %w", err)
}
return repositories, nil
}
// ListArtifacts 列出指定 repository 的所有 artifacts
// mediaTypeFilter: "all", "image", "chart", "other" - 使用模糊匹配过滤
func (c *OCIClient) ListArtifacts(ctx context.Context, registry *entity.Registry, repository, mediaTypeFilter string) ([]*entity.Artifact, error) {
reg, err := c.getRegistry(registry)
if err != nil {
return nil, err
}
repo, err := reg.Repository(ctx, repository)
if err != nil {
return nil, fmt.Errorf("failed to get repository: %w", err)
}
artifacts := make([]*entity.Artifact, 0)
err = repo.Tags(ctx, "", func(tags []string) error {
for _, tag := range tags {
// 获取 manifest 以获取更多信息
desc, err := repo.Resolve(ctx, tag)
if err != nil {
// 跳过无法解析的 tag
continue
}
artifact := &entity.Artifact{
Repository: repository,
Tag: tag,
Digest: desc.Digest.String(),
MediaType: desc.MediaType,
Size: desc.Size,
}
// 尝试获取 config.mediaType 以更准确判断类型
if manifestBytes, err := repo.Fetch(ctx, desc); err == nil {
defer manifestBytes.Close()
if manifestData, err := io.ReadAll(manifestBytes); err == nil {
var manifest map[string]interface{}
if err := json.Unmarshal(manifestData, &manifest); err == nil {
// 获取 config.mediaType
if config, ok := manifest["config"].(map[string]interface{}); ok {
if configMediaType, ok := config["mediaType"].(string); ok {
artifact.ConfigType = configMediaType
}
}
}
}
}
// 使用智能类型判断(综合多种信息)
artifact.DetermineType()
// 应用 mediaType 过滤
if c.shouldIncludeArtifact(artifact, mediaTypeFilter) {
artifacts = append(artifacts, artifact)
}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to list artifacts: %w", err)
}
return artifacts, nil
}
// shouldIncludeArtifact 判断是否应该包含该 artifact
func (c *OCIClient) shouldIncludeArtifact(artifact *entity.Artifact, filter string) bool {
// 默认或 "all" 返回所有
if filter == "" || filter == "all" {
return true
}
filter = strings.ToLower(strings.TrimSpace(filter))
switch filter {
case "chart":
// 只返回 Helm Charts
return artifact.Type == entity.ArtifactTypeChart
case "image":
// 返回 Docker 或 OCI images
return artifact.Type == entity.ArtifactTypeImage
case "other":
// 返回其他类型
return artifact.Type == entity.ArtifactTypeOther
default:
// 未知的 filter返回所有
return true
}
}
// GetArtifact 获取指定 artifact 的详细信息
func (c *OCIClient) GetArtifact(ctx context.Context, registry *entity.Registry, repository, reference string) (*entity.Artifact, error) {
reg, err := c.getRegistry(registry)
if err != nil {
return nil, err
}
repo, err := reg.Repository(ctx, repository)
if err != nil {
return nil, fmt.Errorf("failed to get repository: %w", err)
}
// 解析 reference
desc, err := repo.Resolve(ctx, reference)
if err != nil {
return nil, fmt.Errorf("failed to resolve artifact: %w", err)
}
// 获取 manifest
manifestBytes, err := repo.Fetch(ctx, desc)
if err != nil {
return nil, fmt.Errorf("failed to fetch manifest: %w", err)
}
defer manifestBytes.Close()
manifestData, err := io.ReadAll(manifestBytes)
if err != nil {
return nil, fmt.Errorf("failed to read manifest: %w", err)
}
// 解析 manifest 获取配置信息
var manifest map[string]interface{}
if err := json.Unmarshal(manifestData, &manifest); err != nil {
return nil, fmt.Errorf("failed to unmarshal manifest: %w", err)
}
artifact := &entity.Artifact{
Repository: repository,
Tag: reference,
Digest: desc.Digest.String(),
MediaType: desc.MediaType,
Size: desc.Size,
Annotations: make(map[string]string),
}
// 获取 config.mediaType 和 annotations
if config, ok := manifest["config"].(map[string]interface{}); ok {
// 获取 config.mediaType用于准确的类型判断
if configMediaType, ok := config["mediaType"].(string); ok {
artifact.ConfigType = configMediaType
}
// 获取 annotations
if annotations, ok := config["annotations"].(map[string]interface{}); ok {
for k, v := range annotations {
if str, ok := v.(string); ok {
artifact.Annotations[k] = str
}
}
}
}
// 使用智能类型判断(综合 ConfigType, Annotations, Repository 名称等)
artifact.DetermineType()
return artifact, nil
}
// GetValuesSchema 获取 Helm Chart 的 values schema
func (c *OCIClient) GetValuesSchema(ctx context.Context, registry *entity.Registry, repository, reference string) (string, error) {
reg, err := c.getRegistry(registry)
if err != nil {
return "", err
}
repo, err := reg.Repository(ctx, repository)
if err != nil {
return "", fmt.Errorf("failed to get repository: %w", err)
}
// 解析 reference (tag 或 digest)
desc, err := repo.Resolve(ctx, reference)
if err != nil {
return "", fmt.Errorf("failed to resolve artifact: %w", err)
}
manifestReader, err := repo.Fetch(ctx, desc)
if err != nil {
return "", fmt.Errorf("failed to fetch manifest: %w", err)
}
defer manifestReader.Close()
manifestBytes, err := io.ReadAll(manifestReader)
if err != nil {
return "", fmt.Errorf("failed to read manifest: %w", err)
}
var manifest ocispec.Manifest
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
return "", fmt.Errorf("failed to unmarshal manifest: %w", err)
}
// 优先查找是否存在独立的 values schema layer一些 registry 会将 values.schema.json 作为单独的 layer 存储)
var valuesSchemaLayer *ocispec.Descriptor
for i := range manifest.Layers {
layer := manifest.Layers[i]
mediaType := strings.ToLower(layer.MediaType)
if strings.Contains(mediaType, "helm.values.schema") ||
strings.Contains(mediaType, "values.schema") {
valuesSchemaLayer = &manifest.Layers[i]
break
}
}
// 如果存在独立的 values schema layer直接返回
if valuesSchemaLayer != nil {
reader, err := repo.Fetch(ctx, *valuesSchemaLayer)
if err != nil {
return "", fmt.Errorf("failed to fetch values schema layer: %w", err)
}
defer reader.Close()
data, err := io.ReadAll(reader)
if err != nil {
return "", fmt.Errorf("failed to read values schema layer: %w", err)
}
if len(data) == 0 {
return "", entity.ErrValuesSchemaNotFound
}
return string(data), nil
}
// 回退:查找 Helm Chart layertar+gzip 包含 chart 内容)并从中读取 values.schema.json
var chartLayer *ocispec.Descriptor
for i := range manifest.Layers {
layer := manifest.Layers[i]
if strings.Contains(layer.MediaType, "cncf.helm.chart") ||
strings.Contains(layer.MediaType, "helm.chart.content") {
chartLayer = &manifest.Layers[i]
break
}
}
if chartLayer == nil {
return "", entity.ErrValuesSchemaNotFound
}
if chartLayer.Digest == "" {
return "", fmt.Errorf("chart layer digest is empty")
}
if _, err := digest.Parse(string(chartLayer.Digest)); err != nil {
return "", fmt.Errorf("invalid chart layer digest: %w", err)
}
layerReader, err := repo.Fetch(ctx, *chartLayer)
if err != nil {
return "", fmt.Errorf("failed to fetch chart layer: %w", err)
}
defer layerReader.Close()
gzipReader, err := gzip.NewReader(layerReader)
if err != nil {
return "", fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("failed to read chart archive: %w", err)
}
if header.Typeflag != tar.TypeReg {
continue
}
if strings.HasSuffix(header.Name, "values.schema.json") {
data, err := io.ReadAll(tarReader)
if err != nil {
return "", fmt.Errorf("failed to read values.schema.json: %w", err)
}
if len(data) == 0 {
return "", entity.ErrValuesSchemaNotFound
}
return string(data), nil
}
}
return "", entity.ErrValuesSchemaNotFound
}
// PullArtifact 下载 artifact 到本地
func (c *OCIClient) PullArtifact(ctx context.Context, registry *entity.Registry, repository, reference, destPath string) error {
reg, err := c.getRegistry(registry)
if err != nil {
return err
}
repo, err := reg.Repository(ctx, repository)
if err != nil {
return fmt.Errorf("failed to get repository: %w", err)
}
// 解析 reference
desc, err := repo.Resolve(ctx, reference)
if err != nil {
return fmt.Errorf("failed to resolve artifact: %w", err)
}
// 获取 manifest 内容
manifestReader, err := repo.Fetch(ctx, desc)
if err != nil {
return fmt.Errorf("failed to fetch manifest: %w", err)
}
defer manifestReader.Close()
manifestBytes, err := io.ReadAll(manifestReader)
if err != nil {
return fmt.Errorf("failed to read manifest: %w", err)
}
var manifest ocispec.Manifest
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
return fmt.Errorf("failed to unmarshal manifest: %w", err)
}
var chartLayer *ocispec.Descriptor
for i := range manifest.Layers {
layer := manifest.Layers[i]
if strings.Contains(layer.MediaType, "cncf.helm.chart") ||
strings.Contains(layer.MediaType, "helm.chart.content") {
chartLayer = &layer
break
}
}
if chartLayer == nil {
return fmt.Errorf("helm chart layer not found in manifest")
}
content, err := repo.Fetch(ctx, *chartLayer)
if err != nil {
return fmt.Errorf("failed to fetch chart layer: %w", err)
}
defer content.Close()
// 确保目标目录存在
if err := os.MkdirAll(filepath.Dir(destPath), 0755); err != nil {
return fmt.Errorf("failed to create destination directory: %w", err)
}
// 写入文件
file, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer file.Close()
if _, err := io.Copy(file, content); err != nil {
return fmt.Errorf("failed to write artifact: %w", err)
}
return nil
}
// PushArtifact 推送 artifact 到 Registry
func (c *OCIClient) PushArtifact(ctx context.Context, registry *entity.Registry, repository, tag, sourcePath string) error {
// 这是一个简化实现
// 实际应该实现完整的 OCI artifact push 流程
return fmt.Errorf("push artifact not fully implemented yet")
}
// CheckHealth 检查 Registry 健康状态
func (c *OCIClient) CheckHealth(ctx context.Context, registry *entity.Registry) error {
reg, err := c.getRegistry(registry)
if err != nil {
return err
}
// 尝试 ping registry
err = reg.Ping(ctx)
if err != nil {
return fmt.Errorf("registry health check failed: %w", err)
}
return nil
}

View File

@ -0,0 +1,174 @@
package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// ClusterRepositoryMock 集群仓储 Mock 实现(内存存储,支持加密)
type ClusterRepositoryMock struct {
mu sync.RWMutex
clusters map[string]*entity.Cluster // key: cluster ID
encryptor crypto.Encryptor // 加密器
}
// NewClusterRepositoryMock 创建 Mock 实现
func NewClusterRepositoryMock(encryptor crypto.Encryptor) repository.ClusterRepository {
return &ClusterRepositoryMock{
clusters: make(map[string]*entity.Cluster),
encryptor: encryptor,
}
}
func (r *ClusterRepositoryMock) Create(ctx context.Context, cluster *entity.Cluster) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查名称是否已存在
for _, c := range r.clusters {
if c.Name == cluster.Name {
return entity.ErrClusterExists
}
}
// Mock 模式:如果没有提供认证信息,自动填充默认的 Mock 证书
if (cluster.CertData == "" || cluster.KeyData == "") && cluster.Token == "" {
cluster.CAData = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1vY2sgQ0EgQ2VydGlmaWNhdGUKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQ=="
cluster.CertData = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1vY2sgQ2xpZW50IENlcnRpZmljYXRlCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0="
cluster.KeyData = "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNb2NrIFByaXZhdGUgS2V5Ci0tLS0tRU5EIFJTQSBQUklWQVRFIEtFWS0tLS0t"
}
// 加密敏感数据后存储
encryptedCluster := r.encryptCluster(cluster)
r.clusters[cluster.ID] = encryptedCluster
return nil
}
func (r *ClusterRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
cluster, exists := r.clusters[id]
if !exists {
return nil, entity.ErrClusterNotFound
}
// 解密敏感数据后返回
return r.decryptCluster(cluster), nil
}
func (r *ClusterRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, cluster := range r.clusters {
if cluster.Name == name {
// 解密敏感数据后返回
return r.decryptCluster(cluster), nil
}
}
return nil, entity.ErrClusterNotFound
}
func (r *ClusterRepositoryMock) Update(ctx context.Context, cluster *entity.Cluster) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.clusters[cluster.ID]; !exists {
return entity.ErrClusterNotFound
}
// 加密敏感数据后存储
encryptedCluster := r.encryptCluster(cluster)
r.clusters[cluster.ID] = encryptedCluster
return nil
}
func (r *ClusterRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.clusters[id]; !exists {
return entity.ErrClusterNotFound
}
delete(r.clusters, id)
return nil
}
func (r *ClusterRepositoryMock) List(ctx context.Context) ([]*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
clusters := make([]*entity.Cluster, 0, len(r.clusters))
for _, cluster := range r.clusters {
// 解密敏感数据后返回
clusters = append(clusters, r.decryptCluster(cluster))
}
return clusters, nil
}
// encryptCluster 加密 Cluster 的敏感数据
func (r *ClusterRepositoryMock) encryptCluster(cluster *entity.Cluster) *entity.Cluster {
encrypted := *cluster // 复制
// 加密证书数据
if cluster.CAData != "" && !crypto.IsEncrypted(cluster.CAData) {
if encryptedData, err := r.encryptor.Encrypt(cluster.CAData); err == nil {
encrypted.CAData = encryptedData
}
}
if cluster.CertData != "" && !crypto.IsEncrypted(cluster.CertData) {
if encryptedData, err := r.encryptor.Encrypt(cluster.CertData); err == nil {
encrypted.CertData = encryptedData
}
}
if cluster.KeyData != "" && !crypto.IsEncrypted(cluster.KeyData) {
if encryptedData, err := r.encryptor.Encrypt(cluster.KeyData); err == nil {
encrypted.KeyData = encryptedData
}
}
if cluster.Token != "" && !crypto.IsEncrypted(cluster.Token) {
if encryptedData, err := r.encryptor.Encrypt(cluster.Token); err == nil {
encrypted.Token = encryptedData
}
}
return &encrypted
}
// decryptCluster 解密 Cluster 的敏感数据
func (r *ClusterRepositoryMock) decryptCluster(cluster *entity.Cluster) *entity.Cluster {
decrypted := *cluster // 复制
// 解密证书数据
if cluster.CAData != "" && crypto.IsEncrypted(cluster.CAData) {
if decryptedData, err := r.encryptor.Decrypt(cluster.CAData); err == nil {
decrypted.CAData = decryptedData
}
}
if cluster.CertData != "" && crypto.IsEncrypted(cluster.CertData) {
if decryptedData, err := r.encryptor.Decrypt(cluster.CertData); err == nil {
decrypted.CertData = decryptedData
}
}
if cluster.KeyData != "" && crypto.IsEncrypted(cluster.KeyData) {
if decryptedData, err := r.encryptor.Decrypt(cluster.KeyData); err == nil {
decrypted.KeyData = decryptedData
}
}
if cluster.Token != "" && crypto.IsEncrypted(cluster.Token) {
if decryptedData, err := r.encryptor.Decrypt(cluster.Token); err == nil {
decrypted.Token = decryptedData
}
}
return &decrypted
}

View File

@ -0,0 +1,113 @@
package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// InstanceRepositoryMock 实例仓储 Mock 实现(内存存储)
type InstanceRepositoryMock struct {
mu sync.RWMutex
instances map[string]*entity.Instance // key: instance ID
}
// NewInstanceRepositoryMock 创建 Mock 实现
func NewInstanceRepositoryMock() repository.InstanceRepository {
return &InstanceRepositoryMock{
instances: make(map[string]*entity.Instance),
}
}
func (r *InstanceRepositoryMock) Create(ctx context.Context, instance *entity.Instance) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查同一集群中名称是否已存在
for _, inst := range r.instances {
if inst.ClusterID == instance.ClusterID && inst.Name == instance.Name {
return entity.ErrInstanceExists
}
}
r.instances[instance.ID] = instance
return nil
}
func (r *InstanceRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instance, exists := r.instances[id]
if !exists {
return nil, entity.ErrInstanceNotFound
}
return instance, nil
}
func (r *InstanceRepositoryMock) GetByClusterAndName(ctx context.Context, clusterID, name string) (*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, instance := range r.instances {
if instance.ClusterID == clusterID && instance.Name == name {
return instance, nil
}
}
return nil, entity.ErrInstanceNotFound
}
func (r *InstanceRepositoryMock) Update(ctx context.Context, instance *entity.Instance) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.instances[instance.ID]; !exists {
return entity.ErrInstanceNotFound
}
r.instances[instance.ID] = instance
return nil
}
func (r *InstanceRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.instances[id]; !exists {
return entity.ErrInstanceNotFound
}
delete(r.instances, id)
return nil
}
func (r *InstanceRepositoryMock) ListByCluster(ctx context.Context, clusterID string) ([]*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instances := make([]*entity.Instance, 0)
for _, instance := range r.instances {
if instance.ClusterID == clusterID {
instances = append(instances, instance)
}
}
return instances, nil
}
func (r *InstanceRepositoryMock) List(ctx context.Context) ([]*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instances := make([]*entity.Instance, 0, len(r.instances))
for _, instance := range r.instances {
instances = append(instances, instance)
}
return instances, nil
}

View File

@ -0,0 +1,137 @@
package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// RegistryRepositoryMock Registry 仓储 Mock 实现(内存存储,支持加密)
type RegistryRepositoryMock struct {
mu sync.RWMutex
registries map[string]*entity.Registry // key: registry ID
encryptor crypto.Encryptor // 加密器
}
// NewRegistryRepositoryMock 创建 Mock 实现
func NewRegistryRepositoryMock(encryptor crypto.Encryptor) repository.RegistryRepository {
return &RegistryRepositoryMock{
registries: make(map[string]*entity.Registry),
encryptor: encryptor,
}
}
func (r *RegistryRepositoryMock) Create(ctx context.Context, registry *entity.Registry) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查名称是否已存在
for _, reg := range r.registries {
if reg.Name == registry.Name {
return entity.ErrRegistryExists
}
}
// 加密敏感数据后存储
encryptedRegistry := r.encryptRegistry(registry)
r.registries[registry.ID] = encryptedRegistry
return nil
}
func (r *RegistryRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
registry, exists := r.registries[id]
if !exists {
return nil, entity.ErrRegistryNotFound
}
// 解密敏感数据后返回
return r.decryptRegistry(registry), nil
}
func (r *RegistryRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, registry := range r.registries {
if registry.Name == name {
// 解密敏感数据后返回
return r.decryptRegistry(registry), nil
}
}
return nil, entity.ErrRegistryNotFound
}
func (r *RegistryRepositoryMock) Update(ctx context.Context, registry *entity.Registry) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.registries[registry.ID]; !exists {
return entity.ErrRegistryNotFound
}
// 加密敏感数据后存储
encryptedRegistry := r.encryptRegistry(registry)
r.registries[registry.ID] = encryptedRegistry
return nil
}
func (r *RegistryRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.registries[id]; !exists {
return entity.ErrRegistryNotFound
}
delete(r.registries, id)
return nil
}
func (r *RegistryRepositoryMock) List(ctx context.Context) ([]*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
registries := make([]*entity.Registry, 0, len(r.registries))
for _, registry := range r.registries {
// 解密敏感数据后返回
registries = append(registries, r.decryptRegistry(registry))
}
return registries, nil
}
// encryptRegistry 加密 Registry 的敏感数据
func (r *RegistryRepositoryMock) encryptRegistry(registry *entity.Registry) *entity.Registry {
encrypted := *registry // 复制
// 加密密码
if registry.Password != "" && !crypto.IsEncrypted(registry.Password) {
if encryptedPassword, err := r.encryptor.Encrypt(registry.Password); err == nil {
encrypted.Password = encryptedPassword
}
}
return &encrypted
}
// decryptRegistry 解密 Registry 的敏感数据
func (r *RegistryRepositoryMock) decryptRegistry(registry *entity.Registry) *entity.Registry {
decrypted := *registry // 复制
// 解密密码
if registry.Password != "" && crypto.IsEncrypted(registry.Password) {
if decryptedPassword, err := r.encryptor.Decrypt(registry.Password); err == nil {
decrypted.Password = decryptedPassword
}
}
return &decrypted
}

View File

@ -0,0 +1,99 @@
package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// UserRepositoryMock 用户仓储 Mock 实现(内存存储)
type UserRepositoryMock struct {
mu sync.RWMutex
users map[string]*entity.User // key: user ID
}
// NewUserRepositoryMock 创建 Mock 实现
func NewUserRepositoryMock() repository.UserRepository {
return &UserRepositoryMock{
users: make(map[string]*entity.User),
}
}
func (r *UserRepositoryMock) Create(ctx context.Context, user *entity.User) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查是否已存在
for _, u := range r.users {
if u.Username == user.Username {
return entity.ErrUserExists
}
}
r.users[user.ID] = user
return nil
}
func (r *UserRepositoryMock) GetByID(ctx context.Context, id string) (*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
user, exists := r.users[id]
if !exists {
return nil, entity.ErrUserNotFound
}
return user, nil
}
func (r *UserRepositoryMock) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, user := range r.users {
if user.Username == username {
return user, nil
}
}
return nil, entity.ErrUserNotFound
}
func (r *UserRepositoryMock) Update(ctx context.Context, user *entity.User) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.users[user.ID]; !exists {
return entity.ErrUserNotFound
}
r.users[user.ID] = user
return nil
}
func (r *UserRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.users[id]; !exists {
return entity.ErrUserNotFound
}
delete(r.users, id)
return nil
}
func (r *UserRepositoryMock) List(ctx context.Context) ([]*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
users := make([]*entity.User, 0, len(r.users))
for _, user := range r.users {
users = append(users, user)
}
return users, nil
}

View File

@ -0,0 +1,337 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// ClusterRepository PostgreSQL 集群仓储实现
type ClusterRepository struct {
db *DB
encryptor crypto.Encryptor
}
// NewClusterRepository 创建 PostgreSQL 集群仓储
func NewClusterRepository(db *DB, encryptor crypto.Encryptor) repository.ClusterRepository {
return &ClusterRepository{
db: db,
encryptor: encryptor,
}
}
// Create 创建集群
func (r *ClusterRepository) Create(ctx context.Context, cluster *entity.Cluster) error {
if cluster.ID == "" {
cluster.ID = uuid.New().String()
}
// 加密敏感数据
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
if err != nil {
return fmt.Errorf("failed to encrypt CA data: %w", err)
}
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
if err != nil {
return fmt.Errorf("failed to encrypt cert data: %w", err)
}
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
if err != nil {
return fmt.Errorf("failed to encrypt key data: %w", err)
}
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := `
INSERT INTO clusters (id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`
_, err = r.db.conn.ExecContext(ctx, query,
cluster.ID,
cluster.Name,
cluster.Host,
encryptedCAData,
encryptedCertData,
encryptedKeyData,
encryptedToken,
cluster.Description,
cluster.CreatedAt,
cluster.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create cluster: %w", err)
}
return nil
}
// GetByID 根据 ID 获取集群
func (r *ClusterRepository) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
FROM clusters
WHERE id = $1
`
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrClusterNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
return cluster, nil
}
// GetByName 根据名称获取集群
func (r *ClusterRepository) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
FROM clusters
WHERE name = $1
`
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrClusterNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
return cluster, nil
}
// Update 更新集群
func (r *ClusterRepository) Update(ctx context.Context, cluster *entity.Cluster) error {
cluster.UpdatedAt = time.Now()
// 加密敏感数据
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
if err != nil {
return fmt.Errorf("failed to encrypt CA data: %w", err)
}
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
if err != nil {
return fmt.Errorf("failed to encrypt cert data: %w", err)
}
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
if err != nil {
return fmt.Errorf("failed to encrypt key data: %w", err)
}
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := `
UPDATE clusters
SET name = $1, host = $2, ca_data = $3, cert_data = $4, key_data = $5,
token = $6, description = $7, updated_at = $8
WHERE id = $9
`
result, err := r.db.conn.ExecContext(ctx, query,
cluster.Name,
cluster.Host,
encryptedCAData,
encryptedCertData,
encryptedKeyData,
encryptedToken,
cluster.Description,
cluster.UpdatedAt,
cluster.ID,
)
if err != nil {
return fmt.Errorf("failed to update cluster: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrClusterNotFound
}
return nil
}
// Delete 删除集群
func (r *ClusterRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM clusters WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete cluster: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrClusterNotFound
}
return nil
}
// List 列出所有集群
func (r *ClusterRepository) List(ctx context.Context) ([]*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
FROM clusters
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list clusters: %w", err)
}
defer rows.Close()
clusters := make([]*entity.Cluster, 0)
for rows.Next() {
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := rows.Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan cluster: %w", err)
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
clusters = append(clusters, cluster)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return clusters, nil
}

View File

@ -0,0 +1,135 @@
package postgres
import (
"database/sql"
"fmt"
"time"
_ "github.com/lib/pq"
)
// DB 数据库连接包装器
type DB struct {
conn *sql.DB
}
// NewDB 创建新的数据库连接
func NewDB(connString string) (*DB, error) {
if connString == "" {
return nil, fmt.Errorf("database connection string cannot be empty")
}
conn, err := sql.Open("postgres", connString)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// 配置连接池
conn.SetMaxOpenConns(25)
conn.SetMaxIdleConns(5)
conn.SetConnMaxLifetime(5 * time.Minute)
// 测试连接
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return &DB{conn: conn}, nil
}
// Close 关闭数据库连接
func (db *DB) Close() error {
if db.conn != nil {
return db.conn.Close()
}
return nil
}
// GetConn 获取底层连接(用于事务等高级操作)
func (db *DB) GetConn() *sql.DB {
return db.conn
}
// InitSchema 初始化数据库 schema
func (db *DB) InitSchema() error {
schema := `
-- Users 表
CREATE TABLE IF NOT EXISTS users (
id VARCHAR(36) PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
email VARCHAR(255) NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
-- Clusters 表
CREATE TABLE IF NOT EXISTS clusters (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(255) NOT NULL UNIQUE,
host TEXT NOT NULL,
ca_data TEXT,
cert_data TEXT,
key_data TEXT,
token TEXT,
description TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_clusters_name ON clusters(name);
-- Registries 表
CREATE TABLE IF NOT EXISTS registries (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(255) NOT NULL UNIQUE,
url TEXT NOT NULL,
description TEXT,
username VARCHAR(255),
password TEXT,
insecure BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_registries_name ON registries(name);
-- Instances 表
CREATE TABLE IF NOT EXISTS instances (
id VARCHAR(36) PRIMARY KEY,
cluster_id VARCHAR(36) NOT NULL,
name VARCHAR(255) NOT NULL,
namespace VARCHAR(255) NOT NULL,
registry_id VARCHAR(36) NOT NULL,
repository TEXT NOT NULL,
chart VARCHAR(255) NOT NULL,
version VARCHAR(255) NOT NULL,
description TEXT,
values JSONB,
values_yaml TEXT,
status VARCHAR(50) NOT NULL,
status_reason TEXT,
last_operation VARCHAR(50),
last_error TEXT,
revision INTEGER NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT fk_cluster FOREIGN KEY (cluster_id) REFERENCES clusters(id) ON DELETE CASCADE,
CONSTRAINT fk_registry FOREIGN KEY (registry_id) REFERENCES registries(id) ON DELETE CASCADE,
CONSTRAINT unique_cluster_name UNIQUE (cluster_id, name, namespace)
);
CREATE INDEX IF NOT EXISTS idx_instances_cluster ON instances(cluster_id);
CREATE INDEX IF NOT EXISTS idx_instances_registry ON instances(registry_id);
CREATE INDEX IF NOT EXISTS idx_instances_name ON instances(name);
`
_, err := db.conn.Exec(schema)
if err != nil {
return fmt.Errorf("failed to initialize schema: %w", err)
}
return nil
}

View File

@ -0,0 +1,433 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// InstanceRepository PostgreSQL 实例仓储实现
type InstanceRepository struct {
db *DB
}
// NewInstanceRepository 创建 PostgreSQL 实例仓储
func NewInstanceRepository(db *DB) repository.InstanceRepository {
return &InstanceRepository{db: db}
}
// Create 创建实例
func (r *InstanceRepository) Create(ctx context.Context, instance *entity.Instance) error {
if instance.ID == "" {
instance.ID = uuid.New().String()
}
// 将 Values 转换为 JSON
valuesJSON, err := json.Marshal(instance.Values)
if err != nil {
return fmt.Errorf("failed to marshal values: %w", err)
}
query := `
INSERT INTO instances (id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)
`
_, err = r.db.conn.ExecContext(ctx, query,
instance.ID,
instance.ClusterID,
instance.Name,
instance.Namespace,
instance.RegistryID,
instance.Repository,
instance.Chart,
instance.Version,
instance.Description,
valuesJSON,
instance.ValuesYAML,
instance.Status,
instance.StatusReason,
instance.LastOperation,
instance.LastError,
instance.Revision,
instance.CreatedAt,
instance.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create instance: %w", err)
}
return nil
}
// GetByID 根据 ID 获取实例
func (r *InstanceRepository) GetByID(ctx context.Context, id string) (*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE id = $1
`
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrInstanceNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
return instance, nil
}
// GetByClusterAndName 根据集群 ID 和名称获取实例
func (r *InstanceRepository) GetByClusterAndName(ctx context.Context, clusterID, name string) (*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE cluster_id = $1 AND name = $2
`
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := r.db.conn.QueryRowContext(ctx, query, clusterID, name).Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrInstanceNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
return instance, nil
}
// Update 更新实例
func (r *InstanceRepository) Update(ctx context.Context, instance *entity.Instance) error {
instance.UpdatedAt = time.Now()
// 将 Values 转换为 JSON
valuesJSON, err := json.Marshal(instance.Values)
if err != nil {
return fmt.Errorf("failed to marshal values: %w", err)
}
query := `
UPDATE instances
SET cluster_id = $1, name = $2, namespace = $3, registry_id = $4, repository = $5,
chart = $6, version = $7, description = $8, values = $9, values_yaml = $10,
status = $11, status_reason = $12, last_operation = $13, last_error = $14,
revision = $15, updated_at = $16
WHERE id = $17
`
result, err := r.db.conn.ExecContext(ctx, query,
instance.ClusterID,
instance.Name,
instance.Namespace,
instance.RegistryID,
instance.Repository,
instance.Chart,
instance.Version,
instance.Description,
valuesJSON,
instance.ValuesYAML,
instance.Status,
instance.StatusReason,
instance.LastOperation,
instance.LastError,
instance.Revision,
instance.UpdatedAt,
instance.ID,
)
if err != nil {
return fmt.Errorf("failed to update instance: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrInstanceNotFound
}
return nil
}
// Delete 删除实例
func (r *InstanceRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM instances WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete instance: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrInstanceNotFound
}
return nil
}
// ListByCluster 列出指定集群的所有实例
func (r *InstanceRepository) ListByCluster(ctx context.Context, clusterID string) ([]*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE cluster_id = $1
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query, clusterID)
if err != nil {
return nil, fmt.Errorf("failed to list instances: %w", err)
}
defer rows.Close()
instances := make([]*entity.Instance, 0)
for rows.Next() {
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := rows.Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
instances = append(instances, instance)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return instances, nil
}
// List 列出所有实例
func (r *InstanceRepository) List(ctx context.Context) ([]*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list instances: %w", err)
}
defer rows.Close()
instances := make([]*entity.Instance, 0)
for rows.Next() {
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := rows.Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
instances = append(instances, instance)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return instances, nil
}

View File

@ -0,0 +1,257 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// RegistryRepository PostgreSQL Registry 仓储实现
type RegistryRepository struct {
db *DB
encryptor crypto.Encryptor
}
// NewRegistryRepository 创建 PostgreSQL Registry 仓储
func NewRegistryRepository(db *DB, encryptor crypto.Encryptor) repository.RegistryRepository {
return &RegistryRepository{
db: db,
encryptor: encryptor,
}
}
// Create 创建 Registry
func (r *RegistryRepository) Create(ctx context.Context, registry *entity.Registry) error {
if registry.ID == "" {
registry.ID = uuid.New().String()
}
// 加密密码
encryptedPassword, err := r.encryptor.Encrypt(registry.Password)
if err != nil {
return fmt.Errorf("failed to encrypt password: %w", err)
}
query := `
INSERT INTO registries (id, name, url, description, username, password, insecure, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
_, err = r.db.conn.ExecContext(ctx, query,
registry.ID,
registry.Name,
registry.URL,
registry.Description,
registry.Username,
encryptedPassword,
registry.Insecure,
registry.CreatedAt,
registry.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create registry: %w", err)
}
return nil
}
// GetByID 根据 ID 获取 Registry
func (r *RegistryRepository) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
WHERE id = $1
`
registry := &entity.Registry{}
var encryptedPassword string
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrRegistryNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get registry: %w", err)
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
return registry, nil
}
// GetByName 根据名称获取 Registry
func (r *RegistryRepository) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
WHERE name = $1
`
registry := &entity.Registry{}
var encryptedPassword string
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrRegistryNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get registry: %w", err)
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
return registry, nil
}
// Update 更新 Registry
func (r *RegistryRepository) Update(ctx context.Context, registry *entity.Registry) error {
registry.UpdatedAt = time.Now()
// 加密密码
encryptedPassword, err := r.encryptor.Encrypt(registry.Password)
if err != nil {
return fmt.Errorf("failed to encrypt password: %w", err)
}
query := `
UPDATE registries
SET name = $1, url = $2, description = $3, username = $4, password = $5,
insecure = $6, updated_at = $7
WHERE id = $8
`
result, err := r.db.conn.ExecContext(ctx, query,
registry.Name,
registry.URL,
registry.Description,
registry.Username,
encryptedPassword,
registry.Insecure,
registry.UpdatedAt,
registry.ID,
)
if err != nil {
return fmt.Errorf("failed to update registry: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrRegistryNotFound
}
return nil
}
// Delete 删除 Registry
func (r *RegistryRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM registries WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete registry: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrRegistryNotFound
}
return nil
}
// List 列出所有 Registries
func (r *RegistryRepository) List(ctx context.Context) ([]*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list registries: %w", err)
}
defer rows.Close()
registries := make([]*entity.Registry, 0)
for rows.Next() {
registry := &entity.Registry{}
var encryptedPassword string
err := rows.Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan registry: %w", err)
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
registries = append(registries, registry)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return registries, nil
}

View File

@ -0,0 +1,204 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// UserRepository PostgreSQL 用户仓储实现
type UserRepository struct {
db *DB
}
// NewUserRepository 创建 PostgreSQL 用户仓储
func NewUserRepository(db *DB) repository.UserRepository {
return &UserRepository{db: db}
}
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, user *entity.User) error {
if user.ID == "" {
user.ID = uuid.New().String()
}
query := `
INSERT INTO users (id, username, password_hash, email, revoked_after, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := r.db.conn.ExecContext(ctx, query,
user.ID,
user.Username,
user.PasswordHash,
user.Email,
user.RevokedAfter,
user.CreatedAt,
user.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create user: %w", err)
}
return nil
}
// GetByID 根据 ID 获取用户
func (r *UserRepository) GetByID(ctx context.Context, id string) (*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
FROM users
WHERE id = $1
`
user := &entity.User{}
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&user.ID,
&user.Username,
&user.PasswordHash,
&user.Email,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrUserNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
FROM users
WHERE username = $1
`
user := &entity.User{}
err := r.db.conn.QueryRowContext(ctx, query, username).Scan(
&user.ID,
&user.Username,
&user.PasswordHash,
&user.Email,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrUserNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
return user, nil
}
// Update 更新用户
func (r *UserRepository) Update(ctx context.Context, user *entity.User) error {
user.UpdatedAt = time.Now()
query := `
UPDATE users
SET username = $1, password_hash = $2, email = $3, revoked_after = $4, updated_at = $5
WHERE id = $6
`
result, err := r.db.conn.ExecContext(ctx, query,
user.Username,
user.PasswordHash,
user.Email,
user.RevokedAfter,
user.UpdatedAt,
user.ID,
)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrUserNotFound
}
return nil
}
// Delete 删除用户
func (r *UserRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM users WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrUserNotFound
}
return nil
}
// List 列出所有用户
func (r *UserRepository) List(ctx context.Context) ([]*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
FROM users
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
defer rows.Close()
users := make([]*entity.User, 0)
for rows.Next() {
user := &entity.User{}
err := rows.Scan(
&user.ID,
&user.Username,
&user.PasswordHash,
&user.Email,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan user: %w", err)
}
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return users, nil
}