diff --git a/activator/activator.go b/activator/activator.go index d76685e..b49129f 100644 --- a/activator/activator.go +++ b/activator/activator.go @@ -31,7 +31,6 @@ type Server struct { proxyTimeout time.Duration proxyCancel context.CancelFunc ns ns.NetNS - firstAccept sync.Once maps bpfMaps sandboxPid int started bool @@ -95,7 +94,6 @@ func (s *Server) Started() bool { } func (s *Server) Reset() error { - s.firstAccept = sync.Once{} for _, port := range s.ports { if err := s.enableRedirect(port); err != nil { return err @@ -134,7 +132,6 @@ func (s *Server) listen(ctx context.Context, port uint16, onAccept OnAccept) (in log.G(ctx).Debugf("listening on %s in ns %s", listener.Addr(), s.ns.Path()) - s.firstAccept = sync.Once{} s.onAccept = onAccept s.wg.Add(1) @@ -213,12 +210,10 @@ func (s *Server) handleConection(ctx context.Context, conn net.Conn, port uint16 return } - s.firstAccept.Do(func() { - if err := s.onAccept(); err != nil { - log.G(ctx).Errorf("accept function: %s", err) - return - } - }) + if err := s.onAccept(); err != nil { + log.G(ctx).Errorf("accept function: %s", err) + return + } backendConn, err := s.connect(ctx, port) if err != nil { diff --git a/activator/activator_test.go b/activator/activator_test.go index ddbe0b8..32d3850 100644 --- a/activator/activator_test.go +++ b/activator/activator_test.go @@ -40,26 +40,28 @@ func TestActivator(t *testing.T) { fmt.Fprint(w, response) })) + once := sync.Once{} err = s.Start(ctx, []uint16{uint16(port)}, func() error { - // simulate a delay until our server is started - time.Sleep(time.Millisecond * 200) - l, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) - require.NoError(t, err) - - if err := s.DisableRedirects(); err != nil { - return fmt.Errorf("could not disable redirects: %w", err) - } - - // replace listener of server - ts.Listener.Close() - ts.Listener = l - ts.Start() - t.Logf("listening on :%d", port) - - t.Cleanup(func() { - ts.Close() + once.Do(func() { + // simulate a delay until our server is started + time.Sleep(time.Millisecond * 200) + l, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) + require.NoError(t, err) + + if err := s.DisableRedirects(); err != nil { + t.Errorf("could not disable redirects: %s", err) + } + + // replace listener of server + ts.Listener.Close() + ts.Listener = l + ts.Start() + t.Logf("listening on :%d", port) + + t.Cleanup(func() { + ts.Close() + }) }) - return nil }) require.NoError(t, err) diff --git a/runc/task/service_zeropod.go b/runc/task/service_zeropod.go index 6544d6b..ebe7ce7 100644 --- a/runc/task/service_zeropod.go +++ b/runc/task/service_zeropod.go @@ -197,7 +197,6 @@ func (w *wrapper) Exec(ctx context.Context, r *taskAPI.ExecProcessRequest) (*emp os.Exit(1) } - zeropodContainer.SetScaledDown(false) log.G(ctx).Printf("restored process for exec: %d in %s", p.Pid(), time.Since(beforeRestore)) } diff --git a/zeropod/container.go b/zeropod/container.go index 93ef148..8de2655 100644 --- a/zeropod/container.go +++ b/zeropod/container.go @@ -290,6 +290,10 @@ func (c *Container) restoreHandler(ctx context.Context) activator.OnAccept { beforeRestore := time.Now() restoredContainer, p, err := c.Restore(ctx) if err != nil { + if errors.Is(err, ErrAlreadyRestored) { + log.G(ctx).Info("container is already restored, ignoring request") + return nil + } // restore failed, this is currently unrecoverable, so we shutdown // our shim and let containerd recreate it. log.G(ctx).Fatalf("error restoring container, exiting shim: %s", err) @@ -301,7 +305,6 @@ func (c *Container) restoreHandler(ctx context.Context) activator.OnAccept { return fmt.Errorf("unable to track pid %d: %w", p.Pid(), err) } - c.SetScaledDown(false) log.G(ctx).Printf("restored process: %d in %s", p.Pid(), time.Since(beforeRestore)) return c.ScheduleScaleDown() diff --git a/zeropod/restore.go b/zeropod/restore.go index 36bac07..c26d66b 100644 --- a/zeropod/restore.go +++ b/zeropod/restore.go @@ -2,6 +2,7 @@ package zeropod import ( "context" + "errors" "fmt" "io" "os" @@ -19,9 +20,14 @@ import ( "github.com/containerd/log" ) +var ErrAlreadyRestored = errors.New("container is already restored") + func (c *Container) Restore(ctx context.Context) (*runc.Container, process.Process, error) { c.checkpointRestore.Lock() defer c.checkpointRestore.Unlock() + if !c.ScaledDown() { + return nil, nil, ErrAlreadyRestored + } beforeRestore := time.Now() go func() { @@ -80,6 +86,7 @@ func (c *Container) Restore(ctx context.Context) (*runc.Container, process.Proce c.Container = container c.process = p + c.SetScaledDown(false) if c.postRestore != nil { c.postRestore(container, handleStarted)