diff --git a/server.go b/server.go index b198f30..8b74823 100644 --- a/server.go +++ b/server.go @@ -164,17 +164,28 @@ func (s *Server) Stop() error { const op = "gldap.(Server).Stop" s.mu.RLock() defer s.mu.RUnlock() - if s.listener == nil { - return fmt.Errorf("%s: no listener: %w", op, ErrInvalidState) + + s.logger.Debug("shutting down") + if s.listener == nil && s.shutdownCancel == nil { + s.logger.Debug("nothing to do for shutdown") + return nil } - if s.shutdownCancel == nil { - return fmt.Errorf("%s: no shutdown context cancel func: %w", op, ErrInvalidState) + + if s.listener != nil { + s.logger.Debug("closing listener") + if err := s.listener.Close(); err != nil { + switch { + case !strings.Contains(err.Error(), "use of closed network connection"): + return fmt.Errorf("%s: %w", op, err) + default: + s.logger.Debug("listener already closed") + } + } } - if err := s.listener.Close(); err != nil { - return fmt.Errorf("%s: %w", op, err) + if s.shutdownCancel != nil { + s.logger.Debug("shutdown cancel func") + s.shutdownCancel() } - s.logger.Debug("shutting down") - s.shutdownCancel() s.logger.Debug("waiting on connections to close") s.connWg.Wait() s.logger.Debug("stopped") diff --git a/server_internal_test.go b/server_internal_test.go index d0b530b..fb276e2 100644 --- a/server_internal_test.go +++ b/server_internal_test.go @@ -2,16 +2,27 @@ package gldap import ( "context" + "errors" "fmt" "net" + "os" + "strconv" "testing" + "github.com/hashicorp/go-hclog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestServer_Stop(t *testing.T) { t.Parallel() + var testLogger hclog.Logger + if ok, _ := strconv.ParseBool(os.Getenv("DEBUG")); ok { + testLogger = hclog.New(&hclog.LoggerOptions{ + Name: "TestServer_Run-logger", + Level: hclog.Debug, + }) + } tests := []struct { name string server *Server @@ -21,15 +32,13 @@ func TestServer_Stop(t *testing.T) { { name: "missing-listener", server: func() *Server { - s, err := NewServer() + s, err := NewServer(WithLogger(testLogger)) require.NoError(t, err) s.mu.Lock() defer s.mu.Unlock() s.listener = nil return s }(), - wantErr: true, - wantErrContains: "no listener", }, { name: "missing-cancel", @@ -37,7 +46,7 @@ func TestServer_Stop(t *testing.T) { p := freePort(t) l, err := net.Listen("tcp", fmt.Sprintf(":%d", p)) require.NoError(t, err) - s, err := NewServer() + s, err := NewServer(WithLogger(testLogger)) require.NoError(t, err) s.mu.Lock() defer s.mu.Unlock() @@ -45,8 +54,18 @@ func TestServer_Stop(t *testing.T) { s.shutdownCancel = nil return s }(), - wantErr: true, - wantErrContains: "no shutdown context cancel func", + }, + { + name: "nothing-to-do", + server: func() *Server { + s, err := NewServer(WithLogger(testLogger)) + require.NoError(t, err) + s.mu.Lock() + defer s.mu.Unlock() + s.listener = nil + s.shutdownCancel = nil + return s + }(), }, { name: "listener-closed", @@ -55,7 +74,7 @@ func TestServer_Stop(t *testing.T) { p := freePort(t) l, err := net.Listen("tcp", fmt.Sprintf(":%d", p)) require.NoError(t, err) - s, err := NewServer() + s, err := NewServer(WithLogger(testLogger)) require.NoError(t, err) s.mu.Lock() defer s.mu.Unlock() @@ -64,8 +83,21 @@ func TestServer_Stop(t *testing.T) { l.Close() return s }(), + }, + { + name: "listener-close-err", + server: func() *Server { + _, cancel := context.WithCancel(context.Background()) + s, err := NewServer(WithLogger(testLogger)) + require.NoError(t, err) + s.mu.Lock() + defer s.mu.Unlock() + s.listener = &mockListener{} + s.shutdownCancel = cancel + return s + }(), wantErr: true, - wantErrContains: "use of closed network connection", + wantErrContains: "mockListener.Close error", }, } for _, tc := range tests { @@ -83,3 +115,11 @@ func TestServer_Stop(t *testing.T) { }) } } + +type mockListener struct { + net.Listener +} + +func (*mockListener) Close() error { + return errors.New("mockListener.Close error") +}