Skip to content

Commit

Permalink
[TA] Target Allocator TLS Unit-tests (#265)
Browse files Browse the repository at this point in the history
* TLS tests
  • Loading branch information
okankoAMZ authored Oct 31, 2024
1 parent 31db083 commit 7ae202e
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (
"fmt"
"io/fs"
"os"
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
"time"

"sigs.k8s.io/controller-runtime/pkg/certwatcher"

"github.com/go-logr/logr"
"github.com/prometheus/common/model"
promconfig "github.com/prometheus/prometheus/config"
Expand Down
2 changes: 2 additions & 0 deletions cmd/amazon-cloudwatch-agent-target-allocator/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ func WithTLSConfig(tlsConfig *tls.Config, httpsListenAddr string) Option {
s.setRouter(httpsRouter)

s.httpsServer = &http.Server{Addr: httpsListenAddr, Handler: httpsRouter, ReadHeaderTimeout: 90 * time.Second, TLSConfig: tlsConfig}
s.server.Shutdown(context.Background())
s.server = s.httpsServer
}
}

Expand Down
275 changes: 274 additions & 1 deletion cmd/amazon-cloudwatch-agent-target-allocator/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
package server

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"
"time"

Expand Down Expand Up @@ -185,7 +194,7 @@ func TestServer_TargetsHandler(t *testing.T) {

func TestServer_ScrapeConfigsHandler(t *testing.T) {
svrConfig := allocatorconfig.HTTPSServerConfig{}
tlsConfig, _ := svrConfig.NewTLSConfig()
tlsConfig, _ := svrConfig.NewTLSConfig(context.TODO())
tests := []struct {
description string
scrapeConfigs map[string]*promconfig.ScrapeConfig
Expand Down Expand Up @@ -605,6 +614,7 @@ func TestServer_JobHandler(t *testing.T) {
})
}
}

func TestServer_Readiness(t *testing.T) {
tests := []struct {
description string
Expand Down Expand Up @@ -669,6 +679,269 @@ func TestServer_Readiness(t *testing.T) {
}
}

func TestServer_ValidCAonTLS(t *testing.T) {
listenAddr := ":8443"
server, clientTlsConfig, err := createTestTLSServer(listenAddr)
assert.NoError(t, err)
go func() {
assert.ErrorIs(t, server.StartHTTPS(), http.ErrServerClosed)
}()
time.Sleep(100 * time.Millisecond) // wait for server to launch
defer func() {
err := server.ShutdownHTTPS(context.Background())
if err != nil {
assert.NoError(t, err)
}
}()
tests := []struct {
description string
endpoint string
expectedCode int
}{
{
description: "with tls test for scrape config",
endpoint: "scrape_configs",
expectedCode: http.StatusOK,
},
{
description: "with tls test for jobs",
endpoint: "jobs",
expectedCode: http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
// Create a custom HTTP client with TLS transport
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: clientTlsConfig,
},
}

// Make the GET request
request, err := client.Get(fmt.Sprintf("https://localhost%s/%s", listenAddr, tc.endpoint))

// Verify if a certificate verification error occurred
require.NoError(t, err)

// Only check the status code if there was no error
if err == nil {
assert.Equal(t, tc.expectedCode, request.StatusCode)
} else {
t.Log(err)
}
})
}
}

func TestServer_MissingCAonTLS(t *testing.T) {
listenAddr := ":8443"
server, _, err := createTestTLSServer(listenAddr)
assert.NoError(t, err)
go func() {
assert.ErrorIs(t, server.StartHTTPS(), http.ErrServerClosed)
}()
time.Sleep(100 * time.Millisecond) // wait for server to launch
defer func() {
err := server.ShutdownHTTPS(context.Background())
if err != nil {
assert.NoError(t, err)
}
}()
tests := []struct {
description string
endpoint string
expectedCode int
}{
{
description: "no tls test for scrape config",
endpoint: "scrape_configs",
expectedCode: http.StatusBadRequest,
},
{
description: "no tls test for jobs",
endpoint: "jobs",
expectedCode: http.StatusBadRequest,
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
request, err := http.Get(fmt.Sprintf("https://localhost%s/%s", listenAddr, tc.endpoint))

// Verify if a certificate verification error occurred
require.Error(t, err)

// Only check the status code if there was no error
if err == nil {
assert.Equal(t, tc.expectedCode, request.StatusCode)
}
})
}
}

func TestServer_HTTPOnTLS(t *testing.T) {
listenAddr := ":8443"
server, _, err := createTestTLSServer(listenAddr)
assert.NoError(t, err)
go func() {
assert.NoError(t, server.StartHTTPS())
}()
time.Sleep(100 * time.Millisecond) // wait for server to launch

defer func(s *Server, ctx context.Context) {
err := s.Shutdown(ctx)
if err != nil {
assert.NoError(t, err)
}
}(server, context.Background())
tests := []struct {
description string
endpoint string
expectedCode int
}{
{
description: "no tls test for scrape config",
endpoint: "scrape_configs",
expectedCode: http.StatusBadRequest,
},
{
description: "no tls test for jobs",
endpoint: "jobs",
expectedCode: http.StatusBadRequest,
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
request, err := http.Get(fmt.Sprintf("http://localhost%s/%s", listenAddr, tc.endpoint))

// Only check the status code if there was no error
if err == nil {
assert.Equal(t, tc.expectedCode, request.StatusCode)
}
})
}
}

func createTestTLSServer(listenAddr string) (*Server, *tls.Config, error) {
//testing using this function replicates customer environment
svrConfig := allocatorconfig.HTTPSServerConfig{}
caBundle, caCert, caKey, err := generateTestingCerts()
if err != nil {
return nil, nil, err
}
svrConfig.TLSKeyFilePath = caKey
svrConfig.TLSCertFilePath = caCert
tlsConfig, err := svrConfig.NewTLSConfig(context.TODO())
if err != nil {
return nil, nil, err
}
httpOptions := []Option{}
httpOptions = append(httpOptions, WithTLSConfig(tlsConfig, listenAddr))

//generate ca bundle
bundle, err := readCABundle(caBundle)
if err != nil {
return nil, nil, err
}
allocator := &mockAllocator{targetItems: map[string]*target.Item{
"a": target.NewItem("job1", "", model.LabelSet{}, ""),
}}

return NewServer(logger, allocator, listenAddr, httpOptions...), bundle, nil
}

func newLink(jobName string) target.LinkJSON {
return target.LinkJSON{Link: fmt.Sprintf("/jobs/%s/targets", url.QueryEscape(jobName))}
}

func readCABundle(caBundlePath string) (*tls.Config, error) {
// Load the CA bundle
caCert, err := os.ReadFile(caBundlePath)
if err != nil {
return nil, fmt.Errorf("failed to read CA bundle: %w", err)
}

// Create a CA pool and add the CA certificate(s)
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to add CA certificates to pool")
}

// Set up TLS configuration with the CA pool
tlsConfig := &tls.Config{
RootCAs: caCertPool,
}
return tlsConfig, nil
}

func generateTestingCerts() (caBundlePath, caCertPath, caKeyPath string, err error) {
// Generate private key
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return "", "", "", fmt.Errorf("error generating private key: %w", err)
}

// Set up certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "localhost",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year validity
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
DNSNames: []string{"localhost"},
}

// Self-sign the certificate
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
if err != nil {
return "", "", "", fmt.Errorf("error creating certificate: %w", err)
}

// Create temporary files
tempDir := os.TempDir()

caCertFile, err := os.CreateTemp(tempDir, "ca-cert-*.crt")
if err != nil {
return "", "", "", fmt.Errorf("error creating temp CA cert file: %w", err)
}
defer caCertFile.Close()

caKeyFile, err := os.CreateTemp(tempDir, "ca-key-*.key")
if err != nil {
return "", "", "", fmt.Errorf("error creating temp CA key file: %w", err)
}
defer caKeyFile.Close()

caBundleFile, err := os.CreateTemp(tempDir, "ca-bundle-*.crt")
if err != nil {
return "", "", "", fmt.Errorf("error creating temp CA bundle file: %w", err)
}
defer caBundleFile.Close()

// Write the private key to the key file
privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return "", "", "", fmt.Errorf("error writing private key: %w", err)
}
err = pem.Encode(caKeyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privateKeyBytes})
if err != nil {
return "", "", "", fmt.Errorf("error writing private key: %w", err)
}

// Write the certificate to the certificate and bundle files
certPEM := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}
if err = pem.Encode(caCertFile, certPEM); err != nil {
return "", "", "", fmt.Errorf("error writing certificate: %w", err)
}
if err = pem.Encode(caBundleFile, certPEM); err != nil {
return "", "", "", fmt.Errorf("error writing bundle certificate: %w", err)
}

// Return the file paths
return caBundleFile.Name(), caCertFile.Name(), caKeyFile.Name(), nil
}

0 comments on commit 7ae202e

Please sign in to comment.