refactor: full-stack restructure with multi-tenancy, workspace management, and K8s diagnostics

- Add Workspace domain (entity, repository, service, handler, DTO)
- Add multi-tenant K8s client with tenant binding and quota management
- Add K8s diagnostics client (instance diagnostics)
- Add authorization middleware (authz package)
- Restructure frontend to feature-based architecture (features/)
- Add User Management page in configuration
- Add AccessDenied page and route guards
- Refactor shared components (form inputs, layout, UI)
- Update Tailwind config for new design system
- Add comprehensive documentation (docs/, tasks/, plans)
- Improve cluster service with better kubeconfig handling
- Add tests for crypto, config, helm client, tenant binding
This commit is contained in:
Ivan087
2026-05-12 16:15:14 +08:00
parent c5e51ed069
commit 7f238a3168
172 changed files with 15703 additions and 3162 deletions

View File

@ -96,6 +96,36 @@ func (f *AdapterFactory) CreateInstanceRepository() (repository.InstanceReposito
return postgres.NewInstanceRepository(f.db), nil
}
func (f *AdapterFactory) CreateWorkspaceRepository() (repository.WorkspaceRepository, error) {
if f.mode == ModeMock {
return mock.NewWorkspaceRepositoryMock(), nil
}
if err := f.ensureDBConnection(); err != nil {
return nil, err
}
return postgres.NewWorkspaceRepository(f.db), nil
}
func (f *AdapterFactory) CreateWorkspaceClusterBindingRepository() (repository.WorkspaceClusterBindingRepository, error) {
if f.mode == ModeMock {
return mock.NewWorkspaceClusterBindingRepositoryMock(), nil
}
if err := f.ensureDBConnection(); err != nil {
return nil, err
}
return postgres.NewWorkspaceClusterBindingRepository(f.db), nil
}
func (f *AdapterFactory) CreateAuditLogRepository() (repository.AuditLogRepository, error) {
if f.mode == ModeMock {
return mock.NewAuditLogRepositoryMock(), nil
}
if err := f.ensureDBConnection(); err != nil {
return nil, err
}
return postgres.NewAuditLogRepository(f.db), nil
}
// CreateOCIClient 创建 OCI 客户端
func (f *AdapterFactory) CreateOCIClient() (repository.OCIClient, error) {
if f.mode == ModeMock {
@ -127,6 +157,20 @@ func (f *AdapterFactory) CreateEntryClient() repository.InstanceEntryClient {
return k8s.NewEntryClient()
}
func (f *AdapterFactory) CreateDiagnosticsClient() repository.InstanceDiagnosticsClient {
if f.mode == ModeMock {
return k8s.NewMockDiagnosticsClient()
}
return k8s.NewDiagnosticsClient()
}
func (f *AdapterFactory) CreateTenantKubeClient() repository.TenantKubeClient {
if f.mode == ModeMock {
return k8s.NewMockTenantClient()
}
return k8s.NewTenantClient()
}
// CreateAllRepositories 一次性创建所有 Repositories
func (f *AdapterFactory) CreateAllRepositories() (*Repositories, error) {
userRepo, err := f.CreateUserRepository()
@ -149,6 +193,21 @@ func (f *AdapterFactory) CreateAllRepositories() (*Repositories, error) {
return nil, fmt.Errorf("failed to create instance repository: %w", err)
}
workspaceRepo, err := f.CreateWorkspaceRepository()
if err != nil {
return nil, fmt.Errorf("failed to create workspace repository: %w", err)
}
bindingRepo, err := f.CreateWorkspaceClusterBindingRepository()
if err != nil {
return nil, fmt.Errorf("failed to create workspace cluster binding repository: %w", err)
}
auditRepo, err := f.CreateAuditLogRepository()
if err != nil {
return nil, fmt.Errorf("failed to create audit log repository: %w", err)
}
ociClient, err := f.CreateOCIClient()
if err != nil {
return nil, fmt.Errorf("failed to create OCI client: %w", err)
@ -162,29 +221,41 @@ func (f *AdapterFactory) CreateAllRepositories() (*Repositories, error) {
// 创建 Metrics client依赖 clusterRepo
metricsClient := f.CreateMetricsClient(clusterRepo)
entryClient := f.CreateEntryClient()
diagnosticsClient := f.CreateDiagnosticsClient()
tenantClient := f.CreateTenantKubeClient()
return &Repositories{
UserRepo: userRepo,
ClusterRepo: clusterRepo,
RegistryRepo: registryRepo,
InstanceRepo: instanceRepo,
OCIClient: ociClient,
HelmClient: helmClient,
MetricsClient: metricsClient,
EntryClient: entryClient,
UserRepo: userRepo,
WorkspaceRepo: workspaceRepo,
BindingRepo: bindingRepo,
AuditRepo: auditRepo,
ClusterRepo: clusterRepo,
RegistryRepo: registryRepo,
InstanceRepo: instanceRepo,
OCIClient: ociClient,
HelmClient: helmClient,
MetricsClient: metricsClient,
EntryClient: entryClient,
DiagnosticsClient: diagnosticsClient,
TenantKubeClient: tenantClient,
}, nil
}
// Repositories 所有仓储的集合
type Repositories struct {
UserRepo repository.UserRepository
ClusterRepo repository.ClusterRepository
RegistryRepo repository.RegistryRepository
InstanceRepo repository.InstanceRepository
OCIClient repository.OCIClient
HelmClient repository.HelmClient
MetricsClient repository.MetricsClient
EntryClient repository.InstanceEntryClient
UserRepo repository.UserRepository
WorkspaceRepo repository.WorkspaceRepository
BindingRepo repository.WorkspaceClusterBindingRepository
AuditRepo repository.AuditLogRepository
ClusterRepo repository.ClusterRepository
RegistryRepo repository.RegistryRepository
InstanceRepo repository.InstanceRepository
OCIClient repository.OCIClient
HelmClient repository.HelmClient
MetricsClient repository.MetricsClient
EntryClient repository.InstanceEntryClient
DiagnosticsClient repository.InstanceDiagnosticsClient
TenantKubeClient repository.TenantKubeClient
}
// ensureDBConnection 确保数据库连接已建立

View File

@ -21,6 +21,7 @@ import (
"k8s.io/client-go/rest"
"k8s.io/client-go/restmapper"
"k8s.io/client-go/tools/clientcmd"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
)
// HelmClient 真实的 Helm 客户端实现
@ -36,39 +37,45 @@ func NewHelmClient() repository.HelmClient {
}
// getActionConfig 获取 Helm action configuration
func (h *HelmClient) getActionConfig(cluster *entity.Cluster, namespace string) (*action.Configuration, error) {
func (h *HelmClient) getActionConfig(cluster *entity.Cluster, namespace string) (*action.Configuration, func(), error) {
actionConfig := new(action.Configuration)
// 创建临时 kubeconfig 文件
kubeconfigContent := cluster.GetKubeConfig()
tmpDir, err := os.MkdirTemp("", "helm-kubeconfig-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp dir: %w", err)
return nil, nil, fmt.Errorf("failed to create temp dir: %w", err)
}
cleanup := func() {
_ = os.RemoveAll(tmpDir)
}
kubeconfigPath := filepath.Join(tmpDir, "kubeconfig")
if err := os.WriteFile(kubeconfigPath, []byte(kubeconfigContent), 0600); err != nil {
return nil, fmt.Errorf("failed to write kubeconfig: %w", err)
cleanup()
return nil, nil, fmt.Errorf("failed to write kubeconfig: %w", err)
}
// 使用 kubeconfig 初始化 action config
if err := actionConfig.Init(
&kubeconfigGetter{kubeconfigPath: kubeconfigPath},
&kubeconfigGetter{kubeconfigPath: kubeconfigPath, namespace: namespace},
namespace,
os.Getenv("HELM_DRIVER"), // storage driver: configmap, secret, memory
func(format string, v ...interface{}) {
// Log function
},
); err != nil {
return nil, fmt.Errorf("failed to initialize action config: %w", err)
cleanup()
return nil, nil, fmt.Errorf("failed to initialize action config: %w", err)
}
return actionConfig, nil
return actionConfig, cleanup, nil
}
// kubeconfigGetter implements RESTClientGetter
type kubeconfigGetter struct {
kubeconfigPath string
namespace string
}
func (k *kubeconfigGetter) ToRESTConfig() (*rest.Config, error) {
@ -95,25 +102,30 @@ func (k *kubeconfigGetter) ToRESTMapper() (meta.RESTMapper, error) {
}
func (k *kubeconfigGetter) ToRawKubeConfigLoader() clientcmd.ClientConfig {
overrides := &clientcmd.ConfigOverrides{}
if k.namespace != "" {
overrides.Context = clientcmdapi.Context{Namespace: k.namespace}
}
return clientcmd.NewNonInteractiveDeferredLoadingClientConfig(
&clientcmd.ClientConfigLoadingRules{ExplicitPath: k.kubeconfigPath},
&clientcmd.ConfigOverrides{},
overrides,
)
}
// Install 安装 Helm Chart
func (h *HelmClient) Install(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
actionConfig, err := h.getActionConfig(cluster, instance.Namespace)
actionConfig, cleanup, err := h.getActionConfig(cluster, instance.Namespace)
if err != nil {
return err
}
defer cleanup()
install := action.NewInstall(actionConfig)
install.ReleaseName = instance.Name
install.Namespace = instance.Namespace
install.CreateNamespace = true
install.Wait = true
install.Timeout = 5 * time.Minute
install.Timeout = helmOperationTimeout()
// 加载 Chart从本地路径或 OCI registry
// 这里简化处理,假设 chart 已经被拉取到本地
@ -139,15 +151,16 @@ func (h *HelmClient) Install(ctx context.Context, cluster *entity.Cluster, insta
// Upgrade 升级 Helm Release
func (h *HelmClient) Upgrade(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance) error {
actionConfig, err := h.getActionConfig(cluster, instance.Namespace)
actionConfig, cleanup, err := h.getActionConfig(cluster, instance.Namespace)
if err != nil {
return err
}
defer cleanup()
upgrade := action.NewUpgrade(actionConfig)
upgrade.Namespace = instance.Namespace
upgrade.Wait = true
upgrade.Timeout = 5 * time.Minute
upgrade.Timeout = helmOperationTimeout()
// 加载 Chart
chartPath := fmt.Sprintf("/tmp/charts/%s-%s.tgz", instance.Chart, instance.Version)
@ -172,14 +185,15 @@ func (h *HelmClient) Upgrade(ctx context.Context, cluster *entity.Cluster, insta
// Uninstall 卸载 Helm Release
func (h *HelmClient) Uninstall(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) error {
actionConfig, err := h.getActionConfig(cluster, namespace)
actionConfig, cleanup, err := h.getActionConfig(cluster, namespace)
if err != nil {
return err
}
defer cleanup()
uninstall := action.NewUninstall(actionConfig)
uninstall.Wait = true
uninstall.Timeout = 5 * time.Minute
uninstall.Timeout = helmOperationTimeout()
_, err = uninstall.Run(releaseName)
if err != nil {
@ -194,15 +208,16 @@ func (h *HelmClient) Uninstall(ctx context.Context, cluster *entity.Cluster, rel
// Rollback 回滚 Helm Release
func (h *HelmClient) Rollback(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string, revision int) error {
actionConfig, err := h.getActionConfig(cluster, namespace)
actionConfig, cleanup, err := h.getActionConfig(cluster, namespace)
if err != nil {
return err
}
defer cleanup()
rollback := action.NewRollback(actionConfig)
rollback.Version = revision
rollback.Wait = true
rollback.Timeout = 5 * time.Minute
rollback.Timeout = helmOperationTimeout()
if err := rollback.Run(releaseName); err != nil {
return fmt.Errorf("failed to rollback release: %w", err)
@ -211,12 +226,25 @@ func (h *HelmClient) Rollback(ctx context.Context, cluster *entity.Cluster, rele
return nil
}
func helmOperationTimeout() time.Duration {
raw := os.Getenv("HELM_OPERATION_TIMEOUT")
if raw == "" {
return 15 * time.Minute
}
timeout, err := time.ParseDuration(raw)
if err != nil || timeout <= 0 {
return 15 * time.Minute
}
return timeout
}
// GetStatus 获取 Release 状态
func (h *HelmClient) GetStatus(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (*entity.Instance, error) {
actionConfig, err := h.getActionConfig(cluster, namespace)
actionConfig, cleanup, err := h.getActionConfig(cluster, namespace)
if err != nil {
return nil, err
}
defer cleanup()
status := action.NewStatus(actionConfig)
rel, err := status.Run(releaseName)
@ -229,10 +257,11 @@ func (h *HelmClient) GetStatus(ctx context.Context, cluster *entity.Cluster, rel
// GetHistory 获取 Release 历史
func (h *HelmClient) GetHistory(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) ([]*entity.ReleaseHistory, error) {
actionConfig, err := h.getActionConfig(cluster, namespace)
actionConfig, cleanup, err := h.getActionConfig(cluster, namespace)
if err != nil {
return nil, err
}
defer cleanup()
history := action.NewHistory(actionConfig)
history.Max = 256
@ -259,10 +288,11 @@ func (h *HelmClient) GetHistory(ctx context.Context, cluster *entity.Cluster, re
// List 列出集群中的所有 Releases
func (h *HelmClient) List(ctx context.Context, cluster *entity.Cluster, namespace string) ([]*entity.Instance, error) {
actionConfig, err := h.getActionConfig(cluster, namespace)
actionConfig, cleanup, err := h.getActionConfig(cluster, namespace)
if err != nil {
return nil, err
}
defer cleanup()
list := action.NewList(actionConfig)
if namespace == "" {
@ -284,10 +314,11 @@ func (h *HelmClient) List(ctx context.Context, cluster *entity.Cluster, namespac
// GetValues 获取 Release 的 values
func (h *HelmClient) GetValues(ctx context.Context, cluster *entity.Cluster, releaseName, namespace string) (map[string]interface{}, error) {
actionConfig, err := h.getActionConfig(cluster, namespace)
actionConfig, cleanup, err := h.getActionConfig(cluster, namespace)
if err != nil {
return nil, err
}
defer cleanup()
getValues := action.NewGetValues(actionConfig)
values, err := getValues.Run(releaseName)

View File

@ -0,0 +1,45 @@
package real
import (
"os"
"path/filepath"
"testing"
)
func TestKubeconfigGetterOverridesNamespace(t *testing.T) {
t.Parallel()
kubeconfigPath := filepath.Join(t.TempDir(), "kubeconfig")
kubeconfig := `apiVersion: v1
kind: Config
clusters:
- cluster:
server: https://127.0.0.1:6443
name: test
contexts:
- context:
cluster: test
user: test
name: test
current-context: test
users:
- name: test
user:
token: test
`
if err := os.WriteFile(kubeconfigPath, []byte(kubeconfig), 0600); err != nil {
t.Fatalf("failed to write kubeconfig: %v", err)
}
getter := &kubeconfigGetter{
kubeconfigPath: kubeconfigPath,
namespace: "ocdp-u-alice",
}
namespace, _, err := getter.ToRawKubeConfigLoader().Namespace()
if err != nil {
t.Fatalf("Namespace returned error: %v", err)
}
if namespace != "ocdp-u-alice" {
t.Fatalf("expected namespace override %q, got %q", "ocdp-u-alice", namespace)
}
}

View File

@ -0,0 +1,294 @@
package k8s
import (
"context"
"fmt"
"io"
"sort"
"strings"
"time"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
type DiagnosticsClient struct{}
func NewDiagnosticsClient() repository.InstanceDiagnosticsClient {
return &DiagnosticsClient{}
}
type MockDiagnosticsClient struct{}
func NewMockDiagnosticsClient() repository.InstanceDiagnosticsClient {
return &MockDiagnosticsClient{}
}
func (*MockDiagnosticsClient) GetDiagnostics(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance, tailLines int64) (*entity.InstanceDiagnostics, error) {
return &entity.InstanceDiagnostics{
InstanceName: instance.Name,
Namespace: instance.Namespace,
CollectedAt: time.Now(),
}, nil
}
func (c *DiagnosticsClient) GetDiagnostics(ctx context.Context, cluster *entity.Cluster, instance *entity.Instance, tailLines int64) (*entity.InstanceDiagnostics, error) {
clientset, err := diagnosticsClientset(cluster)
if err != nil {
return nil, err
}
if tailLines <= 0 {
tailLines = 200
}
if tailLines > 2000 {
tailLines = 2000
}
pods, err := listInstancePods(ctx, clientset, instance)
if err != nil {
return nil, err
}
services, err := listInstanceServices(ctx, clientset, instance)
if err != nil {
return nil, err
}
events, err := listInstanceEvents(ctx, clientset, instance, pods, services)
if err != nil {
return nil, err
}
logs := collectPodLogs(ctx, clientset, pods, tailLines)
return &entity.InstanceDiagnostics{
InstanceName: instance.Name,
Namespace: instance.Namespace,
Pods: convertPodsToDiagnostics(pods),
Services: convertServicesToDiagnostics(services),
Events: convertEventsToDiagnostics(events),
Logs: logs,
CollectedAt: time.Now(),
}, nil
}
func diagnosticsClientset(cluster *entity.Cluster) (kubernetes.Interface, error) {
config, err := restConfigFromCluster(cluster)
if err != nil {
return nil, err
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create diagnostics kubernetes client: %w", err)
}
return clientset, nil
}
func listInstancePods(ctx context.Context, clientset kubernetes.Interface, instance *entity.Instance) ([]corev1.Pod, error) {
selector := fmt.Sprintf("app.kubernetes.io/instance=%s", instance.Name)
pods, err := clientset.CoreV1().Pods(instance.Namespace).List(ctx, metav1.ListOptions{LabelSelector: selector})
if err != nil {
return nil, fmt.Errorf("failed to list instance pods: %w", err)
}
if len(pods.Items) > 0 {
return pods.Items, nil
}
all, err := clientset.CoreV1().Pods(instance.Namespace).List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list namespace pods: %w", err)
}
filtered := make([]corev1.Pod, 0)
for _, pod := range all.Items {
if resourceMatchesInstance(pod.ObjectMeta, instance) {
filtered = append(filtered, pod)
}
}
return filtered, nil
}
func listInstanceServices(ctx context.Context, clientset kubernetes.Interface, instance *entity.Instance) ([]corev1.Service, error) {
selector := fmt.Sprintf("app.kubernetes.io/instance=%s", instance.Name)
services, err := clientset.CoreV1().Services(instance.Namespace).List(ctx, metav1.ListOptions{LabelSelector: selector})
if err != nil {
return nil, fmt.Errorf("failed to list instance services: %w", err)
}
if len(services.Items) > 0 {
return services.Items, nil
}
all, err := clientset.CoreV1().Services(instance.Namespace).List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list namespace services: %w", err)
}
filtered := make([]corev1.Service, 0)
for _, svc := range all.Items {
if resourceMatchesInstance(svc.ObjectMeta, instance) {
filtered = append(filtered, svc)
}
}
return filtered, nil
}
func listInstanceEvents(ctx context.Context, clientset kubernetes.Interface, instance *entity.Instance, pods []corev1.Pod, services []corev1.Service) ([]corev1.Event, error) {
events, err := clientset.CoreV1().Events(instance.Namespace).List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list instance events: %w", err)
}
names := map[string]bool{instance.Name: true}
for _, pod := range pods {
names[pod.Name] = true
}
for _, svc := range services {
names[svc.Name] = true
}
filtered := make([]corev1.Event, 0)
for _, event := range events.Items {
if names[event.InvolvedObject.Name] || strings.Contains(event.Message, instance.Name) {
filtered = append(filtered, event)
}
}
sort.SliceStable(filtered, func(i, j int) bool {
return filtered[i].LastTimestamp.Time.After(filtered[j].LastTimestamp.Time)
})
if len(filtered) > 100 {
filtered = filtered[:100]
}
return filtered, nil
}
func collectPodLogs(ctx context.Context, clientset kubernetes.Interface, pods []corev1.Pod, tailLines int64) []entity.InstancePodLog {
logs := make([]entity.InstancePodLog, 0)
for _, pod := range pods {
for _, container := range pod.Spec.Containers {
item := entity.InstancePodLog{Pod: pod.Name, Container: container.Name, TailLines: tailLines}
req := clientset.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, &corev1.PodLogOptions{
Container: container.Name,
TailLines: &tailLines,
})
stream, err := req.Stream(ctx)
if err != nil {
item.Error = err.Error()
logs = append(logs, item)
continue
}
data, err := io.ReadAll(io.LimitReader(stream, 1<<20))
_ = stream.Close()
if err != nil {
item.Error = err.Error()
} else {
item.Log = string(data)
}
logs = append(logs, item)
}
}
return logs
}
func convertPodsToDiagnostics(pods []corev1.Pod) []entity.InstancePodDiagnostics {
out := make([]entity.InstancePodDiagnostics, 0, len(pods))
for _, pod := range pods {
containers := make([]entity.InstanceContainerDiagnostics, 0, len(pod.Status.ContainerStatuses))
var restarts int32
for _, status := range pod.Status.ContainerStatuses {
restarts += status.RestartCount
containers = append(containers, entity.InstanceContainerDiagnostics{
Name: status.Name,
Image: status.Image,
Ready: status.Ready,
RestartCount: status.RestartCount,
State: containerStateName(status.State),
Reason: containerStateReason(status.State),
Message: containerStateMessage(status.State),
})
}
conditions := make([]entity.InstanceConditionDiagnostics, 0, len(pod.Status.Conditions))
for _, condition := range pod.Status.Conditions {
conditions = append(conditions, entity.InstanceConditionDiagnostics{
Type: string(condition.Type),
Status: string(condition.Status),
Reason: condition.Reason,
Message: condition.Message,
})
}
out = append(out, entity.InstancePodDiagnostics{
Name: pod.Name,
Namespace: pod.Namespace,
Phase: string(pod.Status.Phase),
NodeName: pod.Spec.NodeName,
PodIP: pod.Status.PodIP,
HostIP: pod.Status.HostIP,
RestartCount: restarts,
Containers: containers,
Conditions: conditions,
CreationTimestamp: pod.CreationTimestamp.Time,
})
}
return out
}
func convertServicesToDiagnostics(services []corev1.Service) []entity.InstanceServiceDiagnostics {
out := make([]entity.InstanceServiceDiagnostics, 0, len(services))
for _, svc := range services {
entry := convertServiceToEntry(&svc)
out = append(out, entity.InstanceServiceDiagnostics{
Name: svc.Name,
Namespace: svc.Namespace,
Type: string(svc.Spec.Type),
ClusterIP: svc.Spec.ClusterIP,
Ports: entry.Ports,
})
}
return out
}
func convertEventsToDiagnostics(events []corev1.Event) []entity.InstanceEventDiagnostics {
out := make([]entity.InstanceEventDiagnostics, 0, len(events))
for _, event := range events {
out = append(out, entity.InstanceEventDiagnostics{
Type: event.Type,
Reason: event.Reason,
Message: event.Message,
InvolvedKind: event.InvolvedObject.Kind,
InvolvedName: event.InvolvedObject.Name,
Count: event.Count,
FirstTimestamp: event.FirstTimestamp.Time,
LastTimestamp: event.LastTimestamp.Time,
})
}
return out
}
func containerStateName(state corev1.ContainerState) string {
switch {
case state.Running != nil:
return "running"
case state.Waiting != nil:
return "waiting"
case state.Terminated != nil:
return "terminated"
default:
return "unknown"
}
}
func containerStateReason(state corev1.ContainerState) string {
switch {
case state.Waiting != nil:
return state.Waiting.Reason
case state.Terminated != nil:
return state.Terminated.Reason
default:
return ""
}
}
func containerStateMessage(state corev1.ContainerState) string {
switch {
case state.Waiting != nil:
return state.Waiting.Message
case state.Terminated != nil:
return state.Terminated.Message
default:
return ""
}
}

View File

@ -0,0 +1,388 @@
package k8s
import (
"context"
"encoding/base64"
"fmt"
"strings"
"time"
authenticationv1 "k8s.io/api/authentication/v1"
corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// TenantClient provisions namespace-scoped tenant Kubernetes resources.
type TenantClient struct {
clientset kubernetes.Interface
}
// NewTenantClient creates a tenant provisioning client that builds Kubernetes
// clients from the supplied cluster entity for each call.
func NewTenantClient() repository.TenantKubeClient {
return &TenantClient{}
}
// NewTenantClientForClientset creates a tenant provisioning client for tests or
// callers that already own a Kubernetes client.
func NewTenantClientForClientset(clientset kubernetes.Interface) repository.TenantKubeClient {
return &TenantClient{clientset: clientset}
}
// EnsureTenant idempotently ensures Namespace, ServiceAccount, RoleBinding, and
// ResourceQuota resources for the tenant binding.
func (c *TenantClient) EnsureTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
binding = binding.WithDefaults()
if err := binding.Validate(); err != nil {
return err
}
clientset, _, err := c.clientsetForCluster(cluster)
if err != nil {
return err
}
if err := c.ensureNamespace(ctx, clientset, binding); err != nil {
return err
}
if err := c.ensureServiceAccount(ctx, clientset, binding); err != nil {
return err
}
if err := c.ensureRoleBinding(ctx, clientset, binding); err != nil {
return err
}
if err := c.ensureResourceQuota(ctx, clientset, binding); err != nil {
return err
}
return nil
}
// IssueKubeconfig returns a short-lived kubeconfig backed by a Kubernetes
// TokenRequest. The token exists only in the returned value and is never stored.
func (c *TenantClient) IssueKubeconfig(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding, ttl time.Duration) (*entity.TenantKubeconfig, error) {
binding = binding.WithDefaults()
if err := binding.Validate(); err != nil {
return nil, err
}
clientset, restConfig, err := c.clientsetForCluster(cluster)
if err != nil {
return nil, err
}
cappedTTL := entity.TenantTokenTTL(ttl)
expirationSeconds := int64(cappedTTL.Seconds())
tokenRequest, err := clientset.CoreV1().
ServiceAccounts(binding.Namespace).
CreateToken(ctx, binding.ServiceAccountName, &authenticationv1.TokenRequest{
Spec: authenticationv1.TokenRequestSpec{
ExpirationSeconds: &expirationSeconds,
},
}, metav1.CreateOptions{})
if err != nil {
return nil, fmt.Errorf("failed to request tenant service account token: %w", err)
}
if tokenRequest.Status.Token == "" {
return nil, entity.ErrInvalidTenantKubeconfigToken
}
expiresAt := tokenRequest.Status.ExpirationTimestamp.Time
if expiresAt.IsZero() {
expiresAt = time.Now().Add(cappedTTL)
}
kubeconfig, err := buildTenantKubeconfig(cluster, restConfig, binding, tokenRequest.Status.Token)
if err != nil {
return nil, err
}
return &entity.TenantKubeconfig{
Kubeconfig: kubeconfig,
ExpiresAt: expiresAt,
}, nil
}
// SuspendTenant revokes tenant API access by deleting only the RoleBinding.
func (c *TenantClient) SuspendTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
binding = binding.WithDefaults()
if err := binding.Validate(); err != nil {
return err
}
clientset, _, err := c.clientsetForCluster(cluster)
if err != nil {
return err
}
err = clientset.RbacV1().
RoleBindings(binding.Namespace).
Delete(ctx, binding.RoleBindingName, metav1.DeleteOptions{})
if apierrors.IsNotFound(err) {
return nil
}
if err != nil {
return fmt.Errorf("failed to delete tenant role binding: %w", err)
}
return nil
}
func (c *TenantClient) clientsetForCluster(cluster *entity.Cluster) (kubernetes.Interface, *rest.Config, error) {
if c.clientset != nil {
config := &rest.Config{Host: "https://kubernetes.default.svc"}
if cluster != nil {
clusterConfig, err := restConfigFromCluster(cluster)
if err == nil {
config = clusterConfig
}
}
return c.clientset, config, nil
}
config, err := restConfigFromCluster(cluster)
if err != nil {
return nil, nil, err
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, nil, fmt.Errorf("failed to create tenant kubernetes client: %w", err)
}
return clientset, config, nil
}
func restConfigFromCluster(cluster *entity.Cluster) (*rest.Config, error) {
if cluster == nil {
return nil, entity.ErrInvalidClusterHost
}
if looksLikeKubeconfig(cluster.CAData) {
config, err := clientcmd.RESTConfigFromKubeConfig([]byte(cluster.CAData))
if err != nil {
return nil, fmt.Errorf("failed to parse tenant kubeconfig: %w", err)
}
return config, nil
}
if strings.TrimSpace(cluster.Host) == "" {
return nil, entity.ErrInvalidClusterHost
}
return &rest.Config{
Host: cluster.Host,
TLSClientConfig: rest.TLSClientConfig{
CAData: decodePossiblyBase64(cluster.CAData),
CertData: decodePossiblyBase64(cluster.CertData),
KeyData: decodePossiblyBase64(cluster.KeyData),
},
BearerToken: cluster.Token,
}, nil
}
func (c *TenantClient) ensureNamespace(ctx context.Context, clientset kubernetes.Interface, binding entity.TenantBinding) error {
namespace := &corev1.Namespace{
ObjectMeta: metav1.ObjectMeta{
Name: binding.Namespace,
Labels: copyStringMap(binding.Labels),
Annotations: copyStringMap(binding.Annotations),
},
}
_, err := clientset.CoreV1().Namespaces().Create(ctx, namespace, metav1.CreateOptions{})
if apierrors.IsAlreadyExists(err) {
current, getErr := clientset.CoreV1().Namespaces().Get(ctx, binding.Namespace, metav1.GetOptions{})
if getErr != nil {
return fmt.Errorf("failed to get tenant namespace: %w", getErr)
}
mergeObjectMetadata(&current.ObjectMeta, binding.Labels, binding.Annotations)
if _, updateErr := clientset.CoreV1().Namespaces().Update(ctx, current, metav1.UpdateOptions{}); updateErr != nil {
return fmt.Errorf("failed to update tenant namespace: %w", updateErr)
}
return nil
}
if err != nil {
return fmt.Errorf("failed to create tenant namespace: %w", err)
}
return nil
}
func (c *TenantClient) ensureServiceAccount(ctx context.Context, clientset kubernetes.Interface, binding entity.TenantBinding) error {
serviceAccount := &corev1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{
Name: binding.ServiceAccountName,
Namespace: binding.Namespace,
Labels: copyStringMap(binding.Labels),
Annotations: copyStringMap(binding.Annotations),
},
}
_, err := clientset.CoreV1().ServiceAccounts(binding.Namespace).Create(ctx, serviceAccount, metav1.CreateOptions{})
if apierrors.IsAlreadyExists(err) {
current, getErr := clientset.CoreV1().ServiceAccounts(binding.Namespace).Get(ctx, binding.ServiceAccountName, metav1.GetOptions{})
if getErr != nil {
return fmt.Errorf("failed to get tenant service account: %w", getErr)
}
mergeObjectMetadata(&current.ObjectMeta, binding.Labels, binding.Annotations)
if _, updateErr := clientset.CoreV1().ServiceAccounts(binding.Namespace).Update(ctx, current, metav1.UpdateOptions{}); updateErr != nil {
return fmt.Errorf("failed to update tenant service account: %w", updateErr)
}
return nil
}
if err != nil {
return fmt.Errorf("failed to create tenant service account: %w", err)
}
return nil
}
func (c *TenantClient) ensureRoleBinding(ctx context.Context, clientset kubernetes.Interface, binding entity.TenantBinding) error {
roleBinding := desiredRoleBinding(binding)
_, err := clientset.RbacV1().RoleBindings(binding.Namespace).Create(ctx, roleBinding, metav1.CreateOptions{})
if apierrors.IsAlreadyExists(err) {
current, getErr := clientset.RbacV1().RoleBindings(binding.Namespace).Get(ctx, binding.RoleBindingName, metav1.GetOptions{})
if getErr != nil {
return fmt.Errorf("failed to get tenant role binding: %w", getErr)
}
mergeObjectMetadata(&current.ObjectMeta, binding.Labels, binding.Annotations)
current.Subjects = roleBinding.Subjects
current.RoleRef = roleBinding.RoleRef
if _, updateErr := clientset.RbacV1().RoleBindings(binding.Namespace).Update(ctx, current, metav1.UpdateOptions{}); updateErr != nil {
return fmt.Errorf("failed to update tenant role binding: %w", updateErr)
}
return nil
}
if err != nil {
return fmt.Errorf("failed to create tenant role binding: %w", err)
}
return nil
}
func (c *TenantClient) ensureResourceQuota(ctx context.Context, clientset kubernetes.Interface, binding entity.TenantBinding) error {
resourceQuota := &corev1.ResourceQuota{
ObjectMeta: metav1.ObjectMeta{
Name: binding.ResourceQuotaName,
Namespace: binding.Namespace,
Labels: copyStringMap(binding.Labels),
Annotations: copyStringMap(binding.Annotations),
},
Spec: corev1.ResourceQuotaSpec{
Hard: binding.ResourceQuotaHard.DeepCopy(),
},
}
_, err := clientset.CoreV1().ResourceQuotas(binding.Namespace).Create(ctx, resourceQuota, metav1.CreateOptions{})
if apierrors.IsAlreadyExists(err) {
current, getErr := clientset.CoreV1().ResourceQuotas(binding.Namespace).Get(ctx, binding.ResourceQuotaName, metav1.GetOptions{})
if getErr != nil {
return fmt.Errorf("failed to get tenant resource quota: %w", getErr)
}
mergeObjectMetadata(&current.ObjectMeta, binding.Labels, binding.Annotations)
current.Spec.Hard = binding.ResourceQuotaHard.DeepCopy()
if _, updateErr := clientset.CoreV1().ResourceQuotas(binding.Namespace).Update(ctx, current, metav1.UpdateOptions{}); updateErr != nil {
return fmt.Errorf("failed to update tenant resource quota: %w", updateErr)
}
return nil
}
if err != nil {
return fmt.Errorf("failed to create tenant resource quota: %w", err)
}
return nil
}
func desiredRoleBinding(binding entity.TenantBinding) *rbacv1.RoleBinding {
return &rbacv1.RoleBinding{
ObjectMeta: metav1.ObjectMeta{
Name: binding.RoleBindingName,
Namespace: binding.Namespace,
Labels: copyStringMap(binding.Labels),
Annotations: copyStringMap(binding.Annotations),
},
Subjects: []rbacv1.Subject{{
Kind: rbacv1.ServiceAccountKind,
Name: binding.ServiceAccountName,
Namespace: binding.Namespace,
}},
RoleRef: rbacv1.RoleRef{
APIGroup: rbacv1.GroupName,
Kind: "ClusterRole",
Name: binding.ClusterRoleName,
},
}
}
func buildTenantKubeconfig(cluster *entity.Cluster, restConfig *rest.Config, binding entity.TenantBinding, token string) (string, error) {
host := ""
var caData []byte
if restConfig != nil {
host = restConfig.Host
caData = append([]byte{}, restConfig.CAData...)
}
if host == "" && cluster != nil {
host = cluster.Host
}
if len(caData) == 0 && cluster != nil {
caData = decodePossiblyBase64(cluster.CAData)
}
if host == "" {
return "", entity.ErrInvalidClusterHost
}
clusterName := "tenant-cluster"
if cluster != nil && cluster.Name != "" {
clusterName = cluster.Name
}
userName := binding.ServiceAccountName
contextName := fmt.Sprintf("%s/%s", clusterName, binding.Namespace)
config := clientcmdapi.NewConfig()
config.Clusters[clusterName] = &clientcmdapi.Cluster{
Server: host,
CertificateAuthorityData: caData,
}
config.AuthInfos[userName] = &clientcmdapi.AuthInfo{
Token: token,
}
config.Contexts[contextName] = &clientcmdapi.Context{
Cluster: clusterName,
AuthInfo: userName,
Namespace: binding.Namespace,
}
config.CurrentContext = contextName
bytes, err := clientcmd.Write(*config)
if err != nil {
return "", fmt.Errorf("failed to build tenant kubeconfig: %w", err)
}
return string(bytes), nil
}
func mergeObjectMetadata(meta *metav1.ObjectMeta, labels, annotations map[string]string) {
if len(labels) > 0 && meta.Labels == nil {
meta.Labels = map[string]string{}
}
for key, value := range labels {
meta.Labels[key] = value
}
if len(annotations) > 0 && meta.Annotations == nil {
meta.Annotations = map[string]string{}
}
for key, value := range annotations {
meta.Annotations[key] = value
}
}
func copyStringMap(values map[string]string) map[string]string {
if len(values) == 0 {
return nil
}
copied := make(map[string]string, len(values))
for key, value := range values {
copied[key] = value
}
return copied
}
func decodePossiblyBase64(value string) []byte {
decoded, err := base64.StdEncoding.DecodeString(value)
if err == nil {
return decoded
}
return []byte(value)
}
func looksLikeKubeconfig(value string) bool {
trimmed := strings.TrimSpace(value)
return strings.HasPrefix(trimmed, "apiVersion:") || strings.HasPrefix(trimmed, "kind: Config")
}

View File

@ -0,0 +1,172 @@
package k8s
import (
"context"
"strings"
"testing"
"time"
authenticationv1 "k8s.io/api/authentication/v1"
corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/kubernetes/fake"
k8stesting "k8s.io/client-go/testing"
"github.com/ocdp/cluster-service/internal/domain/entity"
)
func TestTenantClientEnsureTenantCreatesResources(t *testing.T) {
ctx := context.Background()
clientset := fake.NewSimpleClientset()
client := NewTenantClientForClientset(clientset)
binding := tenantBinding()
if err := client.EnsureTenant(ctx, nil, binding); err != nil {
t.Fatalf("EnsureTenant returned error: %v", err)
}
if _, err := clientset.CoreV1().Namespaces().Get(ctx, binding.Namespace, metav1.GetOptions{}); err != nil {
t.Fatalf("expected namespace: %v", err)
}
if _, err := clientset.CoreV1().ServiceAccounts(binding.Namespace).Get(ctx, binding.ServiceAccountName, metav1.GetOptions{}); err != nil {
t.Fatalf("expected service account: %v", err)
}
roleBinding, err := clientset.RbacV1().RoleBindings(binding.Namespace).Get(ctx, binding.RoleBindingName, metav1.GetOptions{})
if err != nil {
t.Fatalf("expected role binding: %v", err)
}
if roleBinding.RoleRef.Kind != "ClusterRole" || roleBinding.RoleRef.Name != binding.ClusterRoleName {
t.Fatalf("unexpected role ref: %#v", roleBinding.RoleRef)
}
if len(roleBinding.Subjects) != 1 || roleBinding.Subjects[0].Name != binding.ServiceAccountName {
t.Fatalf("unexpected role binding subjects: %#v", roleBinding.Subjects)
}
quota, err := clientset.CoreV1().ResourceQuotas(binding.Namespace).Get(ctx, binding.ResourceQuotaName, metav1.GetOptions{})
if err != nil {
t.Fatalf("expected resource quota: %v", err)
}
if quota.Spec.Hard.Cpu().String() != "2" {
t.Fatalf("expected cpu quota 2, got %s", quota.Spec.Hard.Cpu().String())
}
}
func TestTenantClientEnsureTenantUpdatesExistingResources(t *testing.T) {
ctx := context.Background()
binding := tenantBinding()
clientset := fake.NewSimpleClientset(
&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: binding.Namespace}},
&corev1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{Name: binding.ServiceAccountName, Namespace: binding.Namespace}},
&rbacv1.RoleBinding{
ObjectMeta: metav1.ObjectMeta{Name: binding.RoleBindingName, Namespace: binding.Namespace},
RoleRef: rbacv1.RoleRef{APIGroup: rbacv1.GroupName, Kind: "ClusterRole", Name: "view"},
},
&corev1.ResourceQuota{
ObjectMeta: metav1.ObjectMeta{Name: binding.ResourceQuotaName, Namespace: binding.Namespace},
Spec: corev1.ResourceQuotaSpec{Hard: corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
}},
},
)
client := NewTenantClientForClientset(clientset)
if err := client.EnsureTenant(ctx, nil, binding); err != nil {
t.Fatalf("EnsureTenant returned error: %v", err)
}
roleBinding, err := clientset.RbacV1().RoleBindings(binding.Namespace).Get(ctx, binding.RoleBindingName, metav1.GetOptions{})
if err != nil {
t.Fatalf("expected updated role binding: %v", err)
}
if roleBinding.RoleRef.Name != binding.ClusterRoleName {
t.Fatalf("expected role ref %q, got %q", binding.ClusterRoleName, roleBinding.RoleRef.Name)
}
if roleBinding.Labels["ocdp.io/tenant"] != binding.Namespace {
t.Fatalf("expected tenant label on updated role binding, got %#v", roleBinding.Labels)
}
quota, err := clientset.CoreV1().ResourceQuotas(binding.Namespace).Get(ctx, binding.ResourceQuotaName, metav1.GetOptions{})
if err != nil {
t.Fatalf("expected updated quota: %v", err)
}
if quota.Spec.Hard.Cpu().String() != "2" {
t.Fatalf("expected updated cpu quota 2, got %s", quota.Spec.Hard.Cpu().String())
}
}
func TestTenantClientSuspendTenantDeletesOnlyRoleBinding(t *testing.T) {
ctx := context.Background()
binding := tenantBinding()
clientset := fake.NewSimpleClientset(
&corev1.Namespace{ObjectMeta: metav1.ObjectMeta{Name: binding.Namespace}},
&corev1.ServiceAccount{ObjectMeta: metav1.ObjectMeta{Name: binding.ServiceAccountName, Namespace: binding.Namespace}},
desiredRoleBinding(binding),
)
client := NewTenantClientForClientset(clientset)
if err := client.SuspendTenant(ctx, nil, binding); err != nil {
t.Fatalf("SuspendTenant returned error: %v", err)
}
if _, err := clientset.RbacV1().RoleBindings(binding.Namespace).Get(ctx, binding.RoleBindingName, metav1.GetOptions{}); !apierrors.IsNotFound(err) {
t.Fatalf("expected deleted role binding, got err %v", err)
}
if _, err := clientset.CoreV1().ServiceAccounts(binding.Namespace).Get(ctx, binding.ServiceAccountName, metav1.GetOptions{}); err != nil {
t.Fatalf("service account should remain: %v", err)
}
}
func TestTenantClientIssueKubeconfigCapsTokenTTL(t *testing.T) {
ctx := context.Background()
binding := tenantBinding()
clientset := fake.NewSimpleClientset(&corev1.ServiceAccount{
ObjectMeta: metav1.ObjectMeta{Name: binding.ServiceAccountName, Namespace: binding.Namespace},
})
var requestedExpirationSeconds int64
expiresAt := time.Now().Add(entity.MaxTenantKubeconfigTTL).UTC()
clientset.Fake.PrependReactor("create", "serviceaccounts", func(action k8stesting.Action) (bool, runtime.Object, error) {
if action.GetSubresource() != "token" {
return false, nil, nil
}
createAction := action.(k8stesting.CreateAction)
tokenRequest := createAction.GetObject().(*authenticationv1.TokenRequest)
if tokenRequest.Spec.ExpirationSeconds != nil {
requestedExpirationSeconds = *tokenRequest.Spec.ExpirationSeconds
}
return true, &authenticationv1.TokenRequest{
Status: authenticationv1.TokenRequestStatus{
Token: "short-lived-token",
ExpirationTimestamp: metav1.NewTime(expiresAt),
},
}, nil
})
client := NewTenantClientForClientset(clientset)
kubeconfig, err := client.IssueKubeconfig(ctx, &entity.Cluster{Name: "test", Host: "https://example.invalid"}, binding, 24*time.Hour)
if err != nil {
t.Fatalf("IssueKubeconfig returned error: %v", err)
}
if requestedExpirationSeconds != int64(entity.MaxTenantKubeconfigTTL.Seconds()) {
t.Fatalf("expected capped ttl %d, got %d", int64(entity.MaxTenantKubeconfigTTL.Seconds()), requestedExpirationSeconds)
}
if !kubeconfig.ExpiresAt.Equal(expiresAt) {
t.Fatalf("expected expiration %s, got %s", expiresAt, kubeconfig.ExpiresAt)
}
if !strings.Contains(kubeconfig.Kubeconfig, "short-lived-token") {
t.Fatal("expected kubeconfig to contain issued token")
}
if !strings.Contains(kubeconfig.Kubeconfig, "namespace: tenant-a") {
t.Fatalf("expected kubeconfig namespace, got:\n%s", kubeconfig.Kubeconfig)
}
}
func tenantBinding() entity.TenantBinding {
binding := entity.NewTenantBinding("tenant-a")
binding.ResourceQuotaHard = corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("4Gi"),
}
return binding
}

View File

@ -0,0 +1,36 @@
package k8s
import (
"context"
"fmt"
"time"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
type MockTenantClient struct{}
func NewMockTenantClient() repository.TenantKubeClient {
return &MockTenantClient{}
}
func (c *MockTenantClient) EnsureTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
return binding.Validate()
}
func (c *MockTenantClient) IssueKubeconfig(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding, ttl time.Duration) (*entity.TenantKubeconfig, error) {
if err := binding.Validate(); err != nil {
return nil, err
}
expiresAt := time.Now().Add(entity.TenantTokenTTL(ttl))
return &entity.TenantKubeconfig{
Kubeconfig: fmt.Sprintf("apiVersion: v1\nkind: Config\nclusters:\n- name: %s\n cluster:\n server: %s\ncontexts:\n- name: %s\n context:\n cluster: %s\n namespace: %s\n user: %s\ncurrent-context: %s\nusers:\n- name: %s\n user:\n token: mock-ephemeral-token\n",
cluster.Name, cluster.Host, binding.Namespace, cluster.Name, binding.Namespace, binding.ServiceAccountName, binding.Namespace, binding.ServiceAccountName),
ExpiresAt: expiresAt,
}, nil
}
func (c *MockTenantClient) SuspendTenant(ctx context.Context, cluster *entity.Cluster, binding entity.TenantBinding) error {
return binding.Validate()
}

View File

@ -5,7 +5,7 @@ import (
"fmt"
"strings"
"time"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
@ -13,7 +13,7 @@ import (
// OCIClientMock OCI Registry 客户端 Mock 实现
type OCIClientMock struct {
// Mock 数据存储
repositories map[string][]string // registryID -> []repositoryName
repositories map[string][]string // registryID -> []repositoryName
artifacts map[string]map[string][]*entity.Artifact // registryID -> repository -> []artifact
}
@ -23,10 +23,10 @@ func NewOCIClientMock() repository.OCIClient {
repositories: make(map[string][]string),
artifacts: make(map[string]map[string][]*entity.Artifact),
}
// 初始化一些测试数据
mock.initMockData()
return mock
}
@ -38,18 +38,18 @@ func (c *OCIClientMock) initMockData() {
// initArtifactsForRegistry initializes mock artifacts for a given registry ID
func (c *OCIClientMock) initArtifactsForRegistry(registryID string) {
c.artifacts[registryID] = make(map[string][]*entity.Artifact)
// vllm-serve artifacts (OCI 格式的 Helm Chart)
c.artifacts[registryID]["charts/vllm-serve"] = []*entity.Artifact{
{
RegistryID: registryID,
Repository: "charts/vllm-serve",
Tag: "0.1.0",
Digest: "sha256:abc123def456",
Type: entity.ArtifactTypeChart,
Size: 12345678,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
RegistryID: registryID,
Repository: "charts/vllm-serve",
Tag: "0.1.0",
Digest: "sha256:abc123def456",
Type: entity.ArtifactTypeChart,
Size: 12345678,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "vllm-serve",
"org.opencontainers.image.version": "0.1.0",
@ -57,14 +57,14 @@ func (c *OCIClientMock) initArtifactsForRegistry(registryID string) {
CreatedAt: time.Now().Add(-24 * time.Hour),
},
{
RegistryID: registryID,
Repository: "charts/vllm-serve",
Tag: "0.2.0",
Digest: "sha256:xyz789uvw012",
Type: entity.ArtifactTypeChart,
Size: 13456789,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
RegistryID: registryID,
Repository: "charts/vllm-serve",
Tag: "0.2.0",
Digest: "sha256:xyz789uvw012",
Type: entity.ArtifactTypeChart,
Size: 13456789,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "vllm-serve",
"org.opencontainers.image.version": "0.2.0",
@ -72,36 +72,36 @@ func (c *OCIClientMock) initArtifactsForRegistry(registryID string) {
CreatedAt: time.Now(),
},
}
// nginx artifacts (OCI 格式的 Helm Chart)
c.artifacts[registryID]["charts/nginx"] = []*entity.Artifact{
{
RegistryID: registryID,
Repository: "charts/nginx",
Tag: "1.0.0",
Digest: "sha256:nginx123456",
Type: entity.ArtifactTypeChart,
Size: 5678901,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
RegistryID: registryID,
Repository: "charts/nginx",
Tag: "1.0.0",
Digest: "sha256:nginx123456",
Type: entity.ArtifactTypeChart,
Size: 5678901,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "nginx",
},
CreatedAt: time.Now().Add(-48 * time.Hour),
},
}
// redis artifacts (OCI 格式的 Helm Chart)
c.artifacts[registryID]["charts/redis"] = []*entity.Artifact{
{
RegistryID: registryID,
Repository: "charts/redis",
Tag: "6.2.0",
Digest: "sha256:redis789abc",
Type: entity.ArtifactTypeChart,
Size: 8901234,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
RegistryID: registryID,
Repository: "charts/redis",
Tag: "6.2.0",
Digest: "sha256:redis789abc",
Type: entity.ArtifactTypeChart,
Size: 8901234,
MediaType: "application/vnd.oci.image.manifest.v1+json",
ConfigType: "application/vnd.cncf.helm.config.v1+json", // Helm Chart 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "redis",
"org.opencontainers.image.version": "6.2.0",
@ -109,18 +109,18 @@ func (c *OCIClientMock) initArtifactsForRegistry(registryID string) {
CreatedAt: time.Now().Add(-72 * time.Hour),
},
}
// alpine artifacts (Docker Image)
c.artifacts[registryID]["library/alpine"] = []*entity.Artifact{
{
RegistryID: registryID,
Repository: "library/alpine",
Tag: "3.18",
Digest: "sha256:alpine123",
Type: entity.ArtifactTypeImage,
Size: 2345678,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
ConfigType: "application/vnd.docker.container.image.v1+json", // Docker Image 的 config type
RegistryID: registryID,
Repository: "library/alpine",
Tag: "3.18",
Digest: "sha256:alpine123",
Type: entity.ArtifactTypeImage,
Size: 2345678,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
ConfigType: "application/vnd.docker.container.image.v1+json", // Docker Image 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "alpine",
"org.opencontainers.image.version": "3.18",
@ -128,14 +128,14 @@ func (c *OCIClientMock) initArtifactsForRegistry(registryID string) {
CreatedAt: time.Now().Add(-96 * time.Hour),
},
{
RegistryID: registryID,
Repository: "library/alpine",
Tag: "latest",
Digest: "sha256:alpine456",
Type: entity.ArtifactTypeImage,
Size: 2456789,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
ConfigType: "application/vnd.docker.container.image.v1+json", // Docker Image 的 config type
RegistryID: registryID,
Repository: "library/alpine",
Tag: "latest",
Digest: "sha256:alpine456",
Type: entity.ArtifactTypeImage,
Size: 2456789,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
ConfigType: "application/vnd.docker.container.image.v1+json", // Docker Image 的 config type
Annotations: map[string]string{
"org.opencontainers.image.title": "alpine",
},
@ -144,7 +144,7 @@ func (c *OCIClientMock) initArtifactsForRegistry(registryID string) {
}
}
func (c *OCIClientMock) ListRepositories(ctx context.Context, registry *entity.Registry) ([]string, error) {
func (c *OCIClientMock) ListRepositories(ctx context.Context, registry *entity.Registry, artifactType string) ([]string, error) {
// Check if we have cached data for this registry
repos, exists := c.repositories[registry.ID]
if !exists {
@ -156,10 +156,20 @@ func (c *OCIClientMock) ListRepositories(ctx context.Context, registry *entity.R
"library/alpine",
}
c.repositories[registry.ID] = repos
// Also initialize artifacts for this registry
c.initArtifactsForRegistry(registry.ID)
}
if strings.EqualFold(strings.TrimSpace(artifactType), "chart") {
chartRepos := make([]string, 0)
for _, repo := range repos {
artifacts, _ := c.ListArtifacts(ctx, registry, repo, "chart")
if len(artifacts) > 0 {
chartRepos = append(chartRepos, repo)
}
}
return chartRepos, nil
}
return repos, nil
}
@ -170,20 +180,20 @@ func (c *OCIClientMock) ListArtifacts(ctx context.Context, registry *entity.Regi
c.initArtifactsForRegistry(registry.ID)
regArtifacts = c.artifacts[registry.ID]
}
artifacts, exists := regArtifacts[repository]
if !exists {
return []*entity.Artifact{}, nil
}
// 应用 mediaType 过滤
if mediaTypeFilter == "" || mediaTypeFilter == "all" {
return artifacts, nil
}
filtered := make([]*entity.Artifact, 0)
filter := strings.ToLower(strings.TrimSpace(mediaTypeFilter))
for _, artifact := range artifacts {
switch filter {
case "chart":
@ -200,7 +210,7 @@ func (c *OCIClientMock) ListArtifacts(ctx context.Context, registry *entity.Regi
}
}
}
return filtered, nil
}
@ -211,19 +221,19 @@ func (c *OCIClientMock) GetArtifact(ctx context.Context, registry *entity.Regist
c.initArtifactsForRegistry(registry.ID)
regArtifacts = c.artifacts[registry.ID]
}
artifacts, exists := regArtifacts[repository]
if !exists {
return nil, entity.ErrArtifactNotFound
}
// 根据 tag 或 digest 查找
for _, artifact := range artifacts {
if artifact.Tag == reference || artifact.Digest == reference {
return artifact, nil
}
}
return nil, entity.ErrArtifactNotFound
}
@ -232,11 +242,11 @@ func (c *OCIClientMock) GetValuesSchema(ctx context.Context, registry *entity.Re
if err != nil {
return "", err
}
if !artifact.IsChart() {
return "", fmt.Errorf("not a helm chart")
}
// 返回 Mock values schema
mockSchema := `{
"$schema": "http://json-schema.org/draft-07/schema#",
@ -262,12 +272,23 @@ func (c *OCIClientMock) GetValuesSchema(ctx context.Context, registry *entity.Re
return mockSchema, nil
}
func (c *OCIClientMock) GetValuesYAML(ctx context.Context, registry *entity.Registry, repository, reference string) (string, error) {
artifact, err := c.GetArtifact(ctx, registry, repository, reference)
if err != nil {
return "", err
}
if !artifact.IsChart() {
return "", fmt.Errorf("not a helm chart")
}
return "replicaCount: 1\nimage:\n repository: nginx\n tag: latest\nservice:\n type: ClusterIP\n", nil
}
func (c *OCIClientMock) PullArtifact(ctx context.Context, registry *entity.Registry, repository, reference, destPath string) error {
_, err := c.GetArtifact(ctx, registry, repository, reference)
if err != nil {
return err
}
// Mock 实现,不实际下载
return nil
}
@ -281,4 +302,3 @@ func (c *OCIClientMock) CheckHealth(ctx context.Context, registry *entity.Regist
// Mock 实现,总是返回健康
return nil
}

View File

@ -8,9 +8,13 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
@ -25,6 +29,30 @@ type OCIClient struct {
httpClient *http.Client
}
type harborProject struct {
Name string `json:"name"`
}
type harborRepository struct {
Name string `json:"name"`
ArtifactCount int `json:"artifact_count"`
}
type harborTag struct {
Name string `json:"name"`
PushTime string `json:"push_time"`
}
type harborArtifact struct {
Digest string `json:"digest"`
MediaType string `json:"media_type"`
ArtifactType string `json:"artifact_type"`
Size int64 `json:"size"`
PushTime string `json:"push_time"`
Tags []harborTag `json:"tags"`
Annotations map[string]string `json:"annotations"`
}
// NewOCIClient 创建真实的 OCI 客户端
func NewOCIClient() repository.OCIClient {
return &OCIClient{
@ -60,8 +88,34 @@ func (c *OCIClient) getRegistry(reg *entity.Registry) (*remote.Registry, error)
return registry, nil
}
// ListRepositories 列出 Registry 中的所有 repositories
func (c *OCIClient) ListRepositories(ctx context.Context, registry *entity.Registry) ([]string, error) {
// ListRepositories 列出 Registry 中的 repositories.
// Harbor registry 优先使用 Harbor v2.0 API避免 robot 账号依赖 /v2/_catalog 全局权限。
func (c *OCIClient) ListRepositories(ctx context.Context, registry *entity.Registry, artifactType string) ([]string, error) {
repositories, harborErr := c.listHarborRepositories(ctx, registry, artifactType)
if harborErr == nil {
return repositories, nil
}
repositories, catalogErr := c.listOCIRepositories(ctx, registry)
if catalogErr != nil {
return nil, fmt.Errorf("failed to list repositories via Harbor API: %v; OCI catalog fallback also failed: %w", harborErr, catalogErr)
}
if strings.EqualFold(strings.TrimSpace(artifactType), "chart") {
chartRepos := make([]string, 0)
for _, repo := range repositories {
artifacts, err := c.ListArtifacts(ctx, registry, repo, "chart")
if err == nil && len(artifacts) > 0 {
chartRepos = append(chartRepos, repo)
}
}
return chartRepos, nil
}
return repositories, nil
}
func (c *OCIClient) listOCIRepositories(ctx context.Context, registry *entity.Registry) ([]string, error) {
reg, err := c.getRegistry(registry)
if err != nil {
return nil, err
@ -81,9 +135,278 @@ func (c *OCIClient) ListRepositories(ctx context.Context, registry *entity.Regis
return repositories, nil
}
func (c *OCIClient) listHarborRepositories(ctx context.Context, registry *entity.Registry, artifactType string) ([]string, error) {
projects, err := c.harborListProjects(ctx, registry)
if err != nil {
return nil, err
}
repositorySet := make(map[string]struct{})
chartOnly := strings.EqualFold(strings.TrimSpace(artifactType), "chart") || strings.TrimSpace(artifactType) == ""
for _, project := range projects {
projectName := strings.TrimSpace(project.Name)
if projectName == "" {
continue
}
repositories, err := c.harborListProjectRepositories(ctx, registry, projectName)
if err != nil {
return nil, err
}
for _, harborRepo := range repositories {
repoName := normalizeHarborRepositoryName(projectName, harborRepo.Name)
if repoName == "" {
continue
}
if chartOnly {
artifacts, err := c.listHarborArtifacts(ctx, registry, repoName, "chart")
if err != nil || len(artifacts) == 0 {
continue
}
}
repositorySet[repoName] = struct{}{}
}
}
repositories := make([]string, 0, len(repositorySet))
for repo := range repositorySet {
repositories = append(repositories, repo)
}
sort.Strings(repositories)
return repositories, nil
}
func (c *OCIClient) harborListProjects(ctx context.Context, registry *entity.Registry) ([]harborProject, error) {
var projects []harborProject
if err := c.harborGetPaged(ctx, registry, "/api/v2.0/projects", url.Values{"member": []string{"true"}}, &projects); err != nil {
return nil, err
}
return projects, nil
}
func (c *OCIClient) harborListProjectRepositories(ctx context.Context, registry *entity.Registry, projectName string) ([]harborRepository, error) {
var repositories []harborRepository
path := "/api/v2.0/projects/" + url.PathEscape(projectName) + "/repositories"
if err := c.harborGetPaged(ctx, registry, path, nil, &repositories); err != nil {
return nil, err
}
return repositories, nil
}
func (c *OCIClient) listHarborArtifacts(ctx context.Context, registry *entity.Registry, repository, mediaTypeFilter string) ([]*entity.Artifact, error) {
projectName, repoName, ok := splitHarborRepository(repository)
if !ok {
return nil, fmt.Errorf("repository %q is not a Harbor project repository path", repository)
}
var harborArtifacts []harborArtifact
path := "/api/v2.0/projects/" + url.PathEscape(projectName) + "/repositories/" + url.PathEscape(repoName) + "/artifacts"
query := url.Values{
"with_tag": []string{"true"},
"with_label": []string{"false"},
}
if err := c.harborGetPaged(ctx, registry, path, query, &harborArtifacts); err != nil {
return nil, err
}
artifacts := make([]*entity.Artifact, 0)
for _, harborArtifact := range harborArtifacts {
tags := harborArtifact.Tags
if len(tags) == 0 {
continue
}
for _, tag := range tags {
if strings.TrimSpace(tag.Name) == "" {
continue
}
artifact := &entity.Artifact{
Repository: repository,
Tag: tag.Name,
Digest: harborArtifact.Digest,
MediaType: harborArtifact.MediaType,
ConfigType: harborArtifact.ArtifactType,
Size: harborArtifact.Size,
Annotations: harborArtifact.Annotations,
CreatedAt: parseHarborTime(firstNonEmpty(tag.PushTime, harborArtifact.PushTime)),
}
if artifact.Annotations == nil {
artifact.Annotations = make(map[string]string)
}
artifact.DetermineType()
if isHarborChartArtifact(harborArtifact) {
artifact.Type = entity.ArtifactTypeChart
}
if c.shouldIncludeArtifact(artifact, mediaTypeFilter) {
artifacts = append(artifacts, artifact)
}
}
}
return artifacts, nil
}
func (c *OCIClient) harborGetPaged(ctx context.Context, registry *entity.Registry, path string, query url.Values, target interface{}) error {
const pageSize = 100
accumulated := make([]json.RawMessage, 0)
for page := 1; ; page++ {
pageQuery := cloneValues(query)
pageQuery.Set("page", fmt.Sprintf("%d", page))
pageQuery.Set("page_size", fmt.Sprintf("%d", pageSize))
body, total, err := c.harborGet(ctx, registry, path, pageQuery)
if err != nil {
return err
}
var pageItems []json.RawMessage
if err := json.Unmarshal(body, &pageItems); err != nil {
return fmt.Errorf("failed to decode Harbor response for %s: %w", path, err)
}
accumulated = append(accumulated, pageItems...)
if len(pageItems) < pageSize || (total >= 0 && len(accumulated) >= total) {
break
}
}
combined, err := json.Marshal(accumulated)
if err != nil {
return fmt.Errorf("failed to combine Harbor pages: %w", err)
}
if err := json.Unmarshal(combined, target); err != nil {
return fmt.Errorf("failed to decode Harbor pages: %w", err)
}
return nil
}
func (c *OCIClient) harborGet(ctx context.Context, registry *entity.Registry, path string, query url.Values) ([]byte, int, error) {
baseURL, err := harborBaseURL(registry)
if err != nil {
return nil, -1, err
}
requestURL := strings.TrimRight(baseURL, "/") + path
if len(query) > 0 {
requestURL += "?" + query.Encode()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, nil)
if err != nil {
return nil, -1, err
}
req.Header.Set("Accept", "application/json")
if registry.Username != "" || registry.Password != "" {
req.SetBasicAuth(registry.Username, registry.Password)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, -1, fmt.Errorf("Harbor API request failed: %w", err)
}
defer resp.Body.Close()
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
if readErr != nil {
return nil, -1, fmt.Errorf("failed to read Harbor API response: %w", readErr)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, -1, fmt.Errorf("Harbor API %s returned %d: %s", path, resp.StatusCode, strings.TrimSpace(string(body)))
}
total := -1
if value := strings.TrimSpace(resp.Header.Get("X-Total-Count")); value != "" {
if parsed, err := strconv.Atoi(value); err == nil {
total = parsed
}
}
return body, total, nil
}
func harborBaseURL(registry *entity.Registry) (string, error) {
rawURL := strings.TrimSpace(registry.URL)
if rawURL == "" {
return "", fmt.Errorf("registry URL is empty")
}
if !strings.Contains(rawURL, "://") {
rawURL = "https://" + rawURL
}
parsed, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("invalid registry URL %q: %w", registry.URL, err)
}
if parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid registry URL %q", registry.URL)
}
return parsed.Scheme + "://" + parsed.Host, nil
}
func splitHarborRepository(repository string) (string, string, bool) {
projectName, repoName, ok := strings.Cut(strings.Trim(repository, "/"), "/")
if !ok || projectName == "" || repoName == "" {
return "", "", false
}
return projectName, repoName, true
}
func normalizeHarborRepositoryName(projectName, repositoryName string) string {
repositoryName = strings.Trim(repositoryName, "/")
if repositoryName == "" {
return ""
}
if strings.HasPrefix(repositoryName, projectName+"/") {
return repositoryName
}
return projectName + "/" + repositoryName
}
func isHarborChartArtifact(artifact harborArtifact) bool {
typeInfo := strings.ToLower(strings.TrimSpace(artifact.ArtifactType + " " + artifact.MediaType))
return strings.Contains(typeInfo, "chart") || strings.Contains(typeInfo, "helm")
}
func cloneValues(values url.Values) url.Values {
cloned := make(url.Values)
for key, items := range values {
cloned[key] = append([]string(nil), items...)
}
return cloned
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}
func parseHarborTime(value string) time.Time {
value = strings.TrimSpace(value)
if value == "" {
return time.Time{}
}
for _, layout := range []string{time.RFC3339Nano, time.RFC3339, "2006-01-02T15:04:05.999999", "2006-01-02T15:04:05"} {
if parsed, err := time.Parse(layout, value); err == nil {
return parsed
}
}
return time.Time{}
}
// ListArtifacts 列出指定 repository 的所有 artifacts
// mediaTypeFilter: "all", "image", "chart", "other" - 使用模糊匹配过滤
func (c *OCIClient) ListArtifacts(ctx context.Context, registry *entity.Registry, repository, mediaTypeFilter string) ([]*entity.Artifact, error) {
if artifacts, err := c.listHarborArtifacts(ctx, registry, repository, mediaTypeFilter); err == nil {
return artifacts, nil
}
reg, err := c.getRegistry(registry)
if err != nil {
return nil, err
@ -370,6 +693,113 @@ func (c *OCIClient) GetValuesSchema(ctx context.Context, registry *entity.Regist
return "", entity.ErrValuesSchemaNotFound
}
// GetValuesYAML 获取 Helm Chart 包内原始 values.yaml
func (c *OCIClient) GetValuesYAML(ctx context.Context, registry *entity.Registry, repository, reference string) (string, error) {
data, err := c.readChartFile(ctx, registry, repository, reference, "values.yaml")
if err != nil {
return "", err
}
if strings.TrimSpace(data) == "" {
return "", entity.ErrArtifactNotFound
}
return data, nil
}
func (c *OCIClient) readChartFile(ctx context.Context, registry *entity.Registry, repository, reference, filename string) (string, error) {
reg, err := c.getRegistry(registry)
if err != nil {
return "", err
}
repo, err := reg.Repository(ctx, repository)
if err != nil {
return "", fmt.Errorf("failed to get repository: %w", err)
}
desc, err := repo.Resolve(ctx, reference)
if err != nil {
return "", fmt.Errorf("failed to resolve artifact: %w", err)
}
manifestReader, err := repo.Fetch(ctx, desc)
if err != nil {
return "", fmt.Errorf("failed to fetch manifest: %w", err)
}
defer manifestReader.Close()
manifestBytes, err := io.ReadAll(manifestReader)
if err != nil {
return "", fmt.Errorf("failed to read manifest: %w", err)
}
var manifest ocispec.Manifest
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
return "", fmt.Errorf("failed to unmarshal manifest: %w", err)
}
var chartLayer *ocispec.Descriptor
for i := range manifest.Layers {
layer := manifest.Layers[i]
if strings.Contains(layer.MediaType, "cncf.helm.chart") ||
strings.Contains(layer.MediaType, "helm.chart.content") {
chartLayer = &manifest.Layers[i]
break
}
}
if chartLayer == nil {
return "", fmt.Errorf("helm chart layer not found in manifest")
}
if chartLayer.Digest == "" {
return "", fmt.Errorf("chart layer digest is empty")
}
if _, err := digest.Parse(string(chartLayer.Digest)); err != nil {
return "", fmt.Errorf("invalid chart layer digest: %w", err)
}
layerReader, err := repo.Fetch(ctx, *chartLayer)
if err != nil {
return "", fmt.Errorf("failed to fetch chart layer: %w", err)
}
defer layerReader.Close()
gzipReader, err := gzip.NewReader(layerReader)
if err != nil {
return "", fmt.Errorf("failed to create gzip reader: %w", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
bestDepth := int(^uint(0) >> 1)
var bestData []byte
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return "", fmt.Errorf("failed to read chart archive: %w", err)
}
if header.Typeflag != tar.TypeReg {
continue
}
if strings.HasSuffix(header.Name, filename) {
data, err := io.ReadAll(tarReader)
if err != nil {
return "", fmt.Errorf("failed to read %s: %w", filename, err)
}
depth := strings.Count(strings.Trim(header.Name, "/"), "/")
if depth < bestDepth {
bestDepth = depth
bestData = data
}
}
}
if len(bestData) > 0 {
return string(bestData), nil
}
return "", fmt.Errorf("%s not found in chart", filename)
}
// PullArtifact 下载 artifact 到本地
func (c *OCIClient) PullArtifact(ctx context.Context, registry *entity.Registry, repository, reference, destPath string) error {
reg, err := c.getRegistry(registry)

View File

@ -3,7 +3,7 @@ package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
@ -27,21 +27,21 @@ func NewClusterRepositoryMock(encryptor crypto.Encryptor) repository.ClusterRepo
func (r *ClusterRepositoryMock) Create(ctx context.Context, cluster *entity.Cluster) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查名称是否已存在
for _, c := range r.clusters {
if c.Name == cluster.Name {
return entity.ErrClusterExists
}
}
// Mock 模式:如果没有提供认证信息,自动填充默认的 Mock 证书
if (cluster.CertData == "" || cluster.KeyData == "") && cluster.Token == "" {
cluster.CAData = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1vY2sgQ0EgQ2VydGlmaWNhdGUKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQ=="
cluster.CertData = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1vY2sgQ2xpZW50IENlcnRpZmljYXRlCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0="
cluster.KeyData = "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNb2NrIFByaXZhdGUgS2V5Ci0tLS0tRU5EIFJTQSBQUklWQVRFIEtFWS0tLS0t"
}
// 加密敏感数据后存储
encryptedCluster := r.encryptCluster(cluster)
r.clusters[cluster.ID] = encryptedCluster
@ -51,12 +51,12 @@ func (r *ClusterRepositoryMock) Create(ctx context.Context, cluster *entity.Clus
func (r *ClusterRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
cluster, exists := r.clusters[id]
if !exists {
return nil, entity.ErrClusterNotFound
}
// 解密敏感数据后返回
return r.decryptCluster(cluster), nil
}
@ -64,25 +64,25 @@ func (r *ClusterRepositoryMock) GetByID(ctx context.Context, id string) (*entity
func (r *ClusterRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, cluster := range r.clusters {
if cluster.Name == name {
// 解密敏感数据后返回
return r.decryptCluster(cluster), nil
}
}
return nil, entity.ErrClusterNotFound
}
func (r *ClusterRepositoryMock) Update(ctx context.Context, cluster *entity.Cluster) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.clusters[cluster.ID]; !exists {
return entity.ErrClusterNotFound
}
// 加密敏感数据后存储
encryptedCluster := r.encryptCluster(cluster)
r.clusters[cluster.ID] = encryptedCluster
@ -92,11 +92,11 @@ func (r *ClusterRepositoryMock) Update(ctx context.Context, cluster *entity.Clus
func (r *ClusterRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.clusters[id]; !exists {
return entity.ErrClusterNotFound
}
delete(r.clusters, id)
return nil
}
@ -104,20 +104,20 @@ func (r *ClusterRepositoryMock) Delete(ctx context.Context, id string) error {
func (r *ClusterRepositoryMock) List(ctx context.Context) ([]*entity.Cluster, error) {
r.mu.RLock()
defer r.mu.RUnlock()
clusters := make([]*entity.Cluster, 0, len(r.clusters))
for _, cluster := range r.clusters {
// 解密敏感数据后返回
clusters = append(clusters, r.decryptCluster(cluster))
}
return clusters, nil
}
// encryptCluster 加密 Cluster 的敏感数据
func (r *ClusterRepositoryMock) encryptCluster(cluster *entity.Cluster) *entity.Cluster {
encrypted := *cluster // 复制
// 加密证书数据
if cluster.CAData != "" && !crypto.IsEncrypted(cluster.CAData) {
if encryptedData, err := r.encryptor.Encrypt(cluster.CAData); err == nil {
@ -139,14 +139,14 @@ func (r *ClusterRepositoryMock) encryptCluster(cluster *entity.Cluster) *entity.
encrypted.Token = encryptedData
}
}
return &encrypted
}
// decryptCluster 解密 Cluster 的敏感数据
func (r *ClusterRepositoryMock) decryptCluster(cluster *entity.Cluster) *entity.Cluster {
decrypted := *cluster // 复制
// 解密证书数据
if cluster.CAData != "" && crypto.IsEncrypted(cluster.CAData) {
if decryptedData, err := r.encryptor.Decrypt(cluster.CAData); err == nil {
@ -168,7 +168,6 @@ func (r *ClusterRepositoryMock) decryptCluster(cluster *entity.Cluster) *entity.
decrypted.Token = decryptedData
}
}
return &decrypted
}

View File

@ -3,7 +3,7 @@ package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
@ -24,14 +24,14 @@ func NewInstanceRepositoryMock() repository.InstanceRepository {
func (r *InstanceRepositoryMock) Create(ctx context.Context, instance *entity.Instance) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查同一集群中名称是否已存在
for _, inst := range r.instances {
if inst.ClusterID == instance.ClusterID && inst.Name == instance.Name {
return entity.ErrInstanceExists
}
}
r.instances[instance.ID] = instance
return nil
}
@ -39,36 +39,36 @@ func (r *InstanceRepositoryMock) Create(ctx context.Context, instance *entity.In
func (r *InstanceRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instance, exists := r.instances[id]
if !exists {
return nil, entity.ErrInstanceNotFound
}
return instance, nil
}
func (r *InstanceRepositoryMock) GetByClusterAndName(ctx context.Context, clusterID, name string) (*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, instance := range r.instances {
if instance.ClusterID == clusterID && instance.Name == name {
return instance, nil
}
}
return nil, entity.ErrInstanceNotFound
}
func (r *InstanceRepositoryMock) Update(ctx context.Context, instance *entity.Instance) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.instances[instance.ID]; !exists {
return entity.ErrInstanceNotFound
}
r.instances[instance.ID] = instance
return nil
}
@ -76,11 +76,11 @@ func (r *InstanceRepositoryMock) Update(ctx context.Context, instance *entity.In
func (r *InstanceRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.instances[id]; !exists {
return entity.ErrInstanceNotFound
}
delete(r.instances, id)
return nil
}
@ -88,26 +88,25 @@ func (r *InstanceRepositoryMock) Delete(ctx context.Context, id string) error {
func (r *InstanceRepositoryMock) ListByCluster(ctx context.Context, clusterID string) ([]*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instances := make([]*entity.Instance, 0)
for _, instance := range r.instances {
if instance.ClusterID == clusterID {
instances = append(instances, instance)
}
}
return instances, nil
}
func (r *InstanceRepositoryMock) List(ctx context.Context) ([]*entity.Instance, error) {
r.mu.RLock()
defer r.mu.RUnlock()
instances := make([]*entity.Instance, 0, len(r.instances))
for _, instance := range r.instances {
instances = append(instances, instance)
}
return instances, nil
}

View File

@ -3,7 +3,7 @@ package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
"github.com/ocdp/cluster-service/internal/pkg/crypto"
@ -27,14 +27,14 @@ func NewRegistryRepositoryMock(encryptor crypto.Encryptor) repository.RegistryRe
func (r *RegistryRepositoryMock) Create(ctx context.Context, registry *entity.Registry) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查名称是否已存在
for _, reg := range r.registries {
if reg.Name == registry.Name {
return entity.ErrRegistryExists
}
}
// 加密敏感数据后存储
encryptedRegistry := r.encryptRegistry(registry)
r.registries[registry.ID] = encryptedRegistry
@ -44,12 +44,12 @@ func (r *RegistryRepositoryMock) Create(ctx context.Context, registry *entity.Re
func (r *RegistryRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
registry, exists := r.registries[id]
if !exists {
return nil, entity.ErrRegistryNotFound
}
// 解密敏感数据后返回
return r.decryptRegistry(registry), nil
}
@ -57,25 +57,25 @@ func (r *RegistryRepositoryMock) GetByID(ctx context.Context, id string) (*entit
func (r *RegistryRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, registry := range r.registries {
if registry.Name == name {
// 解密敏感数据后返回
return r.decryptRegistry(registry), nil
}
}
return nil, entity.ErrRegistryNotFound
}
func (r *RegistryRepositoryMock) Update(ctx context.Context, registry *entity.Registry) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.registries[registry.ID]; !exists {
return entity.ErrRegistryNotFound
}
// 加密敏感数据后存储
encryptedRegistry := r.encryptRegistry(registry)
r.registries[registry.ID] = encryptedRegistry
@ -85,11 +85,11 @@ func (r *RegistryRepositoryMock) Update(ctx context.Context, registry *entity.Re
func (r *RegistryRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.registries[id]; !exists {
return entity.ErrRegistryNotFound
}
delete(r.registries, id)
return nil
}
@ -97,41 +97,40 @@ func (r *RegistryRepositoryMock) Delete(ctx context.Context, id string) error {
func (r *RegistryRepositoryMock) List(ctx context.Context) ([]*entity.Registry, error) {
r.mu.RLock()
defer r.mu.RUnlock()
registries := make([]*entity.Registry, 0, len(r.registries))
for _, registry := range r.registries {
// 解密敏感数据后返回
registries = append(registries, r.decryptRegistry(registry))
}
return registries, nil
}
// encryptRegistry 加密 Registry 的敏感数据
func (r *RegistryRepositoryMock) encryptRegistry(registry *entity.Registry) *entity.Registry {
encrypted := *registry // 复制
// 加密密码
if registry.Password != "" && !crypto.IsEncrypted(registry.Password) {
if encryptedPassword, err := r.encryptor.Encrypt(registry.Password); err == nil {
encrypted.Password = encryptedPassword
}
}
return &encrypted
}
// decryptRegistry 解密 Registry 的敏感数据
func (r *RegistryRepositoryMock) decryptRegistry(registry *entity.Registry) *entity.Registry {
decrypted := *registry // 复制
// 解密密码
if registry.Password != "" && crypto.IsEncrypted(registry.Password) {
if decryptedPassword, err := r.encryptor.Decrypt(registry.Password); err == nil {
decrypted.Password = decryptedPassword
}
}
return &decrypted
}

View File

@ -3,7 +3,7 @@ package mock
import (
"context"
"sync"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
@ -24,14 +24,14 @@ func NewUserRepositoryMock() repository.UserRepository {
func (r *UserRepositoryMock) Create(ctx context.Context, user *entity.User) error {
r.mu.Lock()
defer r.mu.Unlock()
// 检查是否已存在
for _, u := range r.users {
if u.Username == user.Username {
return entity.ErrUserExists
}
}
r.users[user.ID] = user
return nil
}
@ -39,36 +39,36 @@ func (r *UserRepositoryMock) Create(ctx context.Context, user *entity.User) erro
func (r *UserRepositoryMock) GetByID(ctx context.Context, id string) (*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
user, exists := r.users[id]
if !exists {
return nil, entity.ErrUserNotFound
}
return user, nil
}
func (r *UserRepositoryMock) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, user := range r.users {
if user.Username == username {
return user, nil
}
}
return nil, entity.ErrUserNotFound
}
func (r *UserRepositoryMock) Update(ctx context.Context, user *entity.User) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.users[user.ID]; !exists {
return entity.ErrUserNotFound
}
r.users[user.ID] = user
return nil
}
@ -76,11 +76,11 @@ func (r *UserRepositoryMock) Update(ctx context.Context, user *entity.User) erro
func (r *UserRepositoryMock) Delete(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.users[id]; !exists {
return entity.ErrUserNotFound
}
delete(r.users, id)
return nil
}
@ -88,12 +88,11 @@ func (r *UserRepositoryMock) Delete(ctx context.Context, id string) error {
func (r *UserRepositoryMock) List(ctx context.Context) ([]*entity.User, error) {
r.mu.RLock()
defer r.mu.RUnlock()
users := make([]*entity.User, 0, len(r.users))
for _, user := range r.users {
users = append(users, user)
}
return users, nil
}

View File

@ -0,0 +1,162 @@
package mock
import (
"context"
"sync"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
type WorkspaceRepositoryMock struct {
mu sync.RWMutex
workspaces map[string]*entity.Workspace
}
func NewWorkspaceRepositoryMock() repository.WorkspaceRepository {
repo := &WorkspaceRepositoryMock{workspaces: make(map[string]*entity.Workspace)}
defaultWorkspace := entity.NewWorkspace(entity.DefaultWorkspaceName, "")
defaultWorkspace.ID = entity.DefaultWorkspaceID
repo.workspaces[defaultWorkspace.ID] = defaultWorkspace
return repo
}
func (r *WorkspaceRepositoryMock) Create(ctx context.Context, workspace *entity.Workspace) error {
r.mu.Lock()
defer r.mu.Unlock()
if workspace.ID == "" {
workspace.ID = uuid.New().String()
}
for _, existing := range r.workspaces {
if existing.Name == workspace.Name {
return entity.ErrWorkspaceExists
}
}
copy := *workspace
r.workspaces[workspace.ID] = &copy
return nil
}
func (r *WorkspaceRepositoryMock) GetByID(ctx context.Context, id string) (*entity.Workspace, error) {
r.mu.RLock()
defer r.mu.RUnlock()
workspace, ok := r.workspaces[id]
if !ok {
return nil, entity.ErrWorkspaceNotFound
}
copy := *workspace
return &copy, nil
}
func (r *WorkspaceRepositoryMock) GetByName(ctx context.Context, name string) (*entity.Workspace, error) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, workspace := range r.workspaces {
if workspace.Name == name {
copy := *workspace
return &copy, nil
}
}
return nil, entity.ErrWorkspaceNotFound
}
func (r *WorkspaceRepositoryMock) Update(ctx context.Context, workspace *entity.Workspace) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.workspaces[workspace.ID]; !ok {
return entity.ErrWorkspaceNotFound
}
copy := *workspace
r.workspaces[workspace.ID] = &copy
return nil
}
func (r *WorkspaceRepositoryMock) List(ctx context.Context) ([]*entity.Workspace, error) {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]*entity.Workspace, 0, len(r.workspaces))
for _, workspace := range r.workspaces {
copy := *workspace
result = append(result, &copy)
}
return result, nil
}
type WorkspaceClusterBindingRepositoryMock struct {
mu sync.RWMutex
bindings map[string]*entity.WorkspaceClusterBinding
}
func NewWorkspaceClusterBindingRepositoryMock() repository.WorkspaceClusterBindingRepository {
return &WorkspaceClusterBindingRepositoryMock{bindings: make(map[string]*entity.WorkspaceClusterBinding)}
}
func bindingKey(workspaceID, clusterID string) string {
return workspaceID + "/" + clusterID
}
func (r *WorkspaceClusterBindingRepositoryMock) Upsert(ctx context.Context, binding *entity.WorkspaceClusterBinding) error {
r.mu.Lock()
defer r.mu.Unlock()
if binding.ID == "" {
binding.ID = uuid.New().String()
}
copy := *binding
r.bindings[bindingKey(binding.WorkspaceID, binding.ClusterID)] = &copy
return nil
}
func (r *WorkspaceClusterBindingRepositoryMock) Get(ctx context.Context, workspaceID, clusterID string) (*entity.WorkspaceClusterBinding, error) {
r.mu.RLock()
defer r.mu.RUnlock()
binding, ok := r.bindings[bindingKey(workspaceID, clusterID)]
if !ok {
return nil, entity.ErrWorkspaceNotFound
}
copy := *binding
return &copy, nil
}
func (r *WorkspaceClusterBindingRepositoryMock) Delete(ctx context.Context, workspaceID, clusterID string) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.bindings, bindingKey(workspaceID, clusterID))
return nil
}
type AuditLogRepositoryMock struct {
mu sync.RWMutex
logs []*entity.AuditLog
}
func NewAuditLogRepositoryMock() repository.AuditLogRepository {
return &AuditLogRepositoryMock{logs: make([]*entity.AuditLog, 0)}
}
func (r *AuditLogRepositoryMock) Create(ctx context.Context, logEntry *entity.AuditLog) error {
r.mu.Lock()
defer r.mu.Unlock()
if logEntry.ID == "" {
logEntry.ID = uuid.New().String()
}
copy := *logEntry
r.logs = append(r.logs, &copy)
return nil
}
func (r *AuditLogRepositoryMock) ListByWorkspace(ctx context.Context, workspaceID string, limit int) ([]*entity.AuditLog, error) {
r.mu.RLock()
defer r.mu.RUnlock()
result := make([]*entity.AuditLog, 0)
for i := len(r.logs) - 1; i >= 0; i-- {
if r.logs[i].WorkspaceID == workspaceID {
copy := *r.logs[i]
result = append(result, &copy)
if limit > 0 && len(result) >= limit {
break
}
}
}
return result, nil
}

View File

@ -12,54 +12,33 @@ import (
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// ClusterRepository PostgreSQL 集群仓储实现
type ClusterRepository struct {
db *DB
encryptor crypto.Encryptor
}
// NewClusterRepository 创建 PostgreSQL 集群仓储
func NewClusterRepository(db *DB, encryptor crypto.Encryptor) repository.ClusterRepository {
return &ClusterRepository{
db: db,
encryptor: encryptor,
}
return &ClusterRepository{db: db, encryptor: encryptor}
}
// Create 创建集群
func (r *ClusterRepository) Create(ctx context.Context, cluster *entity.Cluster) error {
if cluster.ID == "" {
cluster.ID = uuid.New().String()
}
// 加密敏感数据
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken, err := r.encryptClusterSecrets(cluster)
if err != nil {
return fmt.Errorf("failed to encrypt CA data: %w", err)
return err
}
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
if err != nil {
return fmt.Errorf("failed to encrypt cert data: %w", err)
}
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
if err != nil {
return fmt.Errorf("failed to encrypt key data: %w", err)
}
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := `
INSERT INTO clusters (id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
INSERT INTO clusters
(id, workspace_id, owner_id, visibility, name, host, ca_data, cert_data, key_data, token, description, default_namespace, created_at, updated_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14)
`
_, err = r.db.conn.ExecContext(ctx, query,
cluster.ID,
cluster.WorkspaceID,
cluster.OwnerID,
cluster.Visibility,
cluster.Name,
cluster.Host,
encryptedCAData,
@ -67,160 +46,62 @@ func (r *ClusterRepository) Create(ctx context.Context, cluster *entity.Cluster)
encryptedKeyData,
encryptedToken,
cluster.Description,
cluster.DefaultNamespace,
cluster.CreatedAt,
cluster.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create cluster: %w", err)
}
return nil
}
// GetByID 根据 ID 获取集群
func (r *ClusterRepository) GetByID(ctx context.Context, id string) (*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
FROM clusters
WHERE id = $1
`
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrClusterNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
return cluster, nil
return r.get(ctx, "id = $1", id)
}
// GetByName 根据名称获取集群
func (r *ClusterRepository) GetByName(ctx context.Context, name string) (*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
return r.get(ctx, "name = $1", name)
}
func (r *ClusterRepository) get(ctx context.Context, where string, arg interface{}) (*entity.Cluster, error) {
query := fmt.Sprintf(`
SELECT id, workspace_id, owner_id, visibility, name, host, ca_data, cert_data, key_data, token, description, default_namespace, created_at, updated_at
FROM clusters
WHERE name = $1
`
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrClusterNotFound
}
WHERE %s
`, where)
rows, err := r.db.conn.QueryContext(ctx, query, arg)
if err != nil {
return nil, fmt.Errorf("failed to get cluster: %w", err)
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
defer rows.Close()
if !rows.Next() {
return nil, entity.ErrClusterNotFound
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
cluster, err := r.scanCluster(rows)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
return nil, err
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
return cluster, nil
}
// Update 更新集群
func (r *ClusterRepository) Update(ctx context.Context, cluster *entity.Cluster) error {
cluster.UpdatedAt = time.Now()
// 加密敏感数据
encryptedCAData, err := r.encryptor.Encrypt(cluster.CAData)
encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken, err := r.encryptClusterSecrets(cluster)
if err != nil {
return fmt.Errorf("failed to encrypt CA data: %w", err)
return err
}
encryptedCertData, err := r.encryptor.Encrypt(cluster.CertData)
if err != nil {
return fmt.Errorf("failed to encrypt cert data: %w", err)
}
encryptedKeyData, err := r.encryptor.Encrypt(cluster.KeyData)
if err != nil {
return fmt.Errorf("failed to encrypt key data: %w", err)
}
encryptedToken, err := r.encryptor.Encrypt(cluster.Token)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
}
query := `
UPDATE clusters
SET name = $1, host = $2, ca_data = $3, cert_data = $4, key_data = $5,
token = $6, description = $7, updated_at = $8
WHERE id = $9
SET workspace_id = $1, owner_id = $2, visibility = $3, name = $4, host = $5,
ca_data = $6, cert_data = $7, key_data = $8, token = $9, description = $10,
default_namespace = $11, updated_at = $12
WHERE id = $13
`
result, err := r.db.conn.ExecContext(ctx, query,
cluster.WorkspaceID,
cluster.OwnerID,
cluster.Visibility,
cluster.Name,
cluster.Host,
encryptedCAData,
@ -228,110 +109,134 @@ func (r *ClusterRepository) Update(ctx context.Context, cluster *entity.Cluster)
encryptedKeyData,
encryptedToken,
cluster.Description,
cluster.DefaultNamespace,
cluster.UpdatedAt,
cluster.ID,
)
if err != nil {
return fmt.Errorf("failed to update cluster: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrClusterNotFound
}
return nil
}
// Delete 删除集群
func (r *ClusterRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM clusters WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
result, err := r.db.conn.ExecContext(ctx, `DELETE FROM clusters WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("failed to delete cluster: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrClusterNotFound
}
return nil
}
// List 列出所有集群
func (r *ClusterRepository) List(ctx context.Context) ([]*entity.Cluster, error) {
query := `
SELECT id, name, host, ca_data, cert_data, key_data, token, description, created_at, updated_at
SELECT id, workspace_id, owner_id, visibility, name, host, ca_data, cert_data, key_data, token, description, default_namespace, created_at, updated_at
FROM clusters
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list clusters: %w", err)
}
defer rows.Close()
clusters := make([]*entity.Cluster, 0)
for rows.Next() {
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken string
err := rows.Scan(
&cluster.ID,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
cluster, err := r.scanCluster(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan cluster: %w", err)
return nil, err
}
// 解密敏感数据
cluster.CAData, err = r.encryptor.Decrypt(encryptedCAData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", err)
}
cluster.CertData, err = r.encryptor.Decrypt(encryptedCertData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", err)
}
cluster.KeyData, err = r.encryptor.Decrypt(encryptedKeyData)
if err != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", err)
}
cluster.Token, err = r.encryptor.Decrypt(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
clusters = append(clusters, cluster)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return clusters, nil
}
type clusterScanner interface {
Scan(dest ...interface{}) error
}
func (r *ClusterRepository) scanCluster(scanner clusterScanner) (*entity.Cluster, error) {
cluster := &entity.Cluster{}
var encryptedCAData, encryptedCertData, encryptedKeyData, encryptedToken sql.NullString
var defaultNamespace sql.NullString
err := scanner.Scan(
&cluster.ID,
&cluster.WorkspaceID,
&cluster.OwnerID,
&cluster.Visibility,
&cluster.Name,
&cluster.Host,
&encryptedCAData,
&encryptedCertData,
&encryptedKeyData,
&encryptedToken,
&cluster.Description,
&defaultNamespace,
&cluster.CreatedAt,
&cluster.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan cluster: %w", err)
}
cluster.DefaultNamespace = defaultNamespace.String
var decryptErr error
cluster.CAData, decryptErr = decryptMaybe(r.encryptor, encryptedCAData.String)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt CA data: %w", decryptErr)
}
cluster.CertData, decryptErr = decryptMaybe(r.encryptor, encryptedCertData.String)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt cert data: %w", decryptErr)
}
cluster.KeyData, decryptErr = decryptMaybe(r.encryptor, encryptedKeyData.String)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt key data: %w", decryptErr)
}
cluster.Token, decryptErr = decryptMaybe(r.encryptor, encryptedToken.String)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", decryptErr)
}
return cluster, nil
}
func (r *ClusterRepository) encryptClusterSecrets(cluster *entity.Cluster) (string, string, string, string, error) {
ca, err := r.encryptor.Encrypt(cluster.CAData)
if err != nil {
return "", "", "", "", fmt.Errorf("failed to encrypt CA data: %w", err)
}
cert, err := r.encryptor.Encrypt(cluster.CertData)
if err != nil {
return "", "", "", "", fmt.Errorf("failed to encrypt cert data: %w", err)
}
key, err := r.encryptor.Encrypt(cluster.KeyData)
if err != nil {
return "", "", "", "", fmt.Errorf("failed to encrypt key data: %w", err)
}
token, err := r.encryptor.Encrypt(cluster.Token)
if err != nil {
return "", "", "", "", fmt.Errorf("failed to encrypt token: %w", err)
}
return ca, cert, key, token, nil
}
func decryptMaybe(encryptor crypto.Encryptor, value string) (string, error) {
if value == "" {
return "", nil
}
return encryptor.Decrypt(value)
}

View File

@ -53,21 +53,69 @@ func (db *DB) GetConn() *sql.DB {
// InitSchema 初始化数据库 schema
func (db *DB) InitSchema() error {
schema := `
-- Workspaces 表
CREATE TABLE IF NOT EXISTS workspaces (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(255) NOT NULL UNIQUE,
status VARCHAR(50) NOT NULL DEFAULT 'active',
k8s_namespace VARCHAR(255) NOT NULL,
k8s_sa_name VARCHAR(255) NOT NULL,
default_cluster_id VARCHAR(36),
quota_cpu VARCHAR(50),
quota_memory VARCHAR(50),
quota_gpu VARCHAR(50),
quota_gpu_memory VARCHAR(50),
created_by VARCHAR(36),
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE workspaces
ADD COLUMN IF NOT EXISTS default_cluster_id VARCHAR(36),
ADD COLUMN IF NOT EXISTS quota_cpu VARCHAR(50),
ADD COLUMN IF NOT EXISTS quota_memory VARCHAR(50),
ADD COLUMN IF NOT EXISTS quota_gpu VARCHAR(50),
ADD COLUMN IF NOT EXISTS quota_gpu_memory VARCHAR(50);
INSERT INTO workspaces (id, name, status, k8s_namespace, k8s_sa_name, created_at, updated_at)
VALUES ('00000000-0000-0000-0000-000000000010', 'default', 'active', 'ocdp-ws-default', 'ocdp-ws-default', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (id) DO NOTHING;
-- Users 表
CREATE TABLE IF NOT EXISTS users (
id VARCHAR(36) PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
email VARCHAR(255) NOT NULL,
role VARCHAR(50) NOT NULL DEFAULT 'user',
workspace_id VARCHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000010',
is_active BOOLEAN NOT NULL DEFAULT TRUE,
must_change_password BOOLEAN NOT NULL DEFAULT FALSE,
revoked_after TIMESTAMP NOT NULL DEFAULT '1970-01-01 00:00:00',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE users
ADD COLUMN IF NOT EXISTS role VARCHAR(50) NOT NULL DEFAULT 'user',
ADD COLUMN IF NOT EXISTS workspace_id VARCHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000010',
ADD COLUMN IF NOT EXISTS is_active BOOLEAN NOT NULL DEFAULT TRUE,
ADD COLUMN IF NOT EXISTS must_change_password BOOLEAN NOT NULL DEFAULT FALSE,
ADD COLUMN IF NOT EXISTS revoked_after TIMESTAMP NOT NULL DEFAULT '1970-01-01 00:00:00';
UPDATE users SET role = 'admin' WHERE username = 'admin';
UPDATE users SET workspace_id = '00000000-0000-0000-0000-000000000010' WHERE workspace_id = '';
CREATE INDEX IF NOT EXISTS idx_users_username ON users(username);
CREATE INDEX IF NOT EXISTS idx_users_workspace ON users(workspace_id);
CREATE INDEX IF NOT EXISTS idx_users_revoked_after ON users(revoked_after);
-- Clusters 表
CREATE TABLE IF NOT EXISTS clusters (
id VARCHAR(36) PRIMARY KEY,
workspace_id VARCHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000010',
owner_id VARCHAR(36) NOT NULL DEFAULT '',
visibility VARCHAR(50) NOT NULL DEFAULT 'private',
name VARCHAR(255) NOT NULL UNIQUE,
host TEXT NOT NULL,
ca_data TEXT,
@ -75,15 +123,29 @@ func (db *DB) InitSchema() error {
key_data TEXT,
token TEXT,
description TEXT,
default_namespace VARCHAR(255),
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE clusters
ADD COLUMN IF NOT EXISTS workspace_id VARCHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000010',
ADD COLUMN IF NOT EXISTS owner_id VARCHAR(36) NOT NULL DEFAULT '',
ADD COLUMN IF NOT EXISTS visibility VARCHAR(50) NOT NULL DEFAULT 'private',
ADD COLUMN IF NOT EXISTS default_namespace VARCHAR(255);
UPDATE clusters SET visibility = 'global_shared' WHERE visibility = 'private' AND owner_id = '';
CREATE INDEX IF NOT EXISTS idx_clusters_name ON clusters(name);
CREATE INDEX IF NOT EXISTS idx_clusters_workspace ON clusters(workspace_id);
CREATE INDEX IF NOT EXISTS idx_clusters_owner ON clusters(owner_id);
CREATE INDEX IF NOT EXISTS idx_clusters_visibility ON clusters(visibility);
-- Registries 表
CREATE TABLE IF NOT EXISTS registries (
id VARCHAR(36) PRIMARY KEY,
workspace_id VARCHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000010',
owner_id VARCHAR(36) NOT NULL DEFAULT '',
visibility VARCHAR(50) NOT NULL DEFAULT 'private',
name VARCHAR(255) NOT NULL UNIQUE,
url TEXT NOT NULL,
description TEXT,
@ -94,11 +156,22 @@ func (db *DB) InitSchema() error {
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE registries
ADD COLUMN IF NOT EXISTS workspace_id VARCHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000010',
ADD COLUMN IF NOT EXISTS owner_id VARCHAR(36) NOT NULL DEFAULT '',
ADD COLUMN IF NOT EXISTS visibility VARCHAR(50) NOT NULL DEFAULT 'private';
UPDATE registries SET visibility = 'global_shared' WHERE visibility = 'private' AND owner_id = '';
CREATE INDEX IF NOT EXISTS idx_registries_name ON registries(name);
CREATE INDEX IF NOT EXISTS idx_registries_workspace ON registries(workspace_id);
CREATE INDEX IF NOT EXISTS idx_registries_owner ON registries(owner_id);
CREATE INDEX IF NOT EXISTS idx_registries_visibility ON registries(visibility);
-- Instances 表
CREATE TABLE IF NOT EXISTS instances (
id VARCHAR(36) PRIMARY KEY,
workspace_id VARCHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000010',
owner_id VARCHAR(36) NOT NULL DEFAULT '',
cluster_id VARCHAR(36) NOT NULL,
name VARCHAR(255) NOT NULL,
namespace VARCHAR(255) NOT NULL,
@ -121,9 +194,63 @@ func (db *DB) InitSchema() error {
CONSTRAINT unique_cluster_name UNIQUE (cluster_id, name, namespace)
);
ALTER TABLE instances
ADD COLUMN IF NOT EXISTS workspace_id VARCHAR(36) NOT NULL DEFAULT '00000000-0000-0000-0000-000000000010',
ADD COLUMN IF NOT EXISTS owner_id VARCHAR(36) NOT NULL DEFAULT '';
CREATE INDEX IF NOT EXISTS idx_instances_cluster ON instances(cluster_id);
CREATE INDEX IF NOT EXISTS idx_instances_registry ON instances(registry_id);
CREATE INDEX IF NOT EXISTS idx_instances_name ON instances(name);
CREATE INDEX IF NOT EXISTS idx_instances_workspace ON instances(workspace_id);
CREATE INDEX IF NOT EXISTS idx_instances_owner ON instances(owner_id);
CREATE TABLE IF NOT EXISTS workspace_cluster_bindings (
id VARCHAR(36) PRIMARY KEY,
workspace_id VARCHAR(36) NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE,
cluster_id VARCHAR(36) NOT NULL REFERENCES clusters(id) ON DELETE CASCADE,
namespace VARCHAR(255) NOT NULL,
service_account VARCHAR(255) NOT NULL,
quota_cpu VARCHAR(50),
quota_memory VARCHAR(50),
quota_gpu VARCHAR(50),
quota_gpu_memory VARCHAR(50),
status VARCHAR(50) NOT NULL DEFAULT 'active',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE (workspace_id, cluster_id)
);
ALTER TABLE workspace_cluster_bindings
ADD COLUMN IF NOT EXISTS quota_gpu_memory VARCHAR(50);
CREATE INDEX IF NOT EXISTS idx_workspace_cluster_bindings_workspace ON workspace_cluster_bindings(workspace_id);
CREATE INDEX IF NOT EXISTS idx_workspace_cluster_bindings_cluster ON workspace_cluster_bindings(cluster_id);
CREATE TABLE IF NOT EXISTS workspace_quotas (
id VARCHAR(36) PRIMARY KEY,
workspace_id VARCHAR(36) NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE,
resource_type VARCHAR(50) NOT NULL,
hard_limit VARCHAR(100) NOT NULL,
soft_limit VARCHAR(100),
used VARCHAR(100),
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE (workspace_id, resource_type)
);
CREATE TABLE IF NOT EXISTS audit_logs (
id VARCHAR(36) PRIMARY KEY,
workspace_id VARCHAR(36),
user_id VARCHAR(36),
action VARCHAR(100) NOT NULL,
resource_type VARCHAR(50) NOT NULL,
resource_id VARCHAR(36),
resource_name VARCHAR(255),
details JSONB,
ip_address VARCHAR(50),
user_agent TEXT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_audit_logs_workspace ON audit_logs(workspace_id);
CREATE INDEX IF NOT EXISTS idx_audit_logs_user ON audit_logs(user_id);
`
_, err := db.conn.Exec(schema)

View File

@ -12,37 +12,32 @@ import (
"github.com/ocdp/cluster-service/internal/domain/repository"
)
// InstanceRepository PostgreSQL 实例仓储实现
type InstanceRepository struct {
db *DB
}
// NewInstanceRepository 创建 PostgreSQL 实例仓储
func NewInstanceRepository(db *DB) repository.InstanceRepository {
return &InstanceRepository{db: db}
}
// Create 创建实例
func (r *InstanceRepository) Create(ctx context.Context, instance *entity.Instance) error {
if instance.ID == "" {
instance.ID = uuid.New().String()
}
// 将 Values 转换为 JSON
valuesJSON, err := json.Marshal(instance.Values)
if err != nil {
return fmt.Errorf("failed to marshal values: %w", err)
}
query := `
INSERT INTO instances (id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)
INSERT INTO instances
(id, workspace_id, owner_id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error, revision, created_at, updated_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20)
`
_, err = r.db.conn.ExecContext(ctx, query,
instance.ID,
instance.WorkspaceID,
instance.OwnerID,
instance.ClusterID,
instance.Name,
instance.Namespace,
@ -61,166 +56,71 @@ func (r *InstanceRepository) Create(ctx context.Context, instance *entity.Instan
instance.CreatedAt,
instance.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create instance: %w", err)
}
return nil
}
// GetByID 根据 ID 获取实例
func (r *InstanceRepository) GetByID(ctx context.Context, id string) (*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE id = $1
`
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrInstanceNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
return instance, nil
return r.get(ctx, "id = $1", id)
}
// GetByClusterAndName 根据集群 ID 和名称获取实例
func (r *InstanceRepository) GetByClusterAndName(ctx context.Context, clusterID, name string) (*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
SELECT id, workspace_id, owner_id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE cluster_id = $1 AND name = $2
`
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := r.db.conn.QueryRowContext(ctx, query, clusterID, name).Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrInstanceNotFound
}
rows, err := r.db.conn.QueryContext(ctx, query, clusterID, name)
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
defer rows.Close()
if !rows.Next() {
return nil, entity.ErrInstanceNotFound
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
return instance, nil
return r.scanInstance(rows)
}
func (r *InstanceRepository) get(ctx context.Context, where string, arg interface{}) (*entity.Instance, error) {
query := fmt.Sprintf(`
SELECT id, workspace_id, owner_id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE %s
`, where)
rows, err := r.db.conn.QueryContext(ctx, query, arg)
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
}
defer rows.Close()
if !rows.Next() {
return nil, entity.ErrInstanceNotFound
}
return r.scanInstance(rows)
}
// Update 更新实例
func (r *InstanceRepository) Update(ctx context.Context, instance *entity.Instance) error {
instance.UpdatedAt = time.Now()
// 将 Values 转换为 JSON
valuesJSON, err := json.Marshal(instance.Values)
if err != nil {
return fmt.Errorf("failed to marshal values: %w", err)
}
query := `
UPDATE instances
SET cluster_id = $1, name = $2, namespace = $3, registry_id = $4, repository = $5,
chart = $6, version = $7, description = $8, values = $9, values_yaml = $10,
status = $11, status_reason = $12, last_operation = $13, last_error = $14,
revision = $15, updated_at = $16
WHERE id = $17
SET workspace_id = $1, owner_id = $2, cluster_id = $3, name = $4, namespace = $5,
registry_id = $6, repository = $7, chart = $8, version = $9, description = $10,
values = $11, values_yaml = $12, status = $13, status_reason = $14,
last_operation = $15, last_error = $16, revision = $17, updated_at = $18
WHERE id = $19
`
result, err := r.db.conn.ExecContext(ctx, query,
instance.WorkspaceID,
instance.OwnerID,
instance.ClusterID,
instance.Name,
instance.Namespace,
@ -239,195 +139,126 @@ func (r *InstanceRepository) Update(ctx context.Context, instance *entity.Instan
instance.UpdatedAt,
instance.ID,
)
if err != nil {
return fmt.Errorf("failed to update instance: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrInstanceNotFound
}
return nil
}
// Delete 删除实例
func (r *InstanceRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM instances WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
result, err := r.db.conn.ExecContext(ctx, `DELETE FROM instances WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("failed to delete instance: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrInstanceNotFound
}
return nil
}
// ListByCluster 列出指定集群的所有实例
func (r *InstanceRepository) ListByCluster(ctx context.Context, clusterID string) ([]*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
WHERE cluster_id = $1
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query, clusterID)
if err != nil {
return nil, fmt.Errorf("failed to list instances: %w", err)
}
defer rows.Close()
instances := make([]*entity.Instance, 0)
for rows.Next() {
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := rows.Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan instance: %w", err)
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
instances = append(instances, instance)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return instances, nil
return r.list(ctx, "WHERE cluster_id = $1", clusterID)
}
// List 列出所有实例
func (r *InstanceRepository) List(ctx context.Context) ([]*entity.Instance, error) {
return r.list(ctx, "", nil)
}
func (r *InstanceRepository) list(ctx context.Context, where string, arg interface{}) ([]*entity.Instance, error) {
query := `
SELECT id, cluster_id, name, namespace, registry_id, repository, chart, version,
SELECT id, workspace_id, owner_id, cluster_id, name, namespace, registry_id, repository, chart, version,
description, values, values_yaml, status, status_reason, last_operation, last_error,
revision, created_at, updated_at
FROM instances
` + where + `
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
var rows *sql.Rows
var err error
if where == "" {
rows, err = r.db.conn.QueryContext(ctx, query)
} else {
rows, err = r.db.conn.QueryContext(ctx, query, arg)
}
if err != nil {
return nil, fmt.Errorf("failed to list instances: %w", err)
}
defer rows.Close()
instances := make([]*entity.Instance, 0)
for rows.Next() {
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := rows.Scan(
&instance.ID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
instance, err := r.scanInstance(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan instance: %w", err)
return nil, err
}
// 解析 JSON Values
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
instances = append(instances, instance)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return instances, nil
}
type instanceScanner interface {
Scan(dest ...interface{}) error
}
func (r *InstanceRepository) scanInstance(scanner instanceScanner) (*entity.Instance, error) {
instance := &entity.Instance{}
var (
valuesJSON []byte
statusReason sql.NullString
lastOperation sql.NullString
lastError sql.NullString
)
err := scanner.Scan(
&instance.ID,
&instance.WorkspaceID,
&instance.OwnerID,
&instance.ClusterID,
&instance.Name,
&instance.Namespace,
&instance.RegistryID,
&instance.Repository,
&instance.Chart,
&instance.Version,
&instance.Description,
&valuesJSON,
&instance.ValuesYAML,
&instance.Status,
&statusReason,
&lastOperation,
&lastError,
&instance.Revision,
&instance.CreatedAt,
&instance.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan instance: %w", err)
}
if len(valuesJSON) > 0 {
if err := json.Unmarshal(valuesJSON, &instance.Values); err != nil {
return nil, fmt.Errorf("failed to unmarshal values: %w", err)
}
}
if statusReason.Valid {
instance.StatusReason = statusReason.String
}
if lastOperation.Valid {
instance.LastOperation = entity.InstanceOperation(lastOperation.String)
}
if lastError.Valid {
instance.LastError = lastError.String
}
return instance, nil
}

View File

@ -12,39 +12,32 @@ import (
"github.com/ocdp/cluster-service/internal/pkg/crypto"
)
// RegistryRepository PostgreSQL Registry 仓储实现
type RegistryRepository struct {
db *DB
encryptor crypto.Encryptor
}
// NewRegistryRepository 创建 PostgreSQL Registry 仓储
func NewRegistryRepository(db *DB, encryptor crypto.Encryptor) repository.RegistryRepository {
return &RegistryRepository{
db: db,
encryptor: encryptor,
}
return &RegistryRepository{db: db, encryptor: encryptor}
}
// Create 创建 Registry
func (r *RegistryRepository) Create(ctx context.Context, registry *entity.Registry) error {
if registry.ID == "" {
registry.ID = uuid.New().String()
}
// 加密密码
encryptedPassword, err := r.encryptor.Encrypt(registry.Password)
if err != nil {
return fmt.Errorf("failed to encrypt password: %w", err)
}
query := `
INSERT INTO registries (id, name, url, description, username, password, insecure, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
INSERT INTO registries (id, workspace_id, owner_id, visibility, name, url, description, username, password, insecure, created_at, updated_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12)
`
_, err = r.db.conn.ExecContext(ctx, query,
registry.ID,
registry.WorkspaceID,
registry.OwnerID,
registry.Visibility,
registry.Name,
registry.URL,
registry.Description,
@ -54,110 +47,57 @@ func (r *RegistryRepository) Create(ctx context.Context, registry *entity.Regist
registry.CreatedAt,
registry.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create registry: %w", err)
}
return nil
}
// GetByID 根据 ID 获取 Registry
func (r *RegistryRepository) GetByID(ctx context.Context, id string) (*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
WHERE id = $1
`
registry := &entity.Registry{}
var encryptedPassword string
err := r.db.conn.QueryRowContext(ctx, query, id).Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrRegistryNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get registry: %w", err)
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
return registry, nil
return r.get(ctx, "id = $1", id)
}
// GetByName 根据名称获取 Registry
func (r *RegistryRepository) GetByName(ctx context.Context, name string) (*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
return r.get(ctx, "name = $1", name)
}
func (r *RegistryRepository) get(ctx context.Context, where string, arg interface{}) (*entity.Registry, error) {
query := fmt.Sprintf(`
SELECT id, workspace_id, owner_id, visibility, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
WHERE name = $1
`
registry := &entity.Registry{}
var encryptedPassword string
err := r.db.conn.QueryRowContext(ctx, query, name).Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrRegistryNotFound
}
WHERE %s
`, where)
rows, err := r.db.conn.QueryContext(ctx, query, arg)
if err != nil {
return nil, fmt.Errorf("failed to get registry: %w", err)
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
defer rows.Close()
if !rows.Next() {
return nil, entity.ErrRegistryNotFound
}
registry, err := r.scanRegistry(rows)
if err != nil {
return nil, err
}
return registry, nil
}
// Update 更新 Registry
func (r *RegistryRepository) Update(ctx context.Context, registry *entity.Registry) error {
registry.UpdatedAt = time.Now()
// 加密密码
encryptedPassword, err := r.encryptor.Encrypt(registry.Password)
if err != nil {
return fmt.Errorf("failed to encrypt password: %w", err)
}
query := `
UPDATE registries
SET name = $1, url = $2, description = $3, username = $4, password = $5,
insecure = $6, updated_at = $7
WHERE id = $8
SET workspace_id = $1, owner_id = $2, visibility = $3, name = $4, url = $5,
description = $6, username = $7, password = $8, insecure = $9, updated_at = $10
WHERE id = $11
`
result, err := r.db.conn.ExecContext(ctx, query,
registry.WorkspaceID,
registry.OwnerID,
registry.Visibility,
registry.Name,
registry.URL,
registry.Description,
@ -167,91 +107,86 @@ func (r *RegistryRepository) Update(ctx context.Context, registry *entity.Regist
registry.UpdatedAt,
registry.ID,
)
if err != nil {
return fmt.Errorf("failed to update registry: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrRegistryNotFound
}
return nil
}
// Delete 删除 Registry
func (r *RegistryRepository) Delete(ctx context.Context, id string) error {
query := `DELETE FROM registries WHERE id = $1`
result, err := r.db.conn.ExecContext(ctx, query, id)
result, err := r.db.conn.ExecContext(ctx, `DELETE FROM registries WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("failed to delete registry: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrRegistryNotFound
}
return nil
}
// List 列出所有 Registries
func (r *RegistryRepository) List(ctx context.Context) ([]*entity.Registry, error) {
query := `
SELECT id, name, url, description, username, password, insecure, created_at, updated_at
SELECT id, workspace_id, owner_id, visibility, name, url, description, username, password, insecure, created_at, updated_at
FROM registries
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list registries: %w", err)
}
defer rows.Close()
registries := make([]*entity.Registry, 0)
for rows.Next() {
registry := &entity.Registry{}
var encryptedPassword string
err := rows.Scan(
&registry.ID,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
registry, err := r.scanRegistry(rows)
if err != nil {
return nil, fmt.Errorf("failed to scan registry: %w", err)
return nil, err
}
// 解密密码
registry.Password, err = r.encryptor.Decrypt(encryptedPassword)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
registries = append(registries, registry)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("rows iteration error: %w", err)
}
return registries, nil
}
type registryScanner interface {
Scan(dest ...interface{}) error
}
func (r *RegistryRepository) scanRegistry(scanner registryScanner) (*entity.Registry, error) {
registry := &entity.Registry{}
var encryptedPassword sql.NullString
err := scanner.Scan(
&registry.ID,
&registry.WorkspaceID,
&registry.OwnerID,
&registry.Visibility,
&registry.Name,
&registry.URL,
&registry.Description,
&registry.Username,
&encryptedPassword,
&registry.Insecure,
&registry.CreatedAt,
&registry.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan registry: %w", err)
}
registry.Password, err = decryptMaybe(r.encryptor, encryptedPassword.String)
if err != nil {
return nil, fmt.Errorf("failed to decrypt password: %w", err)
}
return registry, nil
}

View File

@ -28,8 +28,8 @@ func (r *UserRepository) Create(ctx context.Context, user *entity.User) error {
}
query := `
INSERT INTO users (id, username, password_hash, email, revoked_after, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
INSERT INTO users (id, username, password_hash, email, role, workspace_id, is_active, must_change_password, revoked_after, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
`
_, err := r.db.conn.ExecContext(ctx, query,
@ -37,6 +37,10 @@ func (r *UserRepository) Create(ctx context.Context, user *entity.User) error {
user.Username,
user.PasswordHash,
user.Email,
user.Role,
user.WorkspaceID,
user.IsActive,
user.MustChangePassword,
user.RevokedAfter,
user.CreatedAt,
user.UpdatedAt,
@ -52,7 +56,7 @@ func (r *UserRepository) Create(ctx context.Context, user *entity.User) error {
// GetByID 根据 ID 获取用户
func (r *UserRepository) GetByID(ctx context.Context, id string) (*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
SELECT id, username, password_hash, email, role, workspace_id, is_active, must_change_password, revoked_after, created_at, updated_at
FROM users
WHERE id = $1
`
@ -63,6 +67,10 @@ func (r *UserRepository) GetByID(ctx context.Context, id string) (*entity.User,
&user.Username,
&user.PasswordHash,
&user.Email,
&user.Role,
&user.WorkspaceID,
&user.IsActive,
&user.MustChangePassword,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
@ -81,7 +89,7 @@ func (r *UserRepository) GetByID(ctx context.Context, id string) (*entity.User,
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
SELECT id, username, password_hash, email, role, workspace_id, is_active, must_change_password, revoked_after, created_at, updated_at
FROM users
WHERE username = $1
`
@ -92,6 +100,10 @@ func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*e
&user.Username,
&user.PasswordHash,
&user.Email,
&user.Role,
&user.WorkspaceID,
&user.IsActive,
&user.MustChangePassword,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
@ -113,14 +125,19 @@ func (r *UserRepository) Update(ctx context.Context, user *entity.User) error {
query := `
UPDATE users
SET username = $1, password_hash = $2, email = $3, revoked_after = $4, updated_at = $5
WHERE id = $6
SET username = $1, password_hash = $2, email = $3, role = $4, workspace_id = $5,
is_active = $6, must_change_password = $7, revoked_after = $8, updated_at = $9
WHERE id = $10
`
result, err := r.db.conn.ExecContext(ctx, query,
user.Username,
user.PasswordHash,
user.Email,
user.Role,
user.WorkspaceID,
user.IsActive,
user.MustChangePassword,
user.RevokedAfter,
user.UpdatedAt,
user.ID,
@ -166,7 +183,7 @@ func (r *UserRepository) Delete(ctx context.Context, id string) error {
// List 列出所有用户
func (r *UserRepository) List(ctx context.Context) ([]*entity.User, error) {
query := `
SELECT id, username, password_hash, email, revoked_after, created_at, updated_at
SELECT id, username, password_hash, email, role, workspace_id, is_active, must_change_password, revoked_after, created_at, updated_at
FROM users
ORDER BY created_at DESC
`
@ -185,6 +202,10 @@ func (r *UserRepository) List(ctx context.Context) ([]*entity.User, error) {
&user.Username,
&user.PasswordHash,
&user.Email,
&user.Role,
&user.WorkspaceID,
&user.IsActive,
&user.MustChangePassword,
&user.RevokedAfter,
&user.CreatedAt,
&user.UpdatedAt,
@ -201,4 +222,3 @@ func (r *UserRepository) List(ctx context.Context) ([]*entity.User, error) {
return users, nil
}

View File

@ -0,0 +1,345 @@
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/ocdp/cluster-service/internal/domain/entity"
"github.com/ocdp/cluster-service/internal/domain/repository"
)
type WorkspaceRepository struct {
db *DB
}
func NewWorkspaceRepository(db *DB) repository.WorkspaceRepository {
return &WorkspaceRepository{db: db}
}
func (r *WorkspaceRepository) Create(ctx context.Context, workspace *entity.Workspace) error {
if workspace.ID == "" {
workspace.ID = uuid.New().String()
}
query := `
INSERT INTO workspaces (id, name, status, k8s_namespace, k8s_sa_name, default_cluster_id, quota_cpu, quota_memory, quota_gpu, quota_gpu_memory, created_by, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
`
_, err := r.db.conn.ExecContext(ctx, query,
workspace.ID,
workspace.Name,
workspace.Status,
workspace.K8sNamespace,
workspace.K8sSAName,
workspace.DefaultClusterID,
workspace.QuotaCPU,
workspace.QuotaMemory,
workspace.QuotaGPU,
workspace.QuotaGPUMem,
workspace.CreatedBy,
workspace.CreatedAt,
workspace.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to create workspace: %w", err)
}
return nil
}
func (r *WorkspaceRepository) GetByID(ctx context.Context, id string) (*entity.Workspace, error) {
return r.get(ctx, "id = $1", id)
}
func (r *WorkspaceRepository) GetByName(ctx context.Context, name string) (*entity.Workspace, error) {
return r.get(ctx, "name = $1", name)
}
func (r *WorkspaceRepository) get(ctx context.Context, where string, arg interface{}) (*entity.Workspace, error) {
query := fmt.Sprintf(`
SELECT id, name, status, k8s_namespace, k8s_sa_name, default_cluster_id, quota_cpu, quota_memory, quota_gpu, quota_gpu_memory, created_by, created_at, updated_at
FROM workspaces
WHERE %s
`, where)
workspace := &entity.Workspace{}
var createdBy, defaultClusterID, quotaCPU, quotaMemory, quotaGPU, quotaGPUMem sql.NullString
err := r.db.conn.QueryRowContext(ctx, query, arg).Scan(
&workspace.ID,
&workspace.Name,
&workspace.Status,
&workspace.K8sNamespace,
&workspace.K8sSAName,
&defaultClusterID,
&quotaCPU,
&quotaMemory,
&quotaGPU,
&quotaGPUMem,
&createdBy,
&workspace.CreatedAt,
&workspace.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrWorkspaceNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get workspace: %w", err)
}
workspace.CreatedBy = createdBy.String
workspace.DefaultClusterID = defaultClusterID.String
workspace.QuotaCPU = quotaCPU.String
workspace.QuotaMemory = quotaMemory.String
workspace.QuotaGPU = quotaGPU.String
workspace.QuotaGPUMem = quotaGPUMem.String
return workspace, nil
}
func (r *WorkspaceRepository) Update(ctx context.Context, workspace *entity.Workspace) error {
workspace.UpdatedAt = time.Now()
query := `
UPDATE workspaces
SET name = $1, status = $2, k8s_namespace = $3, k8s_sa_name = $4,
default_cluster_id = $5,
quota_cpu = $6, quota_memory = $7, quota_gpu = $8, quota_gpu_memory = $9,
created_by = $10, updated_at = $11
WHERE id = $12
`
result, err := r.db.conn.ExecContext(ctx, query,
workspace.Name,
workspace.Status,
workspace.K8sNamespace,
workspace.K8sSAName,
workspace.DefaultClusterID,
workspace.QuotaCPU,
workspace.QuotaMemory,
workspace.QuotaGPU,
workspace.QuotaGPUMem,
workspace.CreatedBy,
workspace.UpdatedAt,
workspace.ID,
)
if err != nil {
return fmt.Errorf("failed to update workspace: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
return entity.ErrWorkspaceNotFound
}
return nil
}
func (r *WorkspaceRepository) List(ctx context.Context) ([]*entity.Workspace, error) {
query := `
SELECT id, name, status, k8s_namespace, k8s_sa_name, default_cluster_id, quota_cpu, quota_memory, quota_gpu, quota_gpu_memory, created_by, created_at, updated_at
FROM workspaces
ORDER BY created_at DESC
`
rows, err := r.db.conn.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list workspaces: %w", err)
}
defer rows.Close()
workspaces := make([]*entity.Workspace, 0)
for rows.Next() {
workspace := &entity.Workspace{}
var createdBy, defaultClusterID, quotaCPU, quotaMemory, quotaGPU, quotaGPUMem sql.NullString
if err := rows.Scan(
&workspace.ID,
&workspace.Name,
&workspace.Status,
&workspace.K8sNamespace,
&workspace.K8sSAName,
&defaultClusterID,
&quotaCPU,
&quotaMemory,
&quotaGPU,
&quotaGPUMem,
&createdBy,
&workspace.CreatedAt,
&workspace.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan workspace: %w", err)
}
workspace.CreatedBy = createdBy.String
workspace.DefaultClusterID = defaultClusterID.String
workspace.QuotaCPU = quotaCPU.String
workspace.QuotaMemory = quotaMemory.String
workspace.QuotaGPU = quotaGPU.String
workspace.QuotaGPUMem = quotaGPUMem.String
workspaces = append(workspaces, workspace)
}
return workspaces, rows.Err()
}
type WorkspaceClusterBindingRepository struct {
db *DB
}
func NewWorkspaceClusterBindingRepository(db *DB) repository.WorkspaceClusterBindingRepository {
return &WorkspaceClusterBindingRepository{db: db}
}
func (r *WorkspaceClusterBindingRepository) Upsert(ctx context.Context, binding *entity.WorkspaceClusterBinding) error {
if binding.ID == "" {
binding.ID = uuid.New().String()
}
now := time.Now()
if binding.CreatedAt.IsZero() {
binding.CreatedAt = now
}
binding.UpdatedAt = now
query := `
INSERT INTO workspace_cluster_bindings
(id, workspace_id, cluster_id, namespace, service_account, quota_cpu, quota_memory, quota_gpu, quota_gpu_memory, status, created_at, updated_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12)
ON CONFLICT (workspace_id, cluster_id)
DO UPDATE SET namespace = EXCLUDED.namespace,
service_account = EXCLUDED.service_account,
quota_cpu = EXCLUDED.quota_cpu,
quota_memory = EXCLUDED.quota_memory,
quota_gpu = EXCLUDED.quota_gpu,
quota_gpu_memory = EXCLUDED.quota_gpu_memory,
status = EXCLUDED.status,
updated_at = EXCLUDED.updated_at
`
_, err := r.db.conn.ExecContext(ctx, query,
binding.ID,
binding.WorkspaceID,
binding.ClusterID,
binding.Namespace,
binding.ServiceAccount,
binding.QuotaCPU,
binding.QuotaMemory,
binding.QuotaGPU,
binding.QuotaGPUMem,
binding.Status,
binding.CreatedAt,
binding.UpdatedAt,
)
if err != nil {
return fmt.Errorf("failed to upsert workspace cluster binding: %w", err)
}
return nil
}
func (r *WorkspaceClusterBindingRepository) Get(ctx context.Context, workspaceID, clusterID string) (*entity.WorkspaceClusterBinding, error) {
query := `
SELECT id, workspace_id, cluster_id, namespace, service_account, quota_cpu, quota_memory, quota_gpu, quota_gpu_memory, status, created_at, updated_at
FROM workspace_cluster_bindings
WHERE workspace_id = $1 AND cluster_id = $2
`
binding := &entity.WorkspaceClusterBinding{}
err := r.db.conn.QueryRowContext(ctx, query, workspaceID, clusterID).Scan(
&binding.ID,
&binding.WorkspaceID,
&binding.ClusterID,
&binding.Namespace,
&binding.ServiceAccount,
&binding.QuotaCPU,
&binding.QuotaMemory,
&binding.QuotaGPU,
&binding.QuotaGPUMem,
&binding.Status,
&binding.CreatedAt,
&binding.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, entity.ErrWorkspaceNotFound
}
if err != nil {
return nil, fmt.Errorf("failed to get workspace cluster binding: %w", err)
}
return binding, nil
}
func (r *WorkspaceClusterBindingRepository) Delete(ctx context.Context, workspaceID, clusterID string) error {
_, err := r.db.conn.ExecContext(ctx, `DELETE FROM workspace_cluster_bindings WHERE workspace_id = $1 AND cluster_id = $2`, workspaceID, clusterID)
return err
}
type AuditLogRepository struct {
db *DB
}
func NewAuditLogRepository(db *DB) repository.AuditLogRepository {
return &AuditLogRepository{db: db}
}
func (r *AuditLogRepository) Create(ctx context.Context, logEntry *entity.AuditLog) error {
if logEntry.ID == "" {
logEntry.ID = uuid.New().String()
}
details, err := json.Marshal(logEntry.Details)
if err != nil {
return fmt.Errorf("failed to marshal audit details: %w", err)
}
if logEntry.CreatedAt.IsZero() {
logEntry.CreatedAt = time.Now()
}
query := `
INSERT INTO audit_logs (id, workspace_id, user_id, action, resource_type, resource_id, resource_name, details, ip_address, user_agent, created_at)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11)
`
_, err = r.db.conn.ExecContext(ctx, query,
logEntry.ID,
logEntry.WorkspaceID,
logEntry.UserID,
logEntry.Action,
logEntry.ResourceType,
logEntry.ResourceID,
logEntry.ResourceName,
string(details),
logEntry.IPAddress,
logEntry.UserAgent,
logEntry.CreatedAt,
)
if err != nil {
return fmt.Errorf("failed to create audit log: %w", err)
}
return nil
}
func (r *AuditLogRepository) ListByWorkspace(ctx context.Context, workspaceID string, limit int) ([]*entity.AuditLog, error) {
if limit <= 0 || limit > 500 {
limit = 100
}
query := `
SELECT id, workspace_id, user_id, action, resource_type, resource_id, resource_name, details, ip_address, user_agent, created_at
FROM audit_logs
WHERE workspace_id = $1
ORDER BY created_at DESC
LIMIT $2
`
rows, err := r.db.conn.QueryContext(ctx, query, workspaceID, limit)
if err != nil {
return nil, fmt.Errorf("failed to list audit logs: %w", err)
}
defer rows.Close()
result := make([]*entity.AuditLog, 0)
for rows.Next() {
logEntry := &entity.AuditLog{}
var details []byte
if err := rows.Scan(
&logEntry.ID,
&logEntry.WorkspaceID,
&logEntry.UserID,
&logEntry.Action,
&logEntry.ResourceType,
&logEntry.ResourceID,
&logEntry.ResourceName,
&details,
&logEntry.IPAddress,
&logEntry.UserAgent,
&logEntry.CreatedAt,
); err != nil {
return nil, fmt.Errorf("failed to scan audit log: %w", err)
}
_ = json.Unmarshal(details, &logEntry.Details)
result = append(result, logEntry)
}
return result, rows.Err()
}