From 3321117075756eb715d20c41ec82c634586db157 Mon Sep 17 00:00:00 2001 From: Helder Santana Date: Fri, 8 Nov 2024 20:15:26 +0100 Subject: [PATCH] private endpoint translation layer (#1904) --- .mockery.yaml | 1 + .../translation/private_endpoint_service.go | 381 +++++++++++ .../translation/privateendpoint/conversion.go | 475 ++++++++++++++ .../privateendpoint/privateendpoint.go | 181 +++++ .../privateendpoint/privateendpoint_test.go | 616 ++++++++++++++++++ 5 files changed, 1654 insertions(+) create mode 100644 internal/mocks/translation/private_endpoint_service.go create mode 100644 internal/translation/privateendpoint/conversion.go create mode 100644 internal/translation/privateendpoint/privateendpoint.go create mode 100644 internal/translation/privateendpoint/privateendpoint_test.go diff --git a/.mockery.yaml b/.mockery.yaml index a82f64d301..3d642854fb 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -13,3 +13,4 @@ packages: github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/customroles: github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/datafederation: github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/teams: + github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/privateendpoint: diff --git a/internal/mocks/translation/private_endpoint_service.go b/internal/mocks/translation/private_endpoint_service.go new file mode 100644 index 0000000000..29d4d676de --- /dev/null +++ b/internal/mocks/translation/private_endpoint_service.go @@ -0,0 +1,381 @@ +// Code generated by mockery. DO NOT EDIT. + +package translation + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + privateendpoint "github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/translation/privateendpoint" +) + +// PrivateEndpointServiceMock is an autogenerated mock type for the PrivateEndpointService type +type PrivateEndpointServiceMock struct { + mock.Mock +} + +type PrivateEndpointServiceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *PrivateEndpointServiceMock) EXPECT() *PrivateEndpointServiceMock_Expecter { + return &PrivateEndpointServiceMock_Expecter{mock: &_m.Mock} +} + +// CreatePrivateEndpointInterface provides a mock function with given fields: ctx, projectID, provider, serviceID, gcpProjectID, peInterface +func (_m *PrivateEndpointServiceMock) CreatePrivateEndpointInterface(ctx context.Context, projectID string, provider string, serviceID string, gcpProjectID string, peInterface privateendpoint.EndpointInterface) (privateendpoint.EndpointInterface, error) { + ret := _m.Called(ctx, projectID, provider, serviceID, gcpProjectID, peInterface) + + if len(ret) == 0 { + panic("no return value specified for CreatePrivateEndpointInterface") + } + + var r0 privateendpoint.EndpointInterface + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, privateendpoint.EndpointInterface) (privateendpoint.EndpointInterface, error)); ok { + return rf(ctx, projectID, provider, serviceID, gcpProjectID, peInterface) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string, privateendpoint.EndpointInterface) privateendpoint.EndpointInterface); ok { + r0 = rf(ctx, projectID, provider, serviceID, gcpProjectID, peInterface) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(privateendpoint.EndpointInterface) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, string, privateendpoint.EndpointInterface) error); ok { + r1 = rf(ctx, projectID, provider, serviceID, gcpProjectID, peInterface) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreatePrivateEndpointInterface' +type PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call struct { + *mock.Call +} + +// CreatePrivateEndpointInterface is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - provider string +// - serviceID string +// - gcpProjectID string +// - peInterface privateendpoint.EndpointInterface +func (_e *PrivateEndpointServiceMock_Expecter) CreatePrivateEndpointInterface(ctx interface{}, projectID interface{}, provider interface{}, serviceID interface{}, gcpProjectID interface{}, peInterface interface{}) *PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call { + return &PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call{Call: _e.mock.On("CreatePrivateEndpointInterface", ctx, projectID, provider, serviceID, gcpProjectID, peInterface)} +} + +func (_c *PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call) Run(run func(ctx context.Context, projectID string, provider string, serviceID string, gcpProjectID string, peInterface privateendpoint.EndpointInterface)) *PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string), args[5].(privateendpoint.EndpointInterface)) + }) + return _c +} + +func (_c *PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call) Return(_a0 privateendpoint.EndpointInterface, _a1 error) *PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call) RunAndReturn(run func(context.Context, string, string, string, string, privateendpoint.EndpointInterface) (privateendpoint.EndpointInterface, error)) *PrivateEndpointServiceMock_CreatePrivateEndpointInterface_Call { + _c.Call.Return(run) + return _c +} + +// CreatePrivateEndpointService provides a mock function with given fields: ctx, projectID, peService +func (_m *PrivateEndpointServiceMock) CreatePrivateEndpointService(ctx context.Context, projectID string, peService privateendpoint.EndpointService) (privateendpoint.EndpointService, error) { + ret := _m.Called(ctx, projectID, peService) + + if len(ret) == 0 { + panic("no return value specified for CreatePrivateEndpointService") + } + + var r0 privateendpoint.EndpointService + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, privateendpoint.EndpointService) (privateendpoint.EndpointService, error)); ok { + return rf(ctx, projectID, peService) + } + if rf, ok := ret.Get(0).(func(context.Context, string, privateendpoint.EndpointService) privateendpoint.EndpointService); ok { + r0 = rf(ctx, projectID, peService) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(privateendpoint.EndpointService) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, privateendpoint.EndpointService) error); ok { + r1 = rf(ctx, projectID, peService) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PrivateEndpointServiceMock_CreatePrivateEndpointService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreatePrivateEndpointService' +type PrivateEndpointServiceMock_CreatePrivateEndpointService_Call struct { + *mock.Call +} + +// CreatePrivateEndpointService is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - peService privateendpoint.EndpointService +func (_e *PrivateEndpointServiceMock_Expecter) CreatePrivateEndpointService(ctx interface{}, projectID interface{}, peService interface{}) *PrivateEndpointServiceMock_CreatePrivateEndpointService_Call { + return &PrivateEndpointServiceMock_CreatePrivateEndpointService_Call{Call: _e.mock.On("CreatePrivateEndpointService", ctx, projectID, peService)} +} + +func (_c *PrivateEndpointServiceMock_CreatePrivateEndpointService_Call) Run(run func(ctx context.Context, projectID string, peService privateendpoint.EndpointService)) *PrivateEndpointServiceMock_CreatePrivateEndpointService_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(privateendpoint.EndpointService)) + }) + return _c +} + +func (_c *PrivateEndpointServiceMock_CreatePrivateEndpointService_Call) Return(_a0 privateendpoint.EndpointService, _a1 error) *PrivateEndpointServiceMock_CreatePrivateEndpointService_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PrivateEndpointServiceMock_CreatePrivateEndpointService_Call) RunAndReturn(run func(context.Context, string, privateendpoint.EndpointService) (privateendpoint.EndpointService, error)) *PrivateEndpointServiceMock_CreatePrivateEndpointService_Call { + _c.Call.Return(run) + return _c +} + +// DeleteEndpointInterface provides a mock function with given fields: ctx, projectID, provider, serviceID, ID +func (_m *PrivateEndpointServiceMock) DeleteEndpointInterface(ctx context.Context, projectID string, provider string, serviceID string, ID string) error { + ret := _m.Called(ctx, projectID, provider, serviceID, ID) + + if len(ret) == 0 { + panic("no return value specified for DeleteEndpointInterface") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, string) error); ok { + r0 = rf(ctx, projectID, provider, serviceID, ID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PrivateEndpointServiceMock_DeleteEndpointInterface_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteEndpointInterface' +type PrivateEndpointServiceMock_DeleteEndpointInterface_Call struct { + *mock.Call +} + +// DeleteEndpointInterface is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - provider string +// - serviceID string +// - ID string +func (_e *PrivateEndpointServiceMock_Expecter) DeleteEndpointInterface(ctx interface{}, projectID interface{}, provider interface{}, serviceID interface{}, ID interface{}) *PrivateEndpointServiceMock_DeleteEndpointInterface_Call { + return &PrivateEndpointServiceMock_DeleteEndpointInterface_Call{Call: _e.mock.On("DeleteEndpointInterface", ctx, projectID, provider, serviceID, ID)} +} + +func (_c *PrivateEndpointServiceMock_DeleteEndpointInterface_Call) Run(run func(ctx context.Context, projectID string, provider string, serviceID string, ID string)) *PrivateEndpointServiceMock_DeleteEndpointInterface_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string), args[4].(string)) + }) + return _c +} + +func (_c *PrivateEndpointServiceMock_DeleteEndpointInterface_Call) Return(_a0 error) *PrivateEndpointServiceMock_DeleteEndpointInterface_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *PrivateEndpointServiceMock_DeleteEndpointInterface_Call) RunAndReturn(run func(context.Context, string, string, string, string) error) *PrivateEndpointServiceMock_DeleteEndpointInterface_Call { + _c.Call.Return(run) + return _c +} + +// DeleteEndpointService provides a mock function with given fields: ctx, projectID, provider, ID +func (_m *PrivateEndpointServiceMock) DeleteEndpointService(ctx context.Context, projectID string, provider string, ID string) error { + ret := _m.Called(ctx, projectID, provider, ID) + + if len(ret) == 0 { + panic("no return value specified for DeleteEndpointService") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok { + r0 = rf(ctx, projectID, provider, ID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PrivateEndpointServiceMock_DeleteEndpointService_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteEndpointService' +type PrivateEndpointServiceMock_DeleteEndpointService_Call struct { + *mock.Call +} + +// DeleteEndpointService is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - provider string +// - ID string +func (_e *PrivateEndpointServiceMock_Expecter) DeleteEndpointService(ctx interface{}, projectID interface{}, provider interface{}, ID interface{}) *PrivateEndpointServiceMock_DeleteEndpointService_Call { + return &PrivateEndpointServiceMock_DeleteEndpointService_Call{Call: _e.mock.On("DeleteEndpointService", ctx, projectID, provider, ID)} +} + +func (_c *PrivateEndpointServiceMock_DeleteEndpointService_Call) Run(run func(ctx context.Context, projectID string, provider string, ID string)) *PrivateEndpointServiceMock_DeleteEndpointService_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *PrivateEndpointServiceMock_DeleteEndpointService_Call) Return(_a0 error) *PrivateEndpointServiceMock_DeleteEndpointService_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *PrivateEndpointServiceMock_DeleteEndpointService_Call) RunAndReturn(run func(context.Context, string, string, string) error) *PrivateEndpointServiceMock_DeleteEndpointService_Call { + _c.Call.Return(run) + return _c +} + +// GetPrivateEndpoint provides a mock function with given fields: ctx, projectID, provider, ID +func (_m *PrivateEndpointServiceMock) GetPrivateEndpoint(ctx context.Context, projectID string, provider string, ID string) (privateendpoint.EndpointService, error) { + ret := _m.Called(ctx, projectID, provider, ID) + + if len(ret) == 0 { + panic("no return value specified for GetPrivateEndpoint") + } + + var r0 privateendpoint.EndpointService + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (privateendpoint.EndpointService, error)); ok { + return rf(ctx, projectID, provider, ID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) privateendpoint.EndpointService); ok { + r0 = rf(ctx, projectID, provider, ID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(privateendpoint.EndpointService) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, projectID, provider, ID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PrivateEndpointServiceMock_GetPrivateEndpoint_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrivateEndpoint' +type PrivateEndpointServiceMock_GetPrivateEndpoint_Call struct { + *mock.Call +} + +// GetPrivateEndpoint is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - provider string +// - ID string +func (_e *PrivateEndpointServiceMock_Expecter) GetPrivateEndpoint(ctx interface{}, projectID interface{}, provider interface{}, ID interface{}) *PrivateEndpointServiceMock_GetPrivateEndpoint_Call { + return &PrivateEndpointServiceMock_GetPrivateEndpoint_Call{Call: _e.mock.On("GetPrivateEndpoint", ctx, projectID, provider, ID)} +} + +func (_c *PrivateEndpointServiceMock_GetPrivateEndpoint_Call) Run(run func(ctx context.Context, projectID string, provider string, ID string)) *PrivateEndpointServiceMock_GetPrivateEndpoint_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *PrivateEndpointServiceMock_GetPrivateEndpoint_Call) Return(_a0 privateendpoint.EndpointService, _a1 error) *PrivateEndpointServiceMock_GetPrivateEndpoint_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PrivateEndpointServiceMock_GetPrivateEndpoint_Call) RunAndReturn(run func(context.Context, string, string, string) (privateendpoint.EndpointService, error)) *PrivateEndpointServiceMock_GetPrivateEndpoint_Call { + _c.Call.Return(run) + return _c +} + +// ListPrivateEndpoints provides a mock function with given fields: ctx, projectID, provider +func (_m *PrivateEndpointServiceMock) ListPrivateEndpoints(ctx context.Context, projectID string, provider string) ([]privateendpoint.EndpointService, error) { + ret := _m.Called(ctx, projectID, provider) + + if len(ret) == 0 { + panic("no return value specified for ListPrivateEndpoints") + } + + var r0 []privateendpoint.EndpointService + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]privateendpoint.EndpointService, error)); ok { + return rf(ctx, projectID, provider) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) []privateendpoint.EndpointService); ok { + r0 = rf(ctx, projectID, provider) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]privateendpoint.EndpointService) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, projectID, provider) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PrivateEndpointServiceMock_ListPrivateEndpoints_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListPrivateEndpoints' +type PrivateEndpointServiceMock_ListPrivateEndpoints_Call struct { + *mock.Call +} + +// ListPrivateEndpoints is a helper method to define mock.On call +// - ctx context.Context +// - projectID string +// - provider string +func (_e *PrivateEndpointServiceMock_Expecter) ListPrivateEndpoints(ctx interface{}, projectID interface{}, provider interface{}) *PrivateEndpointServiceMock_ListPrivateEndpoints_Call { + return &PrivateEndpointServiceMock_ListPrivateEndpoints_Call{Call: _e.mock.On("ListPrivateEndpoints", ctx, projectID, provider)} +} + +func (_c *PrivateEndpointServiceMock_ListPrivateEndpoints_Call) Run(run func(ctx context.Context, projectID string, provider string)) *PrivateEndpointServiceMock_ListPrivateEndpoints_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *PrivateEndpointServiceMock_ListPrivateEndpoints_Call) Return(_a0 []privateendpoint.EndpointService, _a1 error) *PrivateEndpointServiceMock_ListPrivateEndpoints_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *PrivateEndpointServiceMock_ListPrivateEndpoints_Call) RunAndReturn(run func(context.Context, string, string) ([]privateendpoint.EndpointService, error)) *PrivateEndpointServiceMock_ListPrivateEndpoints_Call { + _c.Call.Return(run) + return _c +} + +// NewPrivateEndpointServiceMock creates a new instance of PrivateEndpointServiceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPrivateEndpointServiceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *PrivateEndpointServiceMock { + mock := &PrivateEndpointServiceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/translation/privateendpoint/conversion.go b/internal/translation/privateendpoint/conversion.go new file mode 100644 index 0000000000..9501f68156 --- /dev/null +++ b/internal/translation/privateendpoint/conversion.go @@ -0,0 +1,475 @@ +/* +Copyright 2024 MongoDB. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package privateendpoint + +import ( + "strings" + + "go.mongodb.org/atlas-sdk/v20231115008/admin" + + "github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/cmp" + "github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/pointer" + akov2 "github.com/mongodb/mongodb-atlas-kubernetes/v2/pkg/api/v1" + "github.com/mongodb/mongodb-atlas-kubernetes/v2/pkg/api/v1/status" +) + +const ( + ProviderAWS = "AWS" + ProviderAzure = "AZURE" + ProviderGCP = "GCP" + + StatusInitiating = "INITIATING" + StatusPending = "PENDING" + StatusPendingAcceptance = "PENDING_ACCEPTANCE" + StatusWaitingForUser = "WAITING_FOR_USER" + StatusVerified = "VERIFIED" + StatusFailed = "FAILED" + StatusRejected = "REJECTED" + StatusDeleting = "DELETING" +) + +type EndpointService interface { + ServiceID() string + EndpointInterfaces() EndpointInterfaces + Provider() string + Region() string + Status() string + ErrorMessage() string +} + +type CommonEndpointService struct { + ID string + CloudRegion string + ServiceStatus string + Error string + Interfaces EndpointInterfaces +} + +func (s *CommonEndpointService) ServiceID() string { + return s.ID +} + +func (s *CommonEndpointService) EndpointInterfaces() EndpointInterfaces { + return s.Interfaces +} + +func (s *CommonEndpointService) Region() string { + return s.CloudRegion +} + +func (s *CommonEndpointService) Status() string { + return s.ServiceStatus +} + +func (s *CommonEndpointService) ErrorMessage() string { + return s.Error +} + +type AWSService struct { + CommonEndpointService + + ServiceName string +} + +func (s *AWSService) Provider() string { + return ProviderAWS +} + +type AzureService struct { + CommonEndpointService + ServiceName string + ResourceID string +} + +func (s *AzureService) Provider() string { + return ProviderAzure +} + +type GCPService struct { + CommonEndpointService + AttachmentNames []string +} + +func (s *GCPService) Provider() string { + return ProviderGCP +} + +type EndpointInterface interface { + InterfaceID() string + Status() string + ErrorMessage() string +} + +type CommonEndpointInterface struct { + ID string + InterfaceStatus string + Error string +} + +func (i *CommonEndpointInterface) InterfaceID() string { + return i.ID +} + +func (i *CommonEndpointInterface) Status() string { + return i.InterfaceStatus +} + +func (i *CommonEndpointInterface) ErrorMessage() string { + return i.Error +} + +type AWSInterface struct { + CommonEndpointInterface +} + +type AzureInterface struct { + CommonEndpointInterface + IP string + ConnectionName string +} + +type GCPInterface struct { + CommonEndpointInterface + Endpoints []GCPInterfaceEndpoint +} + +type GCPInterfaceEndpoint struct { + Name string + IP string + Status string +} + +type EndpointInterfaces []EndpointInterface + +func (ei EndpointInterfaces) Get(ID string) EndpointInterface { + if ei == nil { + return nil + } + + for _, i := range ei { + if i.InterfaceID() == ID { + return i + } + } + + return nil +} + +func NewPrivateEndpoint(akoPrivateEndpoint *akov2.AtlasPrivateEndpoint) EndpointService { + switch akoPrivateEndpoint.Spec.Provider { + case ProviderAWS: + return &AWSService{ + CommonEndpointService: CommonEndpointService{ + ID: akoPrivateEndpoint.Status.ServiceID, + CloudRegion: akoPrivateEndpoint.Spec.Region, + ServiceStatus: akoPrivateEndpoint.Status.ServiceStatus, + Error: akoPrivateEndpoint.Status.Error, + Interfaces: newPrivateEndpointInterface(akoPrivateEndpoint), + }, + ServiceName: akoPrivateEndpoint.Status.ServiceName, + } + case ProviderAzure: + return &AzureService{ + CommonEndpointService: CommonEndpointService{ + ID: akoPrivateEndpoint.Status.ServiceID, + CloudRegion: akoPrivateEndpoint.Spec.Region, + ServiceStatus: akoPrivateEndpoint.Status.ServiceStatus, + Error: akoPrivateEndpoint.Status.Error, + Interfaces: newPrivateEndpointInterface(akoPrivateEndpoint), + }, + ServiceName: akoPrivateEndpoint.Status.ServiceName, + ResourceID: akoPrivateEndpoint.Status.ResourceID, + } + case ProviderGCP: + return &GCPService{ + CommonEndpointService: CommonEndpointService{ + ID: akoPrivateEndpoint.Status.ServiceID, + CloudRegion: akoPrivateEndpoint.Spec.Region, + ServiceStatus: akoPrivateEndpoint.Status.ServiceStatus, + Error: akoPrivateEndpoint.Status.Error, + Interfaces: newPrivateEndpointInterface(akoPrivateEndpoint), + }, + AttachmentNames: akoPrivateEndpoint.Status.ServiceAttachmentNames, + } + } + + return nil +} + +func NewPrivateEndpointStatus(peService EndpointService) status.AtlasPrivateEndpointStatusOption { + return func(s *status.AtlasPrivateEndpointStatus) { + endpoints := make([]status.EndpointInterfaceStatus, 0, len(peService.EndpointInterfaces())) + for _, i := range peService.EndpointInterfaces() { + connName := "" + if azureInterface, ok := i.(*AzureInterface); ok { + connName = azureInterface.ConnectionName + } + + var gcpForwardRules []status.GCPForwardingRule + if gcpInterface, ok := i.(*GCPInterface); ok { + for _, fr := range gcpInterface.Endpoints { + gcpForwardRules = append( + gcpForwardRules, + status.GCPForwardingRule{ + Name: fr.Name, + Status: fr.Status, + }, + ) + } + } + + endpoints = append( + endpoints, + status.EndpointInterfaceStatus{ + ID: i.InterfaceID(), + ConnectionName: connName, + GCPForwardingRules: gcpForwardRules, + Status: i.Status(), + Error: i.ErrorMessage(), + }, + ) + } + + s.ServiceID = peService.ServiceID() + s.Endpoints = endpoints + s.ServiceStatus = peService.Status() + s.Error = peService.ErrorMessage() + + switch pe := peService.(type) { + case *AWSService: + s.ServiceName = pe.ServiceName + case *AzureService: + s.ServiceName = pe.ServiceName + s.ResourceID = pe.ResourceID + case *GCPService: + s.ServiceAttachmentNames = pe.AttachmentNames + } + } +} + +func newPrivateEndpointInterface(akoPrivateEndpoint *akov2.AtlasPrivateEndpoint) EndpointInterfaces { + endpoints := EndpointInterfaces{} + + switch akoPrivateEndpoint.Spec.Provider { + case ProviderAWS: + for _, endpoint := range akoPrivateEndpoint.Spec.AWSConfiguration { + ep := &AWSInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: endpoint.ID, + }, + } + + for _, epStatus := range akoPrivateEndpoint.Status.Endpoints { + if epStatus.ID == endpoint.ID { + ep.InterfaceStatus = epStatus.Status + ep.Error = epStatus.Error + } + } + + endpoints = append(endpoints, ep) + } + case ProviderAzure: + for _, endpoint := range akoPrivateEndpoint.Spec.AzureConfiguration { + ep := &AzureInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: endpoint.ID, + }, + IP: endpoint.IP, + } + + for _, epStatus := range akoPrivateEndpoint.Status.Endpoints { + if epStatus.ID == endpoint.ID { + ep.InterfaceStatus = epStatus.Status + ep.Error = epStatus.Error + ep.ConnectionName = epStatus.ConnectionName + } + } + + endpoints = append(endpoints, ep) + } + case ProviderGCP: + for _, endpoint := range akoPrivateEndpoint.Spec.GCPConfiguration { + gcpEPs := make([]GCPInterfaceEndpoint, 0, len(endpoint.Endpoints)) + for _, gcpEP := range endpoint.Endpoints { + gcpEPs = append( + gcpEPs, + GCPInterfaceEndpoint{ + Name: gcpEP.Name, + IP: gcpEP.IP, + }, + ) + } + + ep := &GCPInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: endpoint.GroupName, + }, + Endpoints: gcpEPs, + } + + for _, epStatus := range akoPrivateEndpoint.Status.Endpoints { + if epStatus.ID == endpoint.GroupName { + ep.InterfaceStatus = epStatus.Status + ep.Error = epStatus.Error + + for i, gcpEP := range ep.Endpoints { + for _, gcpEPStatus := range epStatus.GCPForwardingRules { + if gcpEP.Name == gcpEPStatus.Name { + ep.Endpoints[i].Status = gcpEPStatus.Status + } + } + } + } + } + + endpoints = append(endpoints, ep) + } + } + + cmp.NormalizeSlice(endpoints, func(a, b EndpointInterface) int { + return strings.Compare(a.InterfaceID(), b.InterfaceID()) + }) + + return endpoints +} + +func serviceFromAtlas(peService *admin.EndpointService, endpoints EndpointInterfaces) EndpointService { + switch peService.GetCloudProvider() { + case ProviderAWS: + return &AWSService{ + CommonEndpointService: CommonEndpointService{ + ID: peService.GetId(), + CloudRegion: peService.GetRegionName(), + ServiceStatus: peService.GetStatus(), + Error: peService.GetErrorMessage(), + Interfaces: endpoints, + }, + ServiceName: peService.GetEndpointServiceName(), + } + case ProviderAzure: + return &AzureService{ + CommonEndpointService: CommonEndpointService{ + ID: peService.GetId(), + CloudRegion: peService.GetRegionName(), + ServiceStatus: peService.GetStatus(), + Error: peService.GetErrorMessage(), + Interfaces: endpoints, + }, + ServiceName: peService.GetPrivateLinkServiceName(), + ResourceID: peService.GetPrivateLinkServiceResourceId(), + } + case ProviderGCP: + return &GCPService{ + CommonEndpointService: CommonEndpointService{ + ID: peService.GetId(), + CloudRegion: peService.GetRegionName(), + ServiceStatus: peService.GetStatus(), + Error: peService.GetErrorMessage(), + Interfaces: endpoints, + }, + AttachmentNames: peService.GetServiceAttachmentNames(), + } + } + + return nil +} + +func serviceCreateToAtlas(peService EndpointService) *admin.CloudProviderEndpointServiceRequest { + return &admin.CloudProviderEndpointServiceRequest{ + ProviderName: peService.Provider(), + Region: peService.Region(), + } +} + +func interfaceFromAtlas(peInterface *admin.PrivateLinkEndpoint) EndpointInterface { + switch peInterface.GetCloudProvider() { + case ProviderAWS: + return &AWSInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: peInterface.GetInterfaceEndpointId(), + InterfaceStatus: peInterface.GetConnectionStatus(), + Error: peInterface.GetErrorMessage(), + }, + } + case ProviderAzure: + return &AzureInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: peInterface.GetPrivateEndpointResourceId(), + InterfaceStatus: peInterface.GetStatus(), + Error: peInterface.GetErrorMessage(), + }, + IP: peInterface.GetPrivateEndpointIPAddress(), + ConnectionName: peInterface.GetPrivateEndpointConnectionName(), + } + case ProviderGCP: + endpoints := make([]GCPInterfaceEndpoint, 0, len(peInterface.GetEndpoints())) + for _, ep := range peInterface.GetEndpoints() { + endpoints = append( + endpoints, + GCPInterfaceEndpoint{ + Name: ep.GetEndpointName(), + IP: ep.GetIpAddress(), + Status: ep.GetStatus(), + }, + ) + } + + return &GCPInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: peInterface.GetEndpointGroupName(), + InterfaceStatus: peInterface.GetStatus(), + Error: peInterface.GetErrorMessage(), + }, + Endpoints: endpoints, + } + } + + return nil +} + +func interfaceCreateToAtlas(peInterface EndpointInterface, gcpProjectID string) *admin.CreateEndpointRequest { + switch i := peInterface.(type) { + case *AWSInterface: + return &admin.CreateEndpointRequest{ + Id: pointer.MakePtr(i.InterfaceID()), + } + case *AzureInterface: + return &admin.CreateEndpointRequest{ + Id: pointer.MakePtr(i.InterfaceID()), + PrivateEndpointIPAddress: pointer.MakePtr(i.IP), + } + case *GCPInterface: + gcpEPs := make([]admin.CreateGCPForwardingRuleRequest, 0, len(i.Endpoints)) + for _, ep := range i.Endpoints { + gcpEPs = append( + gcpEPs, + admin.CreateGCPForwardingRuleRequest{ + EndpointName: pointer.MakePtr(ep.Name), + IpAddress: pointer.MakePtr(ep.IP), + }, + ) + } + + return &admin.CreateEndpointRequest{ + GcpProjectId: pointer.MakePtr(gcpProjectID), + EndpointGroupName: pointer.MakePtr(i.InterfaceID()), + Endpoints: &gcpEPs, + } + } + + return nil +} diff --git a/internal/translation/privateendpoint/privateendpoint.go b/internal/translation/privateendpoint/privateendpoint.go new file mode 100644 index 0000000000..b2b3f17794 --- /dev/null +++ b/internal/translation/privateendpoint/privateendpoint.go @@ -0,0 +1,181 @@ +/* +Copyright 2024 MongoDB. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package privateendpoint + +import ( + "context" + "fmt" + "strings" + + "go.mongodb.org/atlas-sdk/v20231115008/admin" + "golang.org/x/exp/slices" +) + +const ( + ErrorServiceNotFound = "PRIVATE_ENDPOINT_SERVICE_NOT_FOUND" +) + +type PrivateEndpointService interface { + ListPrivateEndpoints(ctx context.Context, projectID, provider string) ([]EndpointService, error) + GetPrivateEndpoint(ctx context.Context, projectID, provider, ID string) (EndpointService, error) + CreatePrivateEndpointService(ctx context.Context, projectID string, peService EndpointService) (EndpointService, error) + DeleteEndpointService(ctx context.Context, projectID, provider, ID string) error + CreatePrivateEndpointInterface(ctx context.Context, projectID, provider, serviceID, gcpProjectID string, peInterface EndpointInterface) (EndpointInterface, error) + DeleteEndpointInterface(ctx context.Context, projectID, provider, serviceID, ID string) error +} + +type PrivateEndpoint struct { + api admin.PrivateEndpointServicesApi +} + +func (pe *PrivateEndpoint) ListPrivateEndpoints(ctx context.Context, projectID, provider string) ([]EndpointService, error) { + services, _, err := pe.api.ListPrivateEndpointServices(ctx, projectID, provider). + Execute() + if err != nil { + return nil, fmt.Errorf("failed to retrieve the list of private endpoints: %w", err) + } + + peServices := make([]EndpointService, 0, len(services)) + for _, service := range services { + interfaceIDs := getInterfacesIDs(&service) + peInterfaces := make([]EndpointInterface, 0, len(interfaceIDs)) + + for _, interfaceID := range interfaceIDs { + peInterface, err := pe.getEndpointInterfaces(ctx, projectID, provider, service.GetId(), interfaceID) + if err != nil { + return nil, err + } + + if peInterface != nil { + peInterfaces = append(peInterfaces, peInterface) + } + } + + slices.SortFunc(peInterfaces, func(a, b EndpointInterface) int { + return strings.Compare(a.InterfaceID(), b.InterfaceID()) + }) + + peServices = append( + peServices, + serviceFromAtlas(&service, peInterfaces), + ) + } + + return peServices, nil +} + +func (pe *PrivateEndpoint) GetPrivateEndpoint(ctx context.Context, projectID, provider, ID string) (EndpointService, error) { + service, _, err := pe.api.GetPrivateEndpointService(ctx, projectID, provider, ID). + Execute() + if admin.IsErrorCode(err, ErrorServiceNotFound) { + return nil, nil + } + + if err != nil { + return nil, fmt.Errorf("failed to retrieve the private endpoint: %w", err) + } + + interfaceIDs := getInterfacesIDs(service) + peInterfaces := make([]EndpointInterface, 0, len(interfaceIDs)) + + for _, interfaceID := range interfaceIDs { + peInterface, err := pe.getEndpointInterfaces(ctx, projectID, provider, service.GetId(), interfaceID) + if err != nil { + return nil, err + } + + if peInterface != nil { + peInterfaces = append(peInterfaces, peInterface) + } + } + + slices.SortFunc(peInterfaces, func(a, b EndpointInterface) int { + return strings.Compare(a.InterfaceID(), b.InterfaceID()) + }) + + return serviceFromAtlas(service, peInterfaces), nil +} + +func (pe *PrivateEndpoint) CreatePrivateEndpointService(ctx context.Context, projectID string, peService EndpointService) (EndpointService, error) { + service, _, err := pe.api.CreatePrivateEndpointService(ctx, projectID, serviceCreateToAtlas(peService)). + Execute() + if err != nil { + return nil, fmt.Errorf("failed to create the private endpoint service: %w", err) + } + + return serviceFromAtlas(service, []EndpointInterface{}), nil +} + +func (pe *PrivateEndpoint) DeleteEndpointService(ctx context.Context, projectID, provider, ID string) error { + _, _, err := pe.api.DeletePrivateEndpointService(ctx, projectID, provider, ID).Execute() + if err != nil { + return fmt.Errorf("failed to delete the private endpoint service: %w", err) + } + + return nil +} + +func (pe *PrivateEndpoint) CreatePrivateEndpointInterface(ctx context.Context, projectID, provider, serviceID, gcpProjectID string, peInterface EndpointInterface) (EndpointInterface, error) { + i, _, err := pe.api.CreatePrivateEndpoint(ctx, projectID, provider, serviceID, interfaceCreateToAtlas(peInterface, gcpProjectID)). + Execute() + if err != nil { + return nil, fmt.Errorf("failed to create the private endpoint interface: %w", err) + } + + return interfaceFromAtlas(i), nil +} + +func (pe *PrivateEndpoint) DeleteEndpointInterface(ctx context.Context, projectID, provider, serviceID, ID string) error { + _, _, err := pe.api.DeletePrivateEndpoint(ctx, projectID, provider, ID, serviceID).Execute() + if err != nil { + return fmt.Errorf("failed to delete the private endpoint interface: %w", err) + } + + return nil +} + +func (pe *PrivateEndpoint) getEndpointInterfaces(ctx context.Context, projectID, provider, serviceID, ID string) (EndpointInterface, error) { + i, _, err := pe.api.GetPrivateEndpoint(ctx, projectID, provider, ID, serviceID).Execute() + if admin.IsErrorCode(err, ErrorServiceNotFound) { + return nil, nil + } + + if err != nil { + return nil, fmt.Errorf("failed to retrieve the private endpoint interface: %w", err) + } + + return interfaceFromAtlas(i), nil +} + +func getInterfacesIDs(peService *admin.EndpointService) []string { + switch peService.GetCloudProvider() { + case ProviderAWS: + return peService.GetInterfaceEndpoints() + case ProviderAzure: + return peService.GetPrivateEndpoints() + case ProviderGCP: + return peService.GetEndpointGroupNames() + } + + return nil +} + +func NewPrivateEndpointAPI(api admin.PrivateEndpointServicesApi) PrivateEndpointService { + return &PrivateEndpoint{ + api: api, + } +} diff --git a/internal/translation/privateendpoint/privateendpoint_test.go b/internal/translation/privateendpoint/privateendpoint_test.go new file mode 100644 index 0000000000..7da6c404fd --- /dev/null +++ b/internal/translation/privateendpoint/privateendpoint_test.go @@ -0,0 +1,616 @@ +package privateendpoint + +import ( + "context" + "errors" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.mongodb.org/atlas-sdk/v20231115008/admin" + "go.mongodb.org/atlas-sdk/v20231115008/mockadmin" + + "github.com/mongodb/mongodb-atlas-kubernetes/v2/internal/pointer" +) + +func TestListPrivateEndpoints(t *testing.T) { + tests := map[string]struct { + provider string + mockListReturnFunc func() ([]admin.EndpointService, *http.Response, error) + mockInterfaceReturnFunc func() (*admin.PrivateLinkEndpoint, *http.Response, error) + expectedPEs []EndpointService + expectedErr error + }{ + "failed to retrieve data": { + provider: "AWS", + mockListReturnFunc: func() ([]admin.EndpointService, *http.Response, error) { + return nil, &http.Response{}, errors.New("atlas failed to list") + }, + expectedErr: fmt.Errorf("failed to retrieve the list of private endpoints: %w", errors.New("atlas failed to list")), + }, + "failed to retrieve existing interface for listed private endpoint service": { + provider: "AWS", + mockListReturnFunc: func() ([]admin.EndpointService, *http.Response, error) { + return []admin.EndpointService{ + { + CloudProvider: "AWS", + Id: pointer.MakePtr("pe-service-ID"), + RegionName: pointer.MakePtr("US_EAST_1"), + Status: pointer.MakePtr("AVAILABLE"), + EndpointServiceName: pointer.MakePtr("aws/pe-service/name"), + InterfaceEndpoints: &[]string{"vpcpe-123456"}, + }, + }, &http.Response{}, nil + }, + mockInterfaceReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return nil, &http.Response{}, errors.New("atlas failed to get") + }, + expectedErr: fmt.Errorf("failed to retrieve the private endpoint interface: %w", errors.New("atlas failed to get")), + }, + "list AWS private endpoints": { + provider: "AWS", + mockListReturnFunc: func() ([]admin.EndpointService, *http.Response, error) { + return []admin.EndpointService{ + { + CloudProvider: "AWS", + Id: pointer.MakePtr("pe-service-ID-1"), + RegionName: pointer.MakePtr("US_EAST_1"), + Status: pointer.MakePtr("AVAILABLE"), + EndpointServiceName: pointer.MakePtr("aws/pe-service/name"), + InterfaceEndpoints: &[]string{"vpcpe-123456"}, + }, + { + CloudProvider: "AWS", + Id: pointer.MakePtr("pe-service-ID-2"), + RegionName: pointer.MakePtr("US_EAST_2"), + Status: pointer.MakePtr("AVAILABLE"), + EndpointServiceName: pointer.MakePtr("aws/pe-service/name"), + }, + }, &http.Response{}, nil + }, + mockInterfaceReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return &admin.PrivateLinkEndpoint{ + CloudProvider: "AWS", + ConnectionStatus: pointer.MakePtr("AVAILABLE"), + InterfaceEndpointId: pointer.MakePtr("vpcpe-123456"), + }, &http.Response{}, nil + }, + expectedPEs: []EndpointService{ + &AWSService{ + CommonEndpointService: CommonEndpointService{ + ID: "pe-service-ID-1", + CloudRegion: "US_EAST_1", + ServiceStatus: "AVAILABLE", + Interfaces: EndpointInterfaces{ + &AWSInterface{ + CommonEndpointInterface{ + ID: "vpcpe-123456", + InterfaceStatus: "AVAILABLE", + }, + }, + }, + }, + ServiceName: "aws/pe-service/name", + }, + &AWSService{ + CommonEndpointService: CommonEndpointService{ + ID: "pe-service-ID-2", + CloudRegion: "US_EAST_2", + ServiceStatus: "AVAILABLE", + Interfaces: EndpointInterfaces{}, + }, + ServiceName: "aws/pe-service/name", + }, + }, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + projectID := "project-ID" + api := mockadmin.NewPrivateEndpointServicesApi(t) + api.EXPECT().ListPrivateEndpointServices(ctx, projectID, tt.provider). + Return(admin.ListPrivateEndpointServicesApiRequest{ApiService: api}) + api.EXPECT().ListPrivateEndpointServicesExecute(mock.AnythingOfType("admin.ListPrivateEndpointServicesApiRequest")). + Return(tt.mockListReturnFunc()) + + if tt.mockInterfaceReturnFunc != nil { + api.EXPECT().GetPrivateEndpoint(ctx, projectID, tt.provider, mock.AnythingOfType("string"), mock.AnythingOfType("string")). + Return(admin.GetPrivateEndpointApiRequest{ApiService: api}) + api.EXPECT().GetPrivateEndpointExecute(mock.AnythingOfType("admin.GetPrivateEndpointApiRequest")). + Return(tt.mockInterfaceReturnFunc()) + } + + pe := &PrivateEndpoint{ + api: api, + } + + items, err := pe.ListPrivateEndpoints(ctx, projectID, tt.provider) + assert.Equal(t, tt.expectedErr, err) + assert.Equal(t, tt.expectedPEs, items) + }) + } +} + +func TestGetPrivateEndpoint(t *testing.T) { + notFoundErr := &admin.GenericOpenAPIError{} + notFoundErr.SetModel(admin.ApiError{ErrorCode: pointer.MakePtr("PRIVATE_ENDPOINT_SERVICE_NOT_FOUND")}) + tests := map[string]struct { + provider string + mockGetReturnFunc func() (*admin.EndpointService, *http.Response, error) + mockInterfaceReturnFunc func() (*admin.PrivateLinkEndpoint, *http.Response, error) + expectedPE EndpointService + expectedErr error + }{ + "failed to retrieve data": { + provider: "AWS", + mockGetReturnFunc: func() (*admin.EndpointService, *http.Response, error) { + return nil, &http.Response{}, errors.New("atlas failed to get") + }, + expectedErr: fmt.Errorf("failed to retrieve the private endpoint: %w", errors.New("atlas failed to get")), + }, + "service was not found": { + provider: "AWS", + mockGetReturnFunc: func() (*admin.EndpointService, *http.Response, error) { + return nil, &http.Response{}, notFoundErr + }, + }, + "failed to get interface for the service": { + provider: "AZURE", + mockGetReturnFunc: func() (*admin.EndpointService, *http.Response, error) { + return &admin.EndpointService{ + CloudProvider: "AZURE", + Id: pointer.MakePtr("pe-service-ID"), + RegionName: pointer.MakePtr("GERMANY_NORTH"), + Status: pointer.MakePtr("AVAILABLE"), + PrivateEndpoints: &[]string{"long-azure-resource-ID"}, + PrivateLinkServiceName: pointer.MakePtr("pls_name"), + PrivateLinkServiceResourceId: pointer.MakePtr("long-azure-resource-ID"), + }, &http.Response{}, nil + }, + mockInterfaceReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return nil, &http.Response{}, errors.New("atlas failed to get") + }, + expectedErr: fmt.Errorf("failed to retrieve the private endpoint interface: %w", errors.New("atlas failed to get")), + }, + "get AZURE private endpoint": { + provider: "AZURE", + mockGetReturnFunc: func() (*admin.EndpointService, *http.Response, error) { + return &admin.EndpointService{ + CloudProvider: "AZURE", + Id: pointer.MakePtr("pe-service-ID"), + RegionName: pointer.MakePtr("GERMANY_NORTH"), + Status: pointer.MakePtr("AVAILABLE"), + PrivateEndpoints: &[]string{"long-azure-resource-ID"}, + PrivateLinkServiceName: pointer.MakePtr("pls_name"), + PrivateLinkServiceResourceId: pointer.MakePtr("long-azure-resource-ID"), + }, &http.Response{}, nil + }, + mockInterfaceReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return &admin.PrivateLinkEndpoint{ + CloudProvider: "AZURE", + PrivateEndpointConnectionName: pointer.MakePtr("atlas-resource-name"), + PrivateEndpointIPAddress: pointer.MakePtr("10.0.0.4"), + PrivateEndpointResourceId: pointer.MakePtr("long-azure-resource-ID"), + Status: pointer.MakePtr("AVAILABLE"), + }, &http.Response{}, nil + }, + expectedPE: &AzureService{ + CommonEndpointService: CommonEndpointService{ + ID: "pe-service-ID", + CloudRegion: "GERMANY_NORTH", + ServiceStatus: "AVAILABLE", + Interfaces: EndpointInterfaces{ + &AzureInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: "long-azure-resource-ID", + InterfaceStatus: "AVAILABLE", + }, + IP: "10.0.0.4", + ConnectionName: "atlas-resource-name", + }, + }, + }, + ServiceName: "pls_name", + ResourceID: "long-azure-resource-ID", + }, + }, + "get GCP private endpoint": { + provider: "GCP", + mockGetReturnFunc: func() (*admin.EndpointService, *http.Response, error) { + return &admin.EndpointService{ + CloudProvider: "GCP", + Id: pointer.MakePtr("pe-service-ID"), + RegionName: pointer.MakePtr("EUROPE_WEST_3"), + Status: pointer.MakePtr("AVAILABLE"), + EndpointGroupNames: &[]string{"group-name"}, + ServiceAttachmentNames: &[]string{"service/attachment1", "service/attachment2", "service/attachment3"}, + }, &http.Response{}, nil + }, + mockInterfaceReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return &admin.PrivateLinkEndpoint{ + CloudProvider: "GCP", + Status: pointer.MakePtr("AVAILABLE"), + EndpointGroupName: pointer.MakePtr("group-name"), + Endpoints: &[]admin.GCPConsumerForwardingRule{ + { + EndpointName: pointer.MakePtr("group-name-pe-1"), + IpAddress: pointer.MakePtr("10.0.0.1"), + Status: pointer.MakePtr("AVAILABLE"), + }, + { + EndpointName: pointer.MakePtr("group-name-pe-2"), + IpAddress: pointer.MakePtr("10.0.0.3"), + Status: pointer.MakePtr("AVAILABLE"), + }, + { + EndpointName: pointer.MakePtr("group-name-pe-3"), + IpAddress: pointer.MakePtr("10.0.0.3"), + Status: pointer.MakePtr("AVAILABLE"), + }, + }, + }, &http.Response{}, nil + }, + expectedPE: &GCPService{ + CommonEndpointService: CommonEndpointService{ + ID: "pe-service-ID", + CloudRegion: "EUROPE_WEST_3", + ServiceStatus: "AVAILABLE", + Interfaces: EndpointInterfaces{ + &GCPInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: "group-name", + InterfaceStatus: "AVAILABLE", + }, + Endpoints: []GCPInterfaceEndpoint{ + { + Name: "group-name-pe-1", + IP: "10.0.0.1", + Status: "AVAILABLE", + }, + { + Name: "group-name-pe-2", + IP: "10.0.0.3", + Status: "AVAILABLE", + }, + { + Name: "group-name-pe-3", + IP: "10.0.0.3", + Status: "AVAILABLE", + }, + }, + }, + }, + }, + AttachmentNames: []string{"service/attachment1", "service/attachment2", "service/attachment3"}, + }, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + projectID := "project-ID" + api := mockadmin.NewPrivateEndpointServicesApi(t) + api.EXPECT().GetPrivateEndpointService(ctx, projectID, tt.provider, "pe-service-ID"). + Return(admin.GetPrivateEndpointServiceApiRequest{ApiService: api}) + api.EXPECT().GetPrivateEndpointServiceExecute(mock.AnythingOfType("admin.GetPrivateEndpointServiceApiRequest")). + Return(tt.mockGetReturnFunc()) + + if tt.mockInterfaceReturnFunc != nil { + api.EXPECT().GetPrivateEndpoint(ctx, projectID, tt.provider, mock.AnythingOfType("string"), mock.AnythingOfType("string")). + Return(admin.GetPrivateEndpointApiRequest{ApiService: api}) + api.EXPECT().GetPrivateEndpointExecute(mock.AnythingOfType("admin.GetPrivateEndpointApiRequest")). + Return(tt.mockInterfaceReturnFunc()) + } + + pe := &PrivateEndpoint{ + api: api, + } + + result, err := pe.GetPrivateEndpoint(ctx, projectID, tt.provider, "pe-service-ID") + assert.Equal(t, tt.expectedErr, err) + assert.Equal(t, tt.expectedPE, result) + }) + } +} + +func TestCreatePrivateEndpointService(t *testing.T) { + tests := map[string]struct { + service EndpointService + mockCreateReturnFunc func() (*admin.EndpointService, *http.Response, error) + expectedPE EndpointService + expectedErr error + }{ + "failed to create the service": { + service: &AWSService{ + CommonEndpointService: CommonEndpointService{ + CloudRegion: "US_EAST_1", + }, + }, + mockCreateReturnFunc: func() (*admin.EndpointService, *http.Response, error) { + return nil, &http.Response{}, errors.New("atlas failed to create") + }, + expectedErr: fmt.Errorf("failed to create the private endpoint service: %w", errors.New("atlas failed to create")), + }, + "create private endpoint service": { + service: &AWSService{ + CommonEndpointService: CommonEndpointService{ + CloudRegion: "US_EAST_1", + }, + }, + mockCreateReturnFunc: func() (*admin.EndpointService, *http.Response, error) { + return &admin.EndpointService{ + CloudProvider: "AWS", + Id: pointer.MakePtr("pe-service-ID"), + RegionName: pointer.MakePtr("US_EAST_1"), + Status: pointer.MakePtr("INITIALIZING"), + }, &http.Response{}, nil + }, + expectedPE: &AWSService{ + CommonEndpointService: CommonEndpointService{ + ID: "pe-service-ID", + CloudRegion: "US_EAST_1", + ServiceStatus: "INITIALIZING", + Interfaces: EndpointInterfaces{}, + }, + }, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + projectID := "project-ID" + api := mockadmin.NewPrivateEndpointServicesApi(t) + api.EXPECT().CreatePrivateEndpointService(ctx, projectID, mock.AnythingOfType("*admin.CloudProviderEndpointServiceRequest")). + Return(admin.CreatePrivateEndpointServiceApiRequest{ApiService: api}) + api.EXPECT().CreatePrivateEndpointServiceExecute(mock.AnythingOfType("admin.CreatePrivateEndpointServiceApiRequest")). + Return(tt.mockCreateReturnFunc()) + + pe := &PrivateEndpoint{ + api: api, + } + + result, err := pe.CreatePrivateEndpointService(ctx, projectID, tt.service) + assert.Equal(t, tt.expectedErr, err) + assert.Equal(t, tt.expectedPE, result) + }) + } +} + +func TestDeletePrivateEndpointService(t *testing.T) { + tests := map[string]struct { + mockDeleteReturnFunc func() (map[string]interface{}, *http.Response, error) + expectedErr error + }{ + "failed to delete the service": { + mockDeleteReturnFunc: func() (map[string]interface{}, *http.Response, error) { + return nil, &http.Response{}, errors.New("atlas failed to delete") + }, + expectedErr: fmt.Errorf("failed to delete the private endpoint service: %w", errors.New("atlas failed to delete")), + }, + "delete private endpoint service": { + mockDeleteReturnFunc: func() (map[string]interface{}, *http.Response, error) { + return nil, &http.Response{}, nil + }, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + projectID := "project-ID" + api := mockadmin.NewPrivateEndpointServicesApi(t) + api.EXPECT().DeletePrivateEndpointService(ctx, projectID, "AWS", "pe-service-ID"). + Return(admin.DeletePrivateEndpointServiceApiRequest{ApiService: api}) + api.EXPECT().DeletePrivateEndpointServiceExecute(mock.AnythingOfType("admin.DeletePrivateEndpointServiceApiRequest")). + Return(tt.mockDeleteReturnFunc()) + + pe := &PrivateEndpoint{ + api: api, + } + + err := pe.DeleteEndpointService(ctx, projectID, "AWS", "pe-service-ID") + assert.Equal(t, tt.expectedErr, err) + }) + } +} + +func TestCreatePrivateEndpointInterface(t *testing.T) { + tests := map[string]struct { + provider string + gcpProjectID string + endpointInterface EndpointInterface + mockCreateReturnFunc func() (*admin.PrivateLinkEndpoint, *http.Response, error) + expectedPE EndpointInterface + expectedErr error + }{ + "failed to create the endpoint interface": { + provider: "AWS", + endpointInterface: &AWSInterface{ + CommonEndpointInterface{ + ID: "vpcpe-123456", + }, + }, + mockCreateReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return nil, &http.Response{}, errors.New("atlas failed to create") + }, + expectedErr: fmt.Errorf("failed to create the private endpoint interface: %w", errors.New("atlas failed to create")), + }, + "create AWS private endpoint": { + provider: "AWS", + endpointInterface: &AWSInterface{ + CommonEndpointInterface{ + ID: "vpcpe-123456", + }, + }, + mockCreateReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return &admin.PrivateLinkEndpoint{ + CloudProvider: "AWS", + InterfaceEndpointId: pointer.MakePtr("vpcpe-123456"), + ConnectionStatus: pointer.MakePtr("INITIALIZING"), + }, &http.Response{}, nil + }, + expectedPE: &AWSInterface{ + CommonEndpointInterface{ + ID: "vpcpe-123456", + InterfaceStatus: "INITIALIZING", + }, + }, + }, + "create AZURE private endpoint": { + provider: "AZURE", + endpointInterface: &AzureInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: "long-azure-resource-ID", + InterfaceStatus: "INITIALIZING", + }, + IP: "10.0.0.2", + }, + mockCreateReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return &admin.PrivateLinkEndpoint{ + CloudProvider: "AZURE", + PrivateEndpointResourceId: pointer.MakePtr("long-azure-resource-ID"), + PrivateEndpointIPAddress: pointer.MakePtr("10.0.0.2"), + PrivateEndpointConnectionName: pointer.MakePtr("atlas-resource-name"), + Status: pointer.MakePtr("INITIALIZING"), + }, &http.Response{}, nil + }, + expectedPE: &AzureInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: "long-azure-resource-ID", + InterfaceStatus: "INITIALIZING", + }, + IP: "10.0.0.2", + ConnectionName: "atlas-resource-name", + }, + }, + "create GCP private endpoint": { + provider: "GCP", + gcpProjectID: "customer-project-ID", + endpointInterface: &GCPInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: "group-name", + }, + Endpoints: []GCPInterfaceEndpoint{ + { + Name: "group-name-pe-1", + IP: "10.0.0.1", + }, + { + Name: "group-name-pe-2", + IP: "10.0.0.2", + }, + { + Name: "group-name-pe-3", + IP: "10.0.0.3", + }, + }, + }, + mockCreateReturnFunc: func() (*admin.PrivateLinkEndpoint, *http.Response, error) { + return &admin.PrivateLinkEndpoint{ + CloudProvider: "GCP", + EndpointGroupName: pointer.MakePtr("group-name"), + Endpoints: &[]admin.GCPConsumerForwardingRule{ + { + EndpointName: pointer.MakePtr("group-name-pe-1"), + IpAddress: pointer.MakePtr("10.0.0.1"), + Status: pointer.MakePtr("INITIALIZING"), + }, + { + EndpointName: pointer.MakePtr("group-name-pe-2"), + IpAddress: pointer.MakePtr("10.0.0.2"), + Status: pointer.MakePtr("INITIALIZING"), + }, + { + EndpointName: pointer.MakePtr("group-name-pe-3"), + IpAddress: pointer.MakePtr("10.0.0.3"), + Status: pointer.MakePtr("INITIALIZING"), + }, + }, + Status: pointer.MakePtr("INITIALIZING"), + }, &http.Response{}, nil + }, + expectedPE: &GCPInterface{ + CommonEndpointInterface: CommonEndpointInterface{ + ID: "group-name", + InterfaceStatus: "INITIALIZING", + }, + Endpoints: []GCPInterfaceEndpoint{ + { + Name: "group-name-pe-1", + IP: "10.0.0.1", + Status: "INITIALIZING", + }, + { + Name: "group-name-pe-2", + IP: "10.0.0.2", + Status: "INITIALIZING", + }, + { + Name: "group-name-pe-3", + IP: "10.0.0.3", + Status: "INITIALIZING", + }, + }, + }, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + projectID := "project-ID" + serviceID := "pe-service-ID" + api := mockadmin.NewPrivateEndpointServicesApi(t) + api.EXPECT().CreatePrivateEndpoint(ctx, projectID, tt.provider, serviceID, mock.AnythingOfType("*admin.CreateEndpointRequest")). + Return(admin.CreatePrivateEndpointApiRequest{ApiService: api}) + api.EXPECT().CreatePrivateEndpointExecute(mock.AnythingOfType("admin.CreatePrivateEndpointApiRequest")). + Return(tt.mockCreateReturnFunc()) + + pe := &PrivateEndpoint{ + api: api, + } + + result, err := pe.CreatePrivateEndpointInterface(ctx, projectID, tt.provider, serviceID, tt.gcpProjectID, tt.endpointInterface) + assert.Equal(t, tt.expectedErr, err) + assert.Equal(t, tt.expectedPE, result) + }) + } +} + +func TestDeletePrivateEndpointInterface(t *testing.T) { + tests := map[string]struct { + mockDeleteReturnFunc func() (map[string]interface{}, *http.Response, error) + expectedErr error + }{ + "failed to delete the interface": { + mockDeleteReturnFunc: func() (map[string]interface{}, *http.Response, error) { + return nil, &http.Response{}, errors.New("atlas failed to delete") + }, + expectedErr: fmt.Errorf("failed to delete the private endpoint interface: %w", errors.New("atlas failed to delete")), + }, + "delete private endpoint service": { + mockDeleteReturnFunc: func() (map[string]interface{}, *http.Response, error) { + return nil, &http.Response{}, nil + }, + }, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + ctx := context.Background() + projectID := "project-ID" + api := mockadmin.NewPrivateEndpointServicesApi(t) + api.EXPECT().DeletePrivateEndpoint(ctx, projectID, "AWS", "endpoint-ID", "pe-service-ID"). + Return(admin.DeletePrivateEndpointApiRequest{ApiService: api}) + api.EXPECT().DeletePrivateEndpointExecute(mock.AnythingOfType("admin.DeletePrivateEndpointApiRequest")). + Return(tt.mockDeleteReturnFunc()) + + pe := &PrivateEndpoint{ + api: api, + } + + err := pe.DeleteEndpointInterface(ctx, projectID, "AWS", "pe-service-ID", "endpoint-ID") + assert.Equal(t, tt.expectedErr, err) + }) + } +}