diff --git a/cmd/amazon-cloudwatch-agent-target-allocator/config/config.go b/cmd/amazon-cloudwatch-agent-target-allocator/config/config.go index ce2a1841..26592946 100644 --- a/cmd/amazon-cloudwatch-agent-target-allocator/config/config.go +++ b/cmd/amazon-cloudwatch-agent-target-allocator/config/config.go @@ -11,9 +11,10 @@ import ( "fmt" "io/fs" "os" - "sigs.k8s.io/controller-runtime/pkg/certwatcher" "time" + "sigs.k8s.io/controller-runtime/pkg/certwatcher" + "github.com/go-logr/logr" "github.com/prometheus/common/model" promconfig "github.com/prometheus/prometheus/config" diff --git a/cmd/amazon-cloudwatch-agent-target-allocator/server/server.go b/cmd/amazon-cloudwatch-agent-target-allocator/server/server.go index cabee44d..ed81a955 100644 --- a/cmd/amazon-cloudwatch-agent-target-allocator/server/server.go +++ b/cmd/amazon-cloudwatch-agent-target-allocator/server/server.go @@ -73,6 +73,8 @@ func WithTLSConfig(tlsConfig *tls.Config, httpsListenAddr string) Option { s.setRouter(httpsRouter) s.httpsServer = &http.Server{Addr: httpsListenAddr, Handler: httpsRouter, ReadHeaderTimeout: 90 * time.Second, TLSConfig: tlsConfig} + s.server.Shutdown(context.Background()) + s.server = s.httpsServer } } diff --git a/cmd/amazon-cloudwatch-agent-target-allocator/server/server_test.go b/cmd/amazon-cloudwatch-agent-target-allocator/server/server_test.go index dc27bf21..9e082219 100644 --- a/cmd/amazon-cloudwatch-agent-target-allocator/server/server_test.go +++ b/cmd/amazon-cloudwatch-agent-target-allocator/server/server_test.go @@ -4,13 +4,22 @@ package server import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "encoding/json" + "encoding/pem" "fmt" "io" + "math/big" "net/http" "net/http/httptest" "net/url" + "os" "testing" "time" @@ -185,7 +194,7 @@ func TestServer_TargetsHandler(t *testing.T) { func TestServer_ScrapeConfigsHandler(t *testing.T) { svrConfig := allocatorconfig.HTTPSServerConfig{} - tlsConfig, _ := svrConfig.NewTLSConfig() + tlsConfig, _ := svrConfig.NewTLSConfig(context.TODO()) tests := []struct { description string scrapeConfigs map[string]*promconfig.ScrapeConfig @@ -605,6 +614,7 @@ func TestServer_JobHandler(t *testing.T) { }) } } + func TestServer_Readiness(t *testing.T) { tests := []struct { description string @@ -669,6 +679,269 @@ func TestServer_Readiness(t *testing.T) { } } +func TestServer_ValidCAonTLS(t *testing.T) { + listenAddr := ":8443" + server, clientTlsConfig, err := createTestTLSServer(listenAddr) + assert.NoError(t, err) + go func() { + assert.ErrorIs(t, server.StartHTTPS(), http.ErrServerClosed) + }() + time.Sleep(100 * time.Millisecond) // wait for server to launch + defer func() { + err := server.ShutdownHTTPS(context.Background()) + if err != nil { + assert.NoError(t, err) + } + }() + tests := []struct { + description string + endpoint string + expectedCode int + }{ + { + description: "with tls test for scrape config", + endpoint: "scrape_configs", + expectedCode: http.StatusOK, + }, + { + description: "with tls test for jobs", + endpoint: "jobs", + expectedCode: http.StatusOK, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + // Create a custom HTTP client with TLS transport + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: clientTlsConfig, + }, + } + + // Make the GET request + request, err := client.Get(fmt.Sprintf("https://localhost%s/%s", listenAddr, tc.endpoint)) + + // Verify if a certificate verification error occurred + require.NoError(t, err) + + // Only check the status code if there was no error + if err == nil { + assert.Equal(t, tc.expectedCode, request.StatusCode) + } else { + t.Log(err) + } + }) + } +} + +func TestServer_MissingCAonTLS(t *testing.T) { + listenAddr := ":8443" + server, _, err := createTestTLSServer(listenAddr) + assert.NoError(t, err) + go func() { + assert.ErrorIs(t, server.StartHTTPS(), http.ErrServerClosed) + }() + time.Sleep(100 * time.Millisecond) // wait for server to launch + defer func() { + err := server.ShutdownHTTPS(context.Background()) + if err != nil { + assert.NoError(t, err) + } + }() + tests := []struct { + description string + endpoint string + expectedCode int + }{ + { + description: "no tls test for scrape config", + endpoint: "scrape_configs", + expectedCode: http.StatusBadRequest, + }, + { + description: "no tls test for jobs", + endpoint: "jobs", + expectedCode: http.StatusBadRequest, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + request, err := http.Get(fmt.Sprintf("https://localhost%s/%s", listenAddr, tc.endpoint)) + + // Verify if a certificate verification error occurred + require.Error(t, err) + + // Only check the status code if there was no error + if err == nil { + assert.Equal(t, tc.expectedCode, request.StatusCode) + } + }) + } +} + +func TestServer_HTTPOnTLS(t *testing.T) { + listenAddr := ":8443" + server, _, err := createTestTLSServer(listenAddr) + assert.NoError(t, err) + go func() { + assert.NoError(t, server.StartHTTPS()) + }() + time.Sleep(100 * time.Millisecond) // wait for server to launch + + defer func(s *Server, ctx context.Context) { + err := s.Shutdown(ctx) + if err != nil { + assert.NoError(t, err) + } + }(server, context.Background()) + tests := []struct { + description string + endpoint string + expectedCode int + }{ + { + description: "no tls test for scrape config", + endpoint: "scrape_configs", + expectedCode: http.StatusBadRequest, + }, + { + description: "no tls test for jobs", + endpoint: "jobs", + expectedCode: http.StatusBadRequest, + }, + } + for _, tc := range tests { + t.Run(tc.description, func(t *testing.T) { + request, err := http.Get(fmt.Sprintf("http://localhost%s/%s", listenAddr, tc.endpoint)) + + // Only check the status code if there was no error + if err == nil { + assert.Equal(t, tc.expectedCode, request.StatusCode) + } + }) + } +} + +func createTestTLSServer(listenAddr string) (*Server, *tls.Config, error) { + //testing using this function replicates customer environment + svrConfig := allocatorconfig.HTTPSServerConfig{} + caBundle, caCert, caKey, err := generateTestingCerts() + if err != nil { + return nil, nil, err + } + svrConfig.TLSKeyFilePath = caKey + svrConfig.TLSCertFilePath = caCert + tlsConfig, err := svrConfig.NewTLSConfig(context.TODO()) + if err != nil { + return nil, nil, err + } + httpOptions := []Option{} + httpOptions = append(httpOptions, WithTLSConfig(tlsConfig, listenAddr)) + + //generate ca bundle + bundle, err := readCABundle(caBundle) + if err != nil { + return nil, nil, err + } + allocator := &mockAllocator{targetItems: map[string]*target.Item{ + "a": target.NewItem("job1", "", model.LabelSet{}, ""), + }} + + return NewServer(logger, allocator, listenAddr, httpOptions...), bundle, nil +} + func newLink(jobName string) target.LinkJSON { return target.LinkJSON{Link: fmt.Sprintf("/jobs/%s/targets", url.QueryEscape(jobName))} } + +func readCABundle(caBundlePath string) (*tls.Config, error) { + // Load the CA bundle + caCert, err := os.ReadFile(caBundlePath) + if err != nil { + return nil, fmt.Errorf("failed to read CA bundle: %w", err) + } + + // Create a CA pool and add the CA certificate(s) + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to add CA certificates to pool") + } + + // Set up TLS configuration with the CA pool + tlsConfig := &tls.Config{ + RootCAs: caCertPool, + } + return tlsConfig, nil +} + +func generateTestingCerts() (caBundlePath, caCertPath, caKeyPath string, err error) { + // Generate private key + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return "", "", "", fmt.Errorf("error generating private key: %w", err) + } + + // Set up certificate template + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "localhost", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year validity + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + DNSNames: []string{"localhost"}, + } + + // Self-sign the certificate + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return "", "", "", fmt.Errorf("error creating certificate: %w", err) + } + + // Create temporary files + tempDir := os.TempDir() + + caCertFile, err := os.CreateTemp(tempDir, "ca-cert-*.crt") + if err != nil { + return "", "", "", fmt.Errorf("error creating temp CA cert file: %w", err) + } + defer caCertFile.Close() + + caKeyFile, err := os.CreateTemp(tempDir, "ca-key-*.key") + if err != nil { + return "", "", "", fmt.Errorf("error creating temp CA key file: %w", err) + } + defer caKeyFile.Close() + + caBundleFile, err := os.CreateTemp(tempDir, "ca-bundle-*.crt") + if err != nil { + return "", "", "", fmt.Errorf("error creating temp CA bundle file: %w", err) + } + defer caBundleFile.Close() + + // Write the private key to the key file + privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return "", "", "", fmt.Errorf("error writing private key: %w", err) + } + err = pem.Encode(caKeyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privateKeyBytes}) + if err != nil { + return "", "", "", fmt.Errorf("error writing private key: %w", err) + } + + // Write the certificate to the certificate and bundle files + certPEM := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes} + if err = pem.Encode(caCertFile, certPEM); err != nil { + return "", "", "", fmt.Errorf("error writing certificate: %w", err) + } + if err = pem.Encode(caBundleFile, certPEM); err != nil { + return "", "", "", fmt.Errorf("error writing bundle certificate: %w", err) + } + + // Return the file paths + return caBundleFile.Name(), caCertFile.Name(), caKeyFile.Name(), nil +}