Skip to content

Commit

Permalink
Improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
cdoucy committed Oct 26, 2024
1 parent a79d65f commit c6a6d2c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 14 deletions.
61 changes: 47 additions & 14 deletions cert_watcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,14 @@ func TestClient_SetRootCertificateWatcher(t *testing.T) {
generateCerts(t, paths)
startHTTPSServer(fmt.Sprintf(":%d", port), paths)

poolingInterval := time.Second * 1

//client := New().SetRootCertificate(paths.RootCACert).SetDebug(true)
client := New().SetRootCertificateWatcher(paths.RootCACert, &CertWatcherOptions{
PoolInterval: time.Second * 1,
}).SetDebug(true)
PoolInterval: poolingInterval,
}).SetClientRootCertificateWatcher(paths.RootCACert, &CertWatcherOptions{
PoolInterval: poolingInterval,
}).SetDebug(false)

tr, err := client.Transport()
if err != nil {
Expand All @@ -63,21 +67,50 @@ func TestClient_SetRootCertificateWatcher(t *testing.T) {

url := fmt.Sprintf("https://localhost:%d/", port)

for i := 0; i < 5; i++ {
t.Logf("i = %d", i)
res, err := client.R().Get(url)
if err != nil {
t.Fatal(err)
}
t.Run("Cert Watcher should handle certs rotation", func(t *testing.T) {
for i := 0; i < 5; i++ {
res, err := client.R().Get(url)
if err != nil {
t.Fatal(err)
}

assertEqual(t, res.StatusCode(), http.StatusOK)
assertEqual(t, res.StatusCode(), http.StatusOK)

if i%2 == 1 {
// Re-generate certs to simulate renewal scenario
generateCerts(t, paths)
if i%2 == 1 {
// Re-generate certs to simulate renewal scenario
generateCerts(t, paths)
}
time.Sleep(poolingInterval)
}
time.Sleep(time.Second * 1)
}
})

t.Run("Cert Watcher should recover on failure", func(t *testing.T) {
// Delete root cert and re-create it to ensure that cert watcher is able to recover

// Re-generate certs to invalidate existing cert
generateCerts(t, paths)
// Delete root cert so that Cert Watcher will fail
err = os.RemoveAll(paths.RootCACert)
assertNil(t, err)

// Reset TLS config to ensure that previous root cert is not re-used
tr, err = client.Transport()
assertNil(t, err)
tr.TLSClientConfig = nil
client.SetTransport(tr)

time.Sleep(poolingInterval)

_, err = client.R().Get(url)
// We expect an error since root cert has been deleted
assertNotNil(t, err)

// Re-generate certs. We except cert watcher to reload the new root cert.
generateCerts(t, paths)
time.Sleep(poolingInterval)
_, err = client.R().Get(url)
assertNil(t, err)
})

err = client.Close()
assertNil(t, err)
Expand Down
41 changes: 41 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"io"
"log"
"math"
"net"
"net/http"
Expand Down Expand Up @@ -353,6 +354,30 @@ func TestClientSetClientRootCertificateNotExists(t *testing.T) {
assertNil(t, transport.TLSClientConfig)
}

func TestClientSetClientRootCertificateWatcher(t *testing.T) {
t.Run("Cert exists", func(t *testing.T) {
client := dcnl()
client.SetClientRootCertificateWatcher(filepath.Join(getTestDataPath(), "sample-root.pem"), &CertWatcherOptions{
PoolInterval: time.Second * 1,
})

transport, err := client.Transport()

assertNil(t, err)
assertNotNil(t, transport.TLSClientConfig.ClientCAs)
})

t.Run("Cert does not exist", func(t *testing.T) {
client := dcnl()
client.SetClientRootCertificateWatcher(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem"), nil)

transport, err := client.Transport()

assertNil(t, err)
assertNil(t, transport.TLSClientConfig)
})
}

func TestClientSetClientRootCertificateFromString(t *testing.T) {
client := dcnl()
rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem"))
Expand Down Expand Up @@ -1334,3 +1359,19 @@ func TestResponseBodyLimit(t *testing.T) {
assertErrorIs(t, gzip.ErrHeader, err)
})
}

func TestClientDebugf(t *testing.T) {
t.Run("Debug mode enabled", func(t *testing.T) {
var b bytes.Buffer
c := New().SetLogger(&logger{l: log.New(&b, "", 0)}).SetDebug(true)
c.debugf("hello")
assertEqual(t, "DEBUG RESTY hello\n", b.String())
})

t.Run("Debug mode disabled", func(t *testing.T) {
var b bytes.Buffer
c := New().SetLogger(&logger{l: log.New(&b, "", 0)})
c.debugf("hello")
assertEqual(t, "", b.String())
})
}

0 comments on commit c6a6d2c

Please sign in to comment.