diff --git a/pkg/networkservice/chains/nsmgr/heal_test.go b/pkg/networkservice/chains/nsmgr/heal_test.go index 509947151..459cd1134 100644 --- a/pkg/networkservice/chains/nsmgr/heal_test.go +++ b/pkg/networkservice/chains/nsmgr/heal_test.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -872,12 +872,17 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) { // refresh interval in this test is expected to be 3 minutes and a few milliseconds clk.Add(time.Second * 190) - // kill the forwarder during the healing Request (it is stopped by syncCh). Then continue - the healing process will fail. - for _, forwarder := range domain.Nodes[0].Forwarders { + // kill the forwarder during the refresh (it is stopped by syncCh). Then continue - the refresh will fail. + for idx := range domain.Nodes[0].Forwarders { + forwarder := domain.Nodes[0].Forwarders[idx] forwarder.Cancel() - break + // wait until the forwarder dies + require.Eventually(t, func() bool { + return sandbox.CheckURLFree(forwarder.URL) + }, timeout, tick) } syncCh <- struct{}{} + close(syncCh) // create a new forwarder and allow the healing Request forwarderReg := ®istry.NetworkServiceEndpoint{ @@ -885,7 +890,6 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) { NetworkServiceNames: []string{"forwarder"}, } domain.Nodes[0].NewForwarder(ctx, forwarderReg, sandbox.GenerateTestToken) - syncCh <- struct{}{} // wait till Request reached NSE require.Eventually(t, func() bool { diff --git a/pkg/networkservice/common/refresh/client.go b/pkg/networkservice/common/refresh/client.go index 191cc5fe9..79ffd56a5 100644 --- a/pkg/networkservice/common/refresh/client.go +++ b/pkg/networkservice/common/refresh/client.go @@ -1,6 +1,6 @@ -// Copyright (c) 2020 Cisco Systems, Inc. +// Copyright (c) 2020-2024 Cisco Systems, Inc. // -// Copyright (c) 2020-2022 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2024 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -68,17 +68,15 @@ func (t *refreshClient) Request(ctx context.Context, request *networkservice.Net store(ctx, metadata.IsClient(t), cancel) eventFactory := begin.FromContext(ctx) - clockTime := clock.FromContext(ctx) // Create the afterCh *outside* the go routine. This must be done to avoid picking up a later 'now' // from mockClock in testing - afterTicker := clockTime.Ticker(refreshAfter) + afterCh := clock.FromContext(ctx).After(refreshAfter) go func() { - defer afterTicker.Stop() - for { - select { - case <-cancelCtx.Done(): - return - case <-afterTicker.C(): + select { + case <-cancelCtx.Done(): + return + case <-afterCh: + for cancelCtx.Err() == nil { if err := <-eventFactory.Request(begin.CancelContext(cancelCtx)); err != nil { logger.Warnf("refresh failed: %s", err.Error()) continue diff --git a/pkg/networkservice/common/refresh/client_test.go b/pkg/networkservice/common/refresh/client_test.go index 0a571ac12..78aa29a9f 100644 --- a/pkg/networkservice/common/refresh/client_test.go +++ b/pkg/networkservice/common/refresh/client_test.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -69,19 +71,19 @@ func testTokenFuncWithTimeout(clockTime clock.Clock, timeout time.Duration) toke } } -type captureTickerDuration struct { +type captureAfterDuration struct { *clockmock.Mock - tickerDuration time.Duration + afterDuration time.Duration } -func (m *captureTickerDuration) Ticker(d time.Duration) clock.Ticker { - m.tickerDuration = d - return m.Mock.Ticker(d) +func (m *captureAfterDuration) After(d time.Duration) <-chan time.Time { + m.afterDuration = d + return m.Mock.After(d) } -func (m *captureTickerDuration) Reset(t time.Time) { - m.tickerDuration = 0 +func (m *captureAfterDuration) Reset(t time.Time) { + m.afterDuration = 0 m.Set(t) } @@ -355,7 +357,7 @@ func TestRefreshClient_CalculatesShortestTokenTimeout(t *testing.T) { timeNow := time.Date(2009, 11, 10, 23, 0, 0, 0, time.Local) - clockMock := captureTickerDuration{ + clockMock := captureAfterDuration{ Mock: clockmock.New(ctx), } @@ -389,8 +391,8 @@ func TestRefreshClient_CalculatesShortestTokenTimeout(t *testing.T) { }) require.NoError(t, err) - require.Less(t, clockMock.tickerDuration, testDataElement.ExpectedRefreshTimeout+timeoutDelta) - require.Greater(t, clockMock.tickerDuration, testDataElement.ExpectedRefreshTimeout-timeoutDelta) + require.Less(t, clockMock.afterDuration, testDataElement.ExpectedRefreshTimeout+timeoutDelta) + require.Greater(t, clockMock.afterDuration, testDataElement.ExpectedRefreshTimeout-timeoutDelta) } require.Equal(t, countClient.Requests(), len(testData))