diff --git a/pkg/server/server.go b/pkg/server/server.go index 5911affe9..c006543aa 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1650,9 +1650,10 @@ func (s *BgpServer) handleFSMMessage(peer *peer, e *fsmMsg) { if allEnd { for _, p := range s.neighborMap { p.fsm.lock.Lock() + peerLocalRestarting := p.fsm.pConf.GracefulRestart.State.LocalRestarting p.fsm.pConf.GracefulRestart.State.LocalRestarting = false p.fsm.lock.Unlock() - if !p.isGracefulRestartEnabled() { + if !p.isGracefulRestartEnabled() && !peerLocalRestarting { continue } paths, _ := s.getBestFromLocal(p, p.configuredRFlist()) @@ -1791,9 +1792,10 @@ func (s *BgpServer) handleFSMMessage(peer *peer, e *fsmMsg) { if allEnd { for _, p := range s.neighborMap { p.fsm.lock.Lock() + peerLocalRestarting := p.fsm.pConf.GracefulRestart.State.LocalRestarting p.fsm.pConf.GracefulRestart.State.LocalRestarting = false p.fsm.lock.Unlock() - if !p.isGracefulRestartEnabled() { + if !p.isGracefulRestartEnabled() && !peerLocalRestarting { continue } paths, _ := s.getBestFromLocal(p, p.negotiatedRFList()) diff --git a/test/scenario_test/graceful_restart_test.py b/test/scenario_test/graceful_restart_test.py index 798cde4f4..f7617865e 100644 --- a/test/scenario_test/graceful_restart_test.py +++ b/test/scenario_test/graceful_restart_test.py @@ -127,11 +127,25 @@ def test_04_add_non_graceful_restart_enabled_peer(self): self.bgpds['g3'] = g3 time.sleep(g3.run()) g3.add_route('10.10.30.0/24') - g1.add_peer(g3) + g1.add_peer(g3, graceful_restart=True) g3.add_peer(g1) g1.wait_for(expected_state=BGP_FSM_ESTABLISHED, peer=g3) time.sleep(1) self.assertEqual(len(g3.get_global_rib('10.10.20.0/24')), 1) + self.assertEqual(len(g1.get_global_rib('10.10.30.0/24')), 1) + + # Restart g1 with graceful restart flag and check that g1 routes are still propagated to g3. + g1.stop_gobgp() + g3.stop_gobgp() + g1.routes = {} + g3.routes = {} + g1.start_gobgp(graceful_restart=True) + g1.add_route('10.10.20.0/24') + g3.start_gobgp() + g1.wait_for(expected_state=BGP_FSM_ESTABLISHED, peer=g3) + time.sleep(1) + self.assertEqual(len(g3.get_global_rib('10.10.20.0/24')), 1) + g3.add_route('10.10.30.0/24') def test_05_holdtime_expiry_graceful_restart(self): g1 = self.bgpds['g1']