diff --git a/examples/streaming/multi-sinks/main.go b/examples/streaming/multi-sinks/main.go index 9be3735..3a8ae08 100644 --- a/examples/streaming/multi-sinks/main.go +++ b/examples/streaming/multi-sinks/main.go @@ -54,7 +54,7 @@ func main() { } // Don't forget to close the sink when done - defer sink1.Close() + defer sink1.Close(ctx) // Read and acknowlege event ev := <-sink1.Subscribe() @@ -70,7 +70,7 @@ func main() { if err != nil { panic(err) } - defer sink2.Close() + defer sink2.Close(ctx) // Read second event ev = <-sink2.Subscribe() diff --git a/examples/streaming/multi-streams/main.go b/examples/streaming/multi-streams/main.go index b1b04c5..482a22e 100644 --- a/examples/streaming/multi-streams/main.go +++ b/examples/streaming/multi-streams/main.go @@ -40,7 +40,7 @@ func main() { } // Don't forget to close the sink when done - defer sink.Close() + defer sink.Close(ctx) // Subscribe to events c := sink.Subscribe() diff --git a/examples/streaming/pub-sub/main.go b/examples/streaming/pub-sub/main.go index e001f8f..b595128 100644 --- a/examples/streaming/pub-sub/main.go +++ b/examples/streaming/pub-sub/main.go @@ -56,7 +56,7 @@ func main() { } // Don't forget to close the sink when done - defer sink.Close() + defer sink.Close(ctx) // Read both events c := sink.Subscribe() diff --git a/examples/streaming/single-sink/main.go b/examples/streaming/single-sink/main.go index 861999f..b11a470 100644 --- a/examples/streaming/single-sink/main.go +++ b/examples/streaming/single-sink/main.go @@ -45,7 +45,7 @@ func main() { } // Don't forget to close the sink when done - defer sink.Close() + defer sink.Close(ctx) // Consume event ev := <-sink.Subscribe() diff --git a/pool/node.go b/pool/node.go index 200f1f5..366a0ee 100644 --- a/pool/node.go +++ b/pool/node.go @@ -27,27 +27,29 @@ import ( type ( // Node is a pool of workers. Node struct { - NodeID string - PoolName string - poolStream *streaming.Stream // pool event stream for dispatching jobs - poolSink *streaming.Sink // pool event sink - nodeStream *streaming.Stream // node event stream for receiving worker events - nodeReader *streaming.Reader // node event reader - workerMap *rmap.Map // worker creation times by ID - jobsMap *rmap.Map // jobs by worker ID - jobPayloadsMap *rmap.Map // job payloads by job key - keepAliveMap *rmap.Map // worker keep-alive timestamps indexed by ID - shutdownMap *rmap.Map // key is node ID that requested shutdown - tickerMap *rmap.Map // ticker next tick time indexed by name - workerTTL time.Duration // Worker considered dead if keep-alive not updated after this duration - workerShutdownTTL time.Duration // Worker considered dead if not shutdown after this duration - ackGracePeriod time.Duration // Wait for return status up to this duration - clientOnly bool - logger pulse.Logger - h hasher - stop chan struct{} // closed when node is stopped - wg sync.WaitGroup // allows to wait until all goroutines exit - rdb *redis.Client + ID string + PoolName string + poolStream *streaming.Stream // pool event stream for dispatching jobs + poolSink *streaming.Sink // pool event sink + nodeStream *streaming.Stream // node event stream for receiving worker events + nodeReader *streaming.Reader // node event reader + workerMap *rmap.Map // worker creation times by ID + jobsMap *rmap.Map // jobs by worker ID + jobPayloadsMap *rmap.Map // job payloads by job key + nodeKeepAliveMap *rmap.Map // node keep-alive timestamps indexed by ID + workerKeepAliveMap *rmap.Map // worker keep-alive timestamps indexed by ID + shutdownMap *rmap.Map // key is node ID that requested shutdown + tickerMap *rmap.Map // ticker next tick time indexed by name + workerTTL time.Duration // Worker considered dead if keep-alive not updated after this duration + workerShutdownTTL time.Duration // Worker considered dead if not shutdown after this duration + ackGracePeriod time.Duration // Wait for return status up to this duration + clientOnly bool + logger pulse.Logger + h hasher + stop chan struct{} // closed when node is stopped + closed chan struct{} // closed when node is closed + wg sync.WaitGroup // allows to wait until all goroutines exit + rdb *redis.Client localWorkers sync.Map // workers created by this node workerStreams sync.Map // worker streams indexed by ID @@ -110,6 +112,7 @@ func AddNode(ctx context.Context, poolName string, rdb *redis.Client, opts ...No "worker_ttl", o.workerTTL, "worker_shutdown_ttl", o.workerShutdownTTL, "ack_grace_period", o.ackGracePeriod) + wsm, err := rmap.Join(ctx, shutdownMapName(poolName), rdb, rmap.WithLogger(logger)) if err != nil { return nil, fmt.Errorf("AddNode: failed to join shutdown replicated map %q: %w", shutdownMapName(poolName), err) @@ -117,22 +120,35 @@ func AddNode(ctx context.Context, poolName string, rdb *redis.Client, opts ...No if wsm.Len() > 0 { return nil, fmt.Errorf("AddNode: pool %q is shutting down", poolName) } + + nkm, err := rmap.Join(ctx, nodeKeepAliveMapName(poolName), rdb, rmap.WithLogger(logger)) + if err != nil { + return nil, fmt.Errorf("AddNode: failed to join node keep-alive map %q: %w", nodeKeepAliveMapName(poolName), err) + } + if _, err := nkm.Set(ctx, nodeID, strconv.FormatInt(time.Now().UnixNano(), 10)); err != nil { + return nil, fmt.Errorf("AddNode: failed to set initial node keep-alive: %w", err) + } + poolStream, err := streaming.NewStream(poolStreamName(poolName), rdb, soptions.WithStreamMaxLen(o.maxQueuedJobs), soptions.WithStreamLogger(logger)) if err != nil { return nil, fmt.Errorf("AddNode: failed to create pool job stream %q: %w", poolStreamName(poolName), err) } + var ( - wm *rmap.Map - jm *rmap.Map - jpm *rmap.Map - km *rmap.Map - tm *rmap.Map + wm *rmap.Map + jm *rmap.Map + jpm *rmap.Map + km *rmap.Map + tm *rmap.Map + poolSink *streaming.Sink nodeStream *streaming.Stream nodeReader *streaming.Reader + closed chan struct{} ) + if !o.clientOnly { wm, err = rmap.Join(ctx, workerMapName(poolName), rdb, rmap.WithLogger(logger)) if err != nil { @@ -140,80 +156,94 @@ func AddNode(ctx context.Context, poolName string, rdb *redis.Client, opts ...No } workerIDs := wm.Keys() logger.Info("joined", "workers", workerIDs) + jm, err = rmap.Join(ctx, jobsMapName(poolName), rdb, rmap.WithLogger(logger)) if err != nil { return nil, fmt.Errorf("AddNode: failed to join pool jobs replicated map %q: %w", jobsMapName(poolName), err) } + jpm, err = rmap.Join(ctx, jobPayloadsMapName(poolName), rdb, rmap.WithLogger(logger)) if err != nil { return nil, fmt.Errorf("AddNode: failed to join pool job payloads replicated map %q: %w", jobPayloadsMapName(poolName), err) } - km, err = rmap.Join(ctx, keepAliveMapName(poolName), rdb, rmap.WithLogger(logger)) + + km, err = rmap.Join(ctx, workerKeepAliveMapName(poolName), rdb, rmap.WithLogger(logger)) if err != nil { - return nil, fmt.Errorf("AddNode: failed to join pool keep-alive replicated map %q: %w", keepAliveMapName(poolName), err) + return nil, fmt.Errorf("AddNode: failed to join worker keep-alive replicated map %q: %w", workerKeepAliveMapName(poolName), err) } + tm, err = rmap.Join(ctx, tickerMapName(poolName), rdb, rmap.WithLogger(logger)) if err != nil { return nil, fmt.Errorf("AddNode: failed to join pool ticker replicated map %q: %w", tickerMapName(poolName), err) } + poolSink, err = poolStream.NewSink(ctx, "events", soptions.WithSinkBlockDuration(o.jobSinkBlockDuration), soptions.WithSinkAckGracePeriod(o.ackGracePeriod)) if err != nil { return nil, fmt.Errorf("AddNode: failed to create events sink for stream %q: %w", poolStreamName(poolName), err) } + closed = make(chan struct{}) } + nodeStream, err = streaming.NewStream(nodeStreamName(poolName, nodeID), rdb, soptions.WithStreamLogger(logger)) if err != nil { return nil, fmt.Errorf("AddNode: failed to create node event stream %q: %w", nodeStreamName(poolName, nodeID), err) } + nodeReader, err = nodeStream.NewReader(ctx, soptions.WithReaderBlockDuration(o.jobSinkBlockDuration), soptions.WithReaderStartAtOldest()) if err != nil { return nil, fmt.Errorf("AddNode: failed to create node event reader for stream %q: %w", nodeStreamName(poolName, nodeID), err) } p := &Node{ - NodeID: nodeID, - PoolName: poolName, - keepAliveMap: km, - workerMap: wm, - jobsMap: jm, - jobPayloadsMap: jpm, - shutdownMap: wsm, - tickerMap: tm, - workerStreams: sync.Map{}, - pendingJobs: sync.Map{}, - pendingEvents: sync.Map{}, - poolStream: poolStream, - poolSink: poolSink, - nodeStream: nodeStream, - nodeReader: nodeReader, - clientOnly: o.clientOnly, - workerTTL: o.workerTTL, - workerShutdownTTL: o.workerShutdownTTL, - ackGracePeriod: o.ackGracePeriod, - h: jumpHash{crc64.New(crc64.MakeTable(crc64.ECMA))}, - stop: make(chan struct{}), - rdb: rdb, - logger: logger, + ID: nodeID, + PoolName: poolName, + nodeKeepAliveMap: nkm, + workerKeepAliveMap: km, + workerMap: wm, + jobsMap: jm, + jobPayloadsMap: jpm, + shutdownMap: wsm, + tickerMap: tm, + workerStreams: sync.Map{}, + pendingJobs: sync.Map{}, + pendingEvents: sync.Map{}, + poolStream: poolStream, + poolSink: poolSink, + nodeStream: nodeStream, + nodeReader: nodeReader, + clientOnly: o.clientOnly, + workerTTL: o.workerTTL, + workerShutdownTTL: o.workerShutdownTTL, + ackGracePeriod: o.ackGracePeriod, + h: jumpHash{crc64.New(crc64.MakeTable(crc64.ECMA))}, + stop: make(chan struct{}), + closed: closed, + rdb: rdb, + logger: logger, } nch := nodeReader.Subscribe() if o.clientOnly { logger.Info("client-only") - p.wg.Add(1) - go p.handleNodeEvents(nch) // to handle job acks + p.wg.Add(3) + pulse.Go(ctx, func() { p.handleNodeEvents(ctx, nch) }) // to handle job acks + pulse.Go(ctx, func() { p.processInactiveNodes(ctx) }) + pulse.Go(ctx, func() { p.updateNodeKeepAlive(ctx) }) return p, nil } - p.wg.Add(5) - pch := poolSink.Subscribe() - pulse.Go(ctx, func() { p.handlePoolEvents(pch) }) // handleXXX handles streaming events - pulse.Go(ctx, func() { p.handleNodeEvents(nch) }) - pulse.Go(ctx, func() { p.manageWorkers(ctx) }) // manageXXX handles map updates - pulse.Go(ctx, func() { p.manageShutdown(ctx) }) - pulse.Go(ctx, func() { p.manageInactiveWorkers(ctx) }) + p.wg.Add(7) + pulse.Go(ctx, func() { p.handlePoolEvents(ctx, poolSink.Subscribe()) }) + pulse.Go(ctx, func() { p.handleNodeEvents(ctx, nch) }) + pulse.Go(ctx, func() { p.watchWorkers(ctx) }) + pulse.Go(ctx, func() { p.watchShutdown(ctx) }) + pulse.Go(ctx, func() { p.processInactiveNodes(ctx) }) + pulse.Go(ctx, func() { p.processInactiveWorkers(ctx) }) + pulse.Go(ctx, func() { p.updateNodeKeepAlive(ctx) }) + return p, nil } @@ -239,12 +269,11 @@ func (node *Node) AddWorker(ctx context.Context, handler JobHandler) (*Worker, e // RemoveWorker stops the worker, removes it from the pool and requeues all its // jobs. func (node *Node) RemoveWorker(ctx context.Context, w *Worker) error { - w.stopAndWait(ctx) + w.stop(ctx) if err := w.requeueJobs(ctx); err != nil { node.logger.Error(fmt.Errorf("RemoveWorker: failed to requeue jobs for worker %q: %w", w.ID, err)) } - w.cleanup(ctx) - node.workerStreams.Delete(w.ID) + node.cleanupWorker(ctx, w.ID) node.localWorkers.Delete(w.ID) node.logger.Info("removed worker", "worker", w.ID) return nil @@ -258,14 +287,13 @@ func (node *Node) Workers() []*Worker { workers = append(workers, &Worker{ ID: w.ID, CreatedAt: w.CreatedAt, - Node: node, }) return true }) return workers } -// PoolWorkers returns the list of workers running in the pool. +// PoolWorkers returns the list of workers running in the entire pool. func (node *Node) PoolWorkers() []*Worker { workers := node.workerMap.Map() poolWorkers := make([]*Worker, 0, len(workers)) @@ -275,16 +303,16 @@ func (node *Node) PoolWorkers() []*Worker { node.logger.Error(fmt.Errorf("PoolWorkers: failed to parse createdAt %q for worker %q: %w", createdAt, id, err)) continue } - poolWorkers = append(poolWorkers, &Worker{ID: id, CreatedAt: time.Unix(0, cat), Node: node}) + poolWorkers = append(poolWorkers, &Worker{ID: id, CreatedAt: time.Unix(0, cat)}) } return poolWorkers } -// DispatchJob dispatches a job to the proper worker in the pool. +// DispatchJob dispatches a job to the worker in the pool that is assigned to +// the job key using consistent hashing. // It returns: // - nil if the job is successfully dispatched and started by a worker // - an error returned by the worker's start handler if the job fails to start -// - the context error if the context is canceled before the job is started // - an error if the pool is closed or if there's a failure in adding the job // // The method blocks until one of the above conditions is met. @@ -293,7 +321,7 @@ func (node *Node) DispatchJob(ctx context.Context, key string, payload []byte) e return fmt.Errorf("DispatchJob: pool %q is closed", node.PoolName) } - job := marshalJob(&Job{Key: key, Payload: payload, CreatedAt: time.Now(), NodeID: node.NodeID}) + job := marshalJob(&Job{Key: key, Payload: payload, CreatedAt: time.Now(), NodeID: node.ID}) eventID, err := node.poolStream.Add(ctx, evStartJob, job) if err != nil { return fmt.Errorf("DispatchJob: failed to add job to stream %q: %w", node.poolStream.Name, err) @@ -316,12 +344,13 @@ func (node *Node) DispatchJob(ctx context.Context, key string, payload []byte) e node.pendingJobs.Delete(eventID) close(cherr) - if err == nil { - node.logger.Info("dispatched", "key", key) - } else { + if err != nil { node.logger.Error(fmt.Errorf("DispatchJob: failed to dispatch job: %w", err), "key", key) + return err } - return err + + node.logger.Info("dispatched", "key", key) + return nil } // StopJob stops the job with the given key. @@ -349,7 +378,7 @@ func (node *Node) JobKeys() []string { // JobPayload returns the payload of the job with the given key. // It returns: // - (payload, true) if the job exists and has a payload -// - (nil, true) if the job exists but has no payload (empty payload) +// - (nil, true) if the job exists but has an empty payload // - (nil, false) if the job does not exist func (node *Node) JobPayload(key string) ([]byte, bool) { payload, ok := node.jobPayloadsMap.Get(key) @@ -375,7 +404,10 @@ func (node *Node) NotifyWorker(ctx context.Context, key string, payload []byte) // Shutdown stops the pool workers gracefully across all nodes. It notifies all // workers and waits until they are completed. Shutdown prevents the pool nodes -// from creating new workers and the pool workers from accepting new jobs. +// from creating new workers and the pool workers from accepting new jobs. After +// Shutdown returns, the node object cannot be used anymore and should be +// discarded. One of Shutdown or Close should be called before the node is +// garbage collected unless it is client-only. func (node *Node) Shutdown(ctx context.Context) error { if node.IsClosed() { return nil @@ -383,45 +415,24 @@ func (node *Node) Shutdown(ctx context.Context) error { if node.clientOnly { return fmt.Errorf("Shutdown: client-only node cannot shutdown worker pool") } - node.logger.Info("shutting down") // Signal all nodes to shutdown. - if _, err := node.shutdownMap.SetAndWait(ctx, "shutdown", node.NodeID); err != nil { + if _, err := node.shutdownMap.SetAndWait(ctx, "shutdown", node.ID); err != nil { node.logger.Error(fmt.Errorf("Shutdown: failed to set shutdown status in shutdown map: %w", err)) } + <-node.closed // Wait for this node to be closed + node.cleanupPool(ctx) - <-node.stop // Wait for this node to be closed - - // Destroy the pool stream. - if err := node.poolStream.Destroy(ctx); err != nil { - node.logger.Error(fmt.Errorf("Shutdown: failed to destroy pool stream: %w", err)) - } - - // Cleanup the jobs payloads map as ongoing requeues could prevent it - // from being cleaned up by the workers. - if err := node.jobPayloadsMap.Reset(ctx); err != nil { - node.logger.Error(fmt.Errorf("Shutdown: failed to reset job payloads map: %w", err)) - } - - // Now clean up the shutdown replicated map. - wsm, err := rmap.Join(ctx, shutdownMapName(node.PoolName), node.rdb, rmap.WithLogger(node.logger)) - if err != nil { - node.logger.Error(fmt.Errorf("Shutdown: failed to join shutdown map for cleanup: %w", err)) - } - if err := wsm.Reset(ctx); err != nil { - node.logger.Error(fmt.Errorf("Shutdown: failed to reset shutdown map: %w", err)) - } - - node.logger.Info("shutdown complete") + node.logger.Info("shutdown") return nil } -// Close stops the pool node workers and closes the Redis connection but does +// Close stops the node workers and closes the Redis connection but does // not stop workers running in other nodes. It requeues all the jobs run by -// workers of the node. One of Shutdown or Close should be called before the +// workers of the node. One of Shutdown or Close should be called before the // node is garbage collected unless it is client-only. func (node *Node) Close(ctx context.Context) error { - return node.close(ctx, true) + return node.close(ctx, false) } // IsShutdown returns true if the pool is shutdown. @@ -438,9 +449,12 @@ func (node *Node) IsClosed() bool { return node.closing } -// close is the internal implementation of Close. It handles the actual closing -// process and optionally requeues jobs. -func (node *Node) close(ctx context.Context, requeue bool) error { +// close stops the node and its workers, optionally requeuing jobs. If shutdown +// is true, jobs are not requeued as the pool is being shutdown. Otherwise, jobs +// are requeued to be picked up by other nodes. The method stops all workers, +// waits for background goroutines to complete, cleans up resources and closes +// connections. It is idempotent and can be called multiple times safely. +func (node *Node) close(ctx context.Context, shutdown bool) error { node.lock.Lock() if node.closing { node.lock.Unlock() @@ -449,101 +463,81 @@ func (node *Node) close(ctx context.Context, requeue bool) error { node.closing = true node.lock.Unlock() - node.logger.Info("closing") - - // If we're not requeuing then stop all the jobs. - if !requeue { - node.logger.Info("stopping all jobs") - var wg sync.WaitGroup - var total atomic.Int32 - node.localWorkers.Range(func(key, value any) bool { - wg.Add(1) - worker := value.(*Worker) - pulse.Go(ctx, func() { - defer wg.Done() - for _, job := range worker.Jobs() { - if err := worker.stopJob(ctx, job.Key, false); err != nil { - node.logger.Error(fmt.Errorf("Close: failed to stop job %q for worker %q: %w", job.Key, worker.ID, err)) - } - total.Add(1) - } - }) - return true - }) - wg.Wait() - node.logger.Info("stopped all jobs", "total", total.Load()) + // If we're shutting down then stop all the jobs. + if shutdown { + node.stopAllJobs(ctx) } - // Need to stop workers before requeueing jobs to prevent - // requeued jobs from being handled by this node. + // Stop all workers before waiting for goroutines var wg sync.WaitGroup node.localWorkers.Range(func(key, value any) bool { wg.Add(1) - worker := value.(*Worker) pulse.Go(ctx, func() { defer wg.Done() - worker.stopAndWait(ctx) + value.(*Worker).stop(ctx) + // Remove worker immediately to avoid job requeuing by other nodes + node.cleanupWorker(ctx, value.(*Worker).ID) }) return true }) wg.Wait() - node.logger.Debug("workers stopped") - - // Requeue jobs. - if requeue { - var wg sync.WaitGroup - node.localWorkers.Range(func(key, value any) bool { - wg.Add(1) - worker := value.(*Worker) - pulse.Go(ctx, func() { - defer wg.Done() - if err := worker.requeueJobs(ctx); err != nil { - node.logger.Error(fmt.Errorf("Close: failed to requeue jobs for worker %q: %w", worker.ID, err)) - return - } - }) - return true - }) - wg.Wait() - } - // Cleanup - node.localWorkers.Range(func(key, value any) bool { - worker := value.(*Worker) - worker.cleanup(ctx) - return true - }) - if !node.clientOnly { - node.poolSink.Close() - node.tickerMap.Close() - node.keepAliveMap.Close() - } - node.nodeReader.Close() - if err := node.nodeStream.Destroy(ctx); err != nil { - node.logger.Error(fmt.Errorf("Close: failed to destroy node event stream: %w", err)) - } + // Stop all goroutines close(node.stop) - - // Wait for all goroutines to exit. node.wg.Wait() + // Requeue jobs if not shutting down, after stopping goroutines to avoid receiving new jobs + if !shutdown { + if err := node.requeueAllJobs(ctx); err != nil { + node.logger.Error(fmt.Errorf("close: failed to requeue jobs: %w", err)) + } + } + + // Cleanup resources + node.cleanupNode(ctx) + + // Signal that the node is closed + close(node.closed) + node.logger.Info("closed") return nil } +// stopAllJobs stops all jobs running on the node. +func (node *Node) stopAllJobs(ctx context.Context) { + var wg sync.WaitGroup + var total atomic.Int32 + node.localWorkers.Range(func(key, value any) bool { + wg.Add(1) + worker := value.(*Worker) + pulse.Go(ctx, func() { + defer wg.Done() + for _, job := range worker.Jobs() { + if err := worker.stopJob(ctx, job.Key, false); err != nil { + node.logger.Error(fmt.Errorf("Close: failed to stop job %q for worker %q: %w", job.Key, worker.ID, err)) + } + total.Add(1) + } + }) + return true + }) + wg.Wait() + node.logger.Info("stopped all jobs", "total", total.Load()) +} + // handlePoolEvents reads events from the pool job stream. -func (node *Node) handlePoolEvents(c <-chan *streaming.Event) { +func (node *Node) handlePoolEvents(ctx context.Context, c <-chan *streaming.Event) { defer node.wg.Done() - defer node.logger.Debug("handlePoolEvents: exiting") - ctx := context.Background() - for ev := range c { - if node.IsClosed() { - node.logger.Info("ignoring event, node is closed", "event", ev.EventName, "id", ev.ID) - continue - } - node.logger.Debug("routing", "event", ev.EventName, "id", ev.ID) - if err := node.routeWorkerEvent(ctx, ev); err != nil { - node.logger.Error(fmt.Errorf("handlePoolEvents: failed to route event: %w, will retry after %v", err, node.ackGracePeriod), "event", ev.EventName, "id", ev.ID) + + for { + select { + case ev := <-c: + if err := node.routeWorkerEvent(ctx, ev); err != nil { + node.logger.Error(fmt.Errorf("handlePoolEvents: failed to route event: %w", err)) + } + case <-node.stop: + node.poolSink.Close(ctx) + return } } } @@ -565,7 +559,7 @@ func (node *Node) routeWorkerEvent(ctx context.Context, ev *streaming.Event) err } var eventID string - eventID, err = stream.Add(ctx, ev.EventName, marshalEnvelope(node.NodeID, ev.Payload)) + eventID, err = stream.Add(ctx, ev.EventName, marshalEnvelope(node.ID, ev.Payload)) if err != nil { return fmt.Errorf("routeWorkerEvent: failed to add event %s to worker stream %q: %w", ev.EventName, workerStreamName(wid), err) } @@ -579,32 +573,34 @@ func (node *Node) routeWorkerEvent(ctx context.Context, ev *streaming.Event) err // handleNodeEvents reads events from the node event stream and acks the pending // events that correspond to jobs that are now running or done. -func (node *Node) handleNodeEvents(c <-chan *streaming.Event) { +func (node *Node) handleNodeEvents(ctx context.Context, c <-chan *streaming.Event) { defer node.wg.Done() - defer node.logger.Debug("handleNodeEvents: exiting") - ctx := context.Background() + for { select { - case ev, ok := <-c: - if !ok { - return - } - switch ev.EventName { - case evAck: - // Event sent by worker to ack a dispatched job. - node.logger.Debug("handleNodeEvents: received ack", "event", ev.EventName, "id", ev.ID) - node.ackWorkerEvent(ctx, ev) - case evDispatchReturn: - // Event sent by pool node to node that originally dispatched the job. - node.logger.Debug("handleNodeEvents: received dispatch return", "event", ev.EventName, "id", ev.ID) - node.returnDispatchStatus(ctx, ev) - } + case ev := <-c: + node.processNodeEvent(ctx, ev) case <-node.stop: - go node.nodeReader.Close() // Close nodeReader in a separate goroutine to avoid blocking + node.nodeReader.Close() + return } } } +// processNodeEvent processes a node event. +func (node *Node) processNodeEvent(ctx context.Context, ev *streaming.Event) { + switch ev.EventName { + case evAck: + // Event sent by worker to ack a dispatched job. + node.logger.Debug("handleNodeEvents: received ack", "event", ev.EventName, "id", ev.ID) + node.ackWorkerEvent(ctx, ev) + case evDispatchReturn: + // Event sent by pool node to node that originally dispatched the job. + node.logger.Debug("handleNodeEvents: received dispatch return", "event", ev.EventName, "id", ev.ID) + node.returnDispatchStatus(ctx, ev) + } +} + // ackWorkerEvent acks the pending event that corresponds to the acked job. If // the event was a dispatched job then it sends a dispatch return event to the // node that dispatched the job. @@ -629,7 +625,7 @@ func (node *Node) ackWorkerEvent(ctx context.Context, ev *streaming.Event) { return } ack.EventID = pending.ID - if _, err := stream.Add(ctx, evDispatchReturn, marshalAck(ack), soptions.WithOnlyIfStreamExists()); err != nil { + if _, err := stream.Add(ctx, evDispatchReturn, marshalAck(ack)); err != nil { node.logger.Error(fmt.Errorf("ackWorkerEvent: failed to dispatch return to stream %q: %w", nodeStreamName(node.PoolName, nodeID), err)) } } @@ -646,11 +642,11 @@ func (node *Node) ackWorkerEvent(ctx context.Context, ev *streaming.Event) { ev := value.(*streaming.Event) if time.Since(ev.CreatedAt()) > pendingEventTTL { staleKeys = append(staleKeys, key.(string)) + node.logger.Error(fmt.Errorf("ackWorkerEvent: stale event, removing from pending events"), "event", ev.EventName, "id", ev.ID, "since", time.Since(ev.CreatedAt()), "TTL", pendingEventTTL) } return true }) for _, key := range staleKeys { - node.logger.Error(fmt.Errorf("ackWorkerEvent: stale event, removing from pending events"), "event", ev.EventName, "id", ev.ID, "since", time.Since(ev.CreatedAt()), "TTL", pendingEventTTL) node.pendingEvents.Delete(key) } } @@ -664,9 +660,8 @@ func (node *Node) returnDispatchStatus(_ context.Context, ev *streaming.Event) { return } node.logger.Debug("dispatch return", "event", ev.EventName, "id", ev.ID, "ack-id", ack.EventID) - cherr := val.(chan error) - if cherr == nil { - // Event was requeued. + if val == nil { + // Event was requeued, just clean up node.pendingJobs.Delete(ack.EventID) return } @@ -674,24 +669,20 @@ func (node *Node) returnDispatchStatus(_ context.Context, ev *streaming.Event) { if ack.Error != "" { err = errors.New(ack.Error) } - cherr <- err + val.(chan error) <- err } -// manageWorkers monitors the workers replicated map and triggers job rebalancing +// watches monitors the workers replicated map and triggers job rebalancing // when workers are added or removed from the pool. -func (node *Node) manageWorkers(ctx context.Context) { +func (node *Node) watchWorkers(ctx context.Context) { defer node.wg.Done() - defer node.logger.Debug("manageWorkers: exiting") - defer node.workerMap.Close() - - ch := node.workerMap.Subscribe() for { select { - case <-ch: - node.logger.Debug("manageWorkers: worker map updated") - node.handleWorkerMapUpdate(ctx) case <-node.stop: return + case <-node.workerMap.Subscribe(): + node.logger.Debug("watchWorkers: worker map updated") + node.handleWorkerMapUpdate(ctx) } } } @@ -708,8 +699,10 @@ func (node *Node) handleWorkerMapUpdate(ctx context.Context) { // If it's not in the worker map, then it's not active and its jobs // have already been requeued. node.logger.Info("handleWorkerMapUpdate: removing inactive local worker", "worker", worker.ID) - node.deleteWorker(ctx, worker.ID) - worker.stopAndWait(ctx) + if err := node.deleteWorker(ctx, worker.ID); err != nil { + node.logger.Error(fmt.Errorf("handleWorkerMapUpdate: failed to delete inactive worker %q: %w", worker.ID, err), "worker", worker.ID) + } + worker.stop(ctx) node.localWorkers.Delete(key) return true } @@ -737,7 +730,7 @@ func (node *Node) requeueJob(ctx context.Context, workerID string, job *Job) (ch return nil, nil } node.logger.Debug("requeuing job", "key", job.Key, "worker", workerID) - job.NodeID = node.NodeID + job.NodeID = node.ID eventID, err := node.poolStream.Add(ctx, evStartJob, marshalJob(job)) if err != nil { @@ -751,27 +744,23 @@ func (node *Node) requeueJob(ctx context.Context, workerID string, job *Job) (ch return cherr, nil } -// manageShutdown monitors the pool shutdown map and initiates node shutdown when updated. -func (node *Node) manageShutdown(ctx context.Context) { +// watchShutdown monitors the pool shutdown map and initiates node shutdown when updated. +func (node *Node) watchShutdown(ctx context.Context) { defer node.wg.Done() - defer node.logger.Debug("manageShutdown: exiting") - defer node.shutdownMap.Close() - - ch := node.shutdownMap.Subscribe() for { select { - case <-ch: - node.logger.Debug("manageShutdown: shutdown map updated, initiating shutdown") - node.handleShutdownMapUpdate(ctx) case <-node.stop: return + case <-node.shutdownMap.Subscribe(): + node.logger.Debug("watchShutdown: shutdown map updated") + // Handle shutdown in a separate goroutine to allow this one to exit + pulse.Go(ctx, func() { node.handleShutdown(ctx) }) } } } -// handleShutdownMapUpdate is called when the shutdown map is updated and closes -// the node. -func (node *Node) handleShutdownMapUpdate(ctx context.Context) { +// handleShutdown closes the node. +func (node *Node) handleShutdown(ctx context.Context) { if node.IsClosed() { return } @@ -781,44 +770,77 @@ func (node *Node) handleShutdownMapUpdate(ctx context.Context) { // There is only one value in the map requestingNode = node } - node.logger.Info("shutdown", "requested-by", requestingNode) - node.close(ctx, false) + node.logger.Debug("handleShutdown: shutting down", "requested-by", requestingNode) + node.close(ctx, true) node.lock.Lock() node.shutdown = true node.lock.Unlock() - node.logger.Info("shutdown") + node.logger.Info("shutdown", "requested-by", requestingNode) } -// manageInactiveWorkers periodically checks for inactive workers and requeues their jobs. -func (node *Node) manageInactiveWorkers(ctx context.Context) { +// processInactiveNodes periodically checks for inactive nodes and destroys their streams. +func (node *Node) processInactiveNodes(ctx context.Context) { defer node.wg.Done() - defer node.logger.Debug("manageInactiveWorkers: exiting") ticker := time.NewTicker(node.workerTTL) defer ticker.Stop() + for { select { case <-node.stop: return case <-ticker.C: - node.processInactiveWorkers(ctx) + node.cleanupInactiveNodes(ctx) } } } -// processInactiveWorkers identifies and removes workers that have been inactive -// for longer than workerTTL. It then requeues any jobs associated with these -// inactive workers, ensuring that no work is lost. +func (node *Node) cleanupInactiveNodes(ctx context.Context) { + nodeMap := node.nodeKeepAliveMap.Map() + for nodeID, lastSeen := range nodeMap { + if nodeID == node.ID || node.isActive(lastSeen, node.workerTTL) { + continue + } + + node.logger.Info("cleaning up inactive node", "node", nodeID) + + // Clean up node's stream + stream := nodeStreamName(node.PoolName, nodeID) + if s, err := streaming.NewStream(stream, node.rdb, soptions.WithStreamLogger(node.logger)); err == nil { + if err := s.Destroy(ctx); err != nil { + node.logger.Error(fmt.Errorf("cleanupInactiveNodes: failed to destroy stream: %w", err)) + } + } + + // Remove from keep-alive map + if _, err := node.nodeKeepAliveMap.Delete(ctx, nodeID); err != nil { + node.logger.Error(fmt.Errorf("cleanupInactiveNodes: failed to delete node: %w", err)) + } + } +} + +// processInactiveWorkers periodically checks for inactive workers and requeues their jobs. func (node *Node) processInactiveWorkers(ctx context.Context) { - if node.IsClosed() { - return + defer node.wg.Done() + ticker := time.NewTicker(node.workerTTL) + defer ticker.Stop() + + for { + select { + case <-node.stop: + return + case <-ticker.C: + node.cleanupInactiveWorkers(ctx) + } } +} - alive := node.keepAliveMap.Map() +func (node *Node) cleanupInactiveWorkers(ctx context.Context) { + alive := node.workerKeepAliveMap.Map() for id, ls := range alive { lsi, err := strconv.ParseInt(ls, 10, 64) if err != nil { - node.logger.Error(fmt.Errorf("processInactiveWorkers: failed to parse last seen timestamp: %w", err), "worker", id) + node.logger.Error(fmt.Errorf("cleanupInactiveWorkers: failed to parse last seen timestamp: %w", err), "worker", id) continue } lastSeen := time.Unix(0, lsi) @@ -826,18 +848,18 @@ func (node *Node) processInactiveWorkers(ctx context.Context) { if lsd <= node.workerTTL { continue } - node.logger.Debug("processInactiveWorkers: removing worker", "worker", id, "last-seen", lsd, "ttl", node.workerTTL) + node.logger.Debug("cleanupInactiveWorkers: removing worker", "worker", id, "last-seen", lsd, "ttl", node.workerTTL) // Use optimistic locking to set the keep-alive timestamp to a value // in the future so that another node does not also requeue the jobs. next := lsi + node.workerTTL.Nanoseconds() - last, err := node.keepAliveMap.TestAndSet(ctx, id, ls, strconv.FormatInt(next, 10)) + last, err := node.workerKeepAliveMap.TestAndSet(ctx, id, ls, strconv.FormatInt(next, 10)) if err != nil { - node.logger.Error(fmt.Errorf("processInactiveWorkers: failed to set keep-alive timestamp: %w", err), "worker", id) + node.logger.Error(fmt.Errorf("cleanupInactiveWorkers: failed to set keep-alive timestamp: %w", err), "worker", id) continue } if last != ls { - node.logger.Debug("processInactiveWorkers: keep-alive timestamp for worker already set by another node", "worker", id) + node.logger.Debug("cleanupInactiveWorkers: keep-alive timestamp for worker already set by another node", "worker", id) continue } @@ -845,7 +867,7 @@ func (node *Node) processInactiveWorkers(ctx context.Context) { if !ok { // Worker has no jobs, so delete it right away. if err := node.deleteWorker(ctx, id); err != nil { - node.logger.Error(fmt.Errorf("processInactiveWorkers: failed to delete worker %q: %w", id, err), "worker", id) + node.logger.Error(fmt.Errorf("cleanupInactiveWorkers: failed to delete worker %q: %w", id, err), "worker", id) } continue } @@ -856,11 +878,11 @@ func (node *Node) processInactiveWorkers(ctx context.Context) { Key: key, Payload: []byte(payload), CreatedAt: time.Now(), - NodeID: node.NodeID, + NodeID: node.ID, } cherr, err := node.requeueJob(ctx, id, job) if err != nil { - node.logger.Error(fmt.Errorf("processInactiveWorkers: failed to requeue inactive job: %w", err), "job", job.Key, "worker", id) + node.logger.Error(fmt.Errorf("cleanupInactiveWorkers: failed to requeue inactive job: %w", err), "job", job.Key, "worker", id) continue } requeued[job.Key] = cherr @@ -868,9 +890,40 @@ func (node *Node) processInactiveWorkers(ctx context.Context) { allRequeued := len(requeued) == len(keys) if !allRequeued { - node.logger.Error(fmt.Errorf("processInactiveWorkers: failed to requeue inactive jobs: %d/%d, will retry later", len(requeued), len(keys)), "worker", id) + node.logger.Error(fmt.Errorf("cleanupInactiveWorkers: failed to requeue inactive jobs: %d/%d, will retry later", len(requeued), len(keys)), "worker", id) + } + if len(requeued) > 0 { + pulse.Go(ctx, func() { node.processRequeuedJobs(ctx, id, requeued, allRequeued) }) + } + } +} + +// isActive checks if a timestamp is within TTL +func (node *Node) isActive(lastSeen string, ttl time.Duration) bool { + lsi, err := strconv.ParseInt(lastSeen, 10, 64) + if err != nil { + node.logger.Error(fmt.Errorf("isActive: failed to parse last seen timestamp: %w", err)) + return false + } + return time.Since(time.Unix(0, lsi)) <= ttl +} + +// Keep node alive +func (node *Node) updateNodeKeepAlive(ctx context.Context) { + defer node.wg.Done() + ticker := time.NewTicker(node.workerTTL / 2) + defer ticker.Stop() + + for { + select { + case <-node.stop: + return + case <-ticker.C: + if _, err := node.nodeKeepAliveMap.Set(ctx, node.ID, + strconv.FormatInt(time.Now().UnixNano(), 10)); err != nil { + node.logger.Error(fmt.Errorf("updateNodeKeepAlive: failed to update timestamp: %w", err)) + } } - go node.processRequeuedJobs(ctx, id, requeued, allRequeued) } } @@ -880,7 +933,7 @@ func (node *Node) processRequeuedJobs(ctx context.Context, id string, requeued m var succeeded int64 for key, cherr := range requeued { wg.Add(1) - go func(key string, cherr chan error) { + pulse.Go(ctx, func() { defer wg.Done() select { case err := <-cherr: @@ -890,9 +943,9 @@ func (node *Node) processRequeuedJobs(ctx context.Context, id string, requeued m } atomic.AddInt64(&succeeded, 1) case <-time.After(node.workerTTL): - node.logger.Error(fmt.Errorf("processRequeuedJobs: timeout waiting for requeue result for job"), "job", key, "worker", id) + node.logger.Error(fmt.Errorf("processRequeuedJobs: timeout waiting for requeue result"), "job", key, "worker", id, "timout", node.workerTTL) } - }(key, cherr) + }) } wg.Wait() @@ -931,7 +984,7 @@ func (node *Node) activeWorkers() []string { }) // Then filter out workers that have not been seen for more than workerTTL. - alive := node.keepAliveMap.Map() + alive := node.workerKeepAliveMap.Map() var activeIDs []string for _, id := range sortedIDs { ls, ok := alive[id] @@ -957,7 +1010,7 @@ func (node *Node) activeWorkers() []string { // deleteWorker removes a remote worker from the pool deleting the worker stream. func (node *Node) deleteWorker(ctx context.Context, id string) error { node.logger.Debug("deleteWorker: deleting worker", "worker", id) - if _, err := node.keepAliveMap.Delete(ctx, id); err != nil { + if _, err := node.workerKeepAliveMap.Delete(ctx, id); err != nil { node.logger.Error(fmt.Errorf("deleteWorker: failed to delete worker %q from keep-alive map: %w", id, err)) } if _, err := node.workerMap.Delete(ctx, id); err != nil { @@ -991,6 +1044,88 @@ func (node *Node) workerStream(_ context.Context, id string) (*streaming.Stream, return val.(*streaming.Stream), nil } +// cleanup removes the worker from all pool maps. +func (node *Node) cleanupWorker(ctx context.Context, id string) { + if _, err := node.workerMap.Delete(ctx, id); err != nil { + node.logger.Error(fmt.Errorf("failed to remove worker %s from worker map: %w", id, err)) + } + if _, err := node.workerKeepAliveMap.Delete(ctx, id); err != nil { + node.logger.Error(fmt.Errorf("failed to remove worker %s from keep alive map: %w", id, err)) + } + if _, err := node.jobsMap.Delete(ctx, id); err != nil { + node.logger.Error(fmt.Errorf("failed to remove worker %s from jobs map: %w", id, err)) + } + node.workerStreams.Delete(id) +} + +// requeueAllJobs requeues all jobs from all local workers in parallel. It waits for all +// requeue operations to complete before returning. If any requeue operations fail, it +// collects all errors and returns them as a single error. This is typically called +// during node close to ensure no jobs are lost. +func (node *Node) requeueAllJobs(ctx context.Context) error { + var wg sync.WaitGroup + var errs []error + var errLock sync.Mutex + + node.localWorkers.Range(func(key, value any) bool { + wg.Add(1) + pulse.Go(ctx, func() { + defer wg.Done() + if err := value.(*Worker).requeueJobs(ctx); err != nil { + errLock.Lock() + errs = append(errs, err) + errLock.Unlock() + } + }) + return true + }) + wg.Wait() + + if len(errs) > 0 { + return fmt.Errorf("failed to requeue %d jobs: %v", len(errs), errs) + } + return nil +} + +// cleanupPool removes the pool resources from Redis. +func (node *Node) cleanupPool(ctx context.Context) { + for _, m := range node.maps() { + if m != nil { + if err := m.Reset(ctx); err != nil { + node.logger.Error(fmt.Errorf("cleanupPool: failed to reset map: %w", err)) + } + } + } + if err := node.poolStream.Destroy(ctx); err != nil { + node.logger.Error(fmt.Errorf("cleanupPool: failed to destroy pool stream: %w", err)) + } +} + +// cleanupNode closes the node resources. +func (node *Node) cleanupNode(ctx context.Context) { + for _, m := range node.maps() { + if m != nil { + m.Close() + } + } + if err := node.nodeStream.Destroy(ctx); err != nil { + node.logger.Error(fmt.Errorf("cleanupNode: failed to destroy node stream: %w", err)) + } +} + +// maps returns the maps managed by the node. +func (node *Node) maps() []*rmap.Map { + return []*rmap.Map{ + node.jobPayloadsMap, + node.jobsMap, + node.nodeKeepAliveMap, + node.workerKeepAliveMap, + node.shutdownMap, + node.tickerMap, + node.workerMap, + } +} + // Hash implements the Jump Consistent Hash algorithm. // See https://arxiv.org/ftp/arxiv/papers/1406/1406.2294.pdf for details. func (jh jumpHash) Hash(key string, numBuckets int64) int64 { @@ -1009,6 +1144,12 @@ func (jh jumpHash) Hash(key string, numBuckets int64) int64 { return b } +// nodeKeepAliveMapName returns the name of the replicated map used to store the +// node keep-alive timestamps. +func nodeKeepAliveMapName(pool string) string { + return fmt.Sprintf("%s:node-keepalive", pool) +} + // workerMapName returns the name of the replicated map used to store the // worker creation timestamps. func workerMapName(pool string) string { @@ -1027,10 +1168,10 @@ func jobPayloadsMapName(pool string) string { return fmt.Sprintf("%s:job-payloads", pool) } -// keepAliveMapName returns the name of the replicated map used to store the +// workerKeepAliveMapName returns the name of the replicated map used to store the // worker keep-alive timestamps. -func keepAliveMapName(pool string) string { - return fmt.Sprintf("%s:keepalive", pool) +func workerKeepAliveMapName(pool string) string { + return fmt.Sprintf("%s:worker-keepalive", pool) } // tickerMapName returns the name of the replicated map used to store ticker diff --git a/pool/node_test.go b/pool/node_test.go index df2a226..23b6023 100644 --- a/pool/node_test.go +++ b/pool/node_test.go @@ -3,7 +3,9 @@ package pool import ( "context" "fmt" + "strconv" "strings" + "sync" "testing" "time" @@ -563,6 +565,146 @@ func TestStaleEventsAreRemoved(t *testing.T) { }, max, delay, "Fresh event should still be present") } +func TestStaleNodeStreamCleanup(t *testing.T) { + var ( + ctx = ptesting.NewTestContext(t) + testName = strings.Replace(t.Name(), "/", "_", -1) + rdb = ptesting.NewRedisClient(t) + node1 = newTestNode(t, ctx, rdb, testName) + node2 = newTestNode(t, ctx, rdb, testName) + numJobs = 0 + ) + defer ptesting.CleanupRedis(t, rdb, false, testName) + + // Configure nodes to send jobs to specific workers + node1.h = &ptesting.Hasher{IndexFunc: func(key string, numBuckets int64) int64 { + numJobs++ + if numJobs > 2 { + return 0 // to avoid panics on cleanup where jobs get requeued + } + if key == "job1" { + return 0 // job1 goes to worker1 + } + return 1 // job2 goes to worker2 + }} + node2.h = node1.h + + // Create workers and dispatch jobs to both nodes to ensure streams exist + newTestWorker(t, ctx, node1) + newTestWorker(t, ctx, node2) + + // Make sure workers are registered with both nodes + require.Eventually(t, func() bool { + return len(node1.PoolWorkers()) == 2 && len(node2.PoolWorkers()) == 2 + }, max, delay, "Workers were not registered with both nodes") + + // Dispatch jobs to both nodes + assert.NoError(t, node1.DispatchJob(ctx, "job1", []byte("payload1"))) + assert.NoError(t, node2.DispatchJob(ctx, "job2", []byte("payload2"))) + + // Verify both streams exist initially + name1 := "pulse:stream:" + nodeStreamName(node1.PoolName, node1.ID) + name2 := "pulse:stream:" + nodeStreamName(node2.PoolName, node2.ID) + assert.Eventually(t, func() bool { + exists1, err1 := rdb.Exists(ctx, name1).Result() + exists2, err2 := rdb.Exists(ctx, name2).Result() + return err1 == nil && err2 == nil && exists1 == 1 && exists2 == 1 + }, max, delay, "Node streams should exist initially") + + // Set node2's last seen time to a stale value + close(node2.stop) + _, err := node2.nodeKeepAliveMap.Set(ctx, node2.ID, + strconv.FormatInt(time.Now().Add(-3*node2.workerTTL).UnixNano(), 10)) + assert.NoError(t, err) + node2.wg.Wait() + node2.stop = make(chan struct{}) // so we can close + + // Verify node2's stream gets cleaned up + assert.Eventually(t, func() bool { + exists, err := rdb.Exists(ctx, name2).Result() + return err == nil && exists == 0 + }, max, delay, "Stale node stream should have been cleaned up") + + // Verify node1's stream still exists + assert.Eventually(t, func() bool { + exists, err := rdb.Exists(ctx, name1).Result() + return err == nil && exists == 1 + }, max, delay, "Active node stream should still exist") + + // Verify node2 was removed from keep-alive map + assert.Eventually(t, func() bool { + _, exists := node1.nodeKeepAliveMap.Get(node2.ID) + return !exists + }, max, delay, "Stale node should have been removed from keep-alive map") + + // Clean up + assert.NoError(t, node2.Close(ctx)) + assert.NoError(t, node1.Shutdown(ctx)) +} + +func TestShutdownStopsAllJobs(t *testing.T) { + testName := strings.Replace(t.Name(), "/", "_", -1) + ctx := ptesting.NewTestContext(t) + rdb := ptesting.NewRedisClient(t) + defer ptesting.CleanupRedis(t, rdb, true, testName) + + // Create node and workers + node := newTestNode(t, ctx, rdb, testName) + worker1 := newTestWorker(t, ctx, node) + worker2 := newTestWorker(t, ctx, node) + + // Track stopped jobs + var stoppedJobs sync.Map + stopHandler := func(key string) error { + stoppedJobs.Store(key, true) + return nil + } + worker1.handler.(*mockHandler).stopFunc = stopHandler + worker2.handler.(*mockHandler).stopFunc = stopHandler + + // Dispatch multiple jobs + jobs := []struct { + key string + payload []byte + }{ + {key: "job1", payload: []byte("payload1")}, + {key: "job2", payload: []byte("payload2")}, + {key: "job3", payload: []byte("payload3")}, + {key: "job4", payload: []byte("payload4")}, + } + + // Configure node to distribute jobs between workers + node.h = &ptesting.Hasher{IndexFunc: func(key string, numBuckets int64) int64 { + if strings.HasSuffix(key, "1") || strings.HasSuffix(key, "2") { + return 0 // jobs 1 and 2 go to worker1 + } + return 1 // jobs 3 and 4 go to worker2 + }} + + // Dispatch all jobs + for _, job := range jobs { + require.NoError(t, node.DispatchJob(ctx, job.key, job.payload)) + } + + // Wait for jobs to be distributed + require.Eventually(t, func() bool { + return len(worker1.Jobs()) == 2 && len(worker2.Jobs()) == 2 + }, max, delay, "Jobs were not distributed correctly") + + // Shutdown the node + assert.NoError(t, node.Shutdown(ctx)) + + // Verify all jobs were stopped + for _, job := range jobs { + _, ok := stoppedJobs.Load(job.key) + assert.True(t, ok, "Job %s was not stopped during shutdown", job.key) + } + + // Verify workers have no remaining jobs + assert.Empty(t, worker1.Jobs(), "Worker1 should have no remaining jobs") + assert.Empty(t, worker2.Jobs(), "Worker2 should have no remaining jobs") +} + type mockAcker struct { XAckFunc func(ctx context.Context, streamKey, sinkName string, ids ...string) *redis.IntCmd } diff --git a/pool/testing.go b/pool/testing.go index 75476d7..3163131 100644 --- a/pool/testing.go +++ b/pool/testing.go @@ -26,7 +26,7 @@ type mockHandlerWithoutNotify struct { const ( testWorkerShutdownTTL = 100 * time.Millisecond testJobSinkBlockDuration = 100 * time.Millisecond - testWorkerTTL = 100 * time.Millisecond + testWorkerTTL = 150 * time.Millisecond testAckGracePeriod = 50 * time.Millisecond ) diff --git a/pool/worker.go b/pool/worker.go index 5a4ef0e..fe37501 100644 --- a/pool/worker.go +++ b/pool/worker.go @@ -21,11 +21,10 @@ type ( Worker struct { // Unique worker ID ID string - // Worker pool node where worker is running. - Node *Node // Time worker was created. CreatedAt time.Time + node *Node handler JobHandler stream *streaming.Stream reader *streaming.Reader @@ -92,7 +91,7 @@ func newWorker(ctx context.Context, node *Node, h JobHandler) (*Worker, error) { return nil, fmt.Errorf("failed to add worker %q to pool %q: %w", wid, node.PoolName, err) } now := strconv.FormatInt(time.Now().UnixNano(), 10) - if _, err := node.keepAliveMap.SetAndWait(ctx, wid, now); err != nil { + if _, err := node.workerKeepAliveMap.SetAndWait(ctx, wid, now); err != nil { return nil, fmt.Errorf("failed to update worker keep-alive: %w", err) } stream, err := streaming.NewStream(workerStreamName(wid), node.rdb, soptions.WithStreamLogger(node.logger)) @@ -105,7 +104,7 @@ func newWorker(ctx context.Context, node *Node, h JobHandler) (*Worker, error) { } w := &Worker{ ID: wid, - Node: node, + node: node, handler: h, CreatedAt: time.Now(), stream: stream, @@ -113,7 +112,7 @@ func newWorker(ctx context.Context, node *Node, h JobHandler) (*Worker, error) { done: make(chan struct{}), jobsMap: node.jobsMap, jobPayloadsMap: node.jobPayloadsMap, - keepAliveMap: node.keepAliveMap, + keepAliveMap: node.workerKeepAliveMap, shutdownMap: node.shutdownMap, workerTTL: node.workerTTL, workerShutdownTTL: node.workerShutdownTTL, @@ -127,7 +126,7 @@ func newWorker(ctx context.Context, node *Node, h JobHandler) (*Worker, error) { "worker_shutdown_ttl", w.workerShutdownTTL) w.wg.Add(2) - pulse.Go(ctx, func() { w.handleEvents(reader.Subscribe()) }) + pulse.Go(ctx, func() { w.handleEvents(ctx, reader.Subscribe()) }) pulse.Go(ctx, func() { w.keepAlive(ctx) }) return w, nil @@ -152,7 +151,7 @@ func (w *Worker) Jobs() []*Job { Key: key, Payload: job.Payload, CreatedAt: job.CreatedAt, - Worker: &Worker{ID: w.ID, Node: w.Node, CreatedAt: w.CreatedAt}, + Worker: &Worker{ID: w.ID, node: w.node, CreatedAt: w.CreatedAt}, NodeID: job.NodeID, }) } @@ -167,9 +166,9 @@ func (w *Worker) IsStopped() bool { } // handleEvents is the worker loop. -func (w *Worker) handleEvents(c <-chan *streaming.Event) { +func (w *Worker) handleEvents(ctx context.Context, c <-chan *streaming.Event) { defer w.wg.Done() - ctx := context.Background() + for { select { case ev, ok := <-c: @@ -201,14 +200,12 @@ func (w *Worker) handleEvents(c <-chan *streaming.Event) { } w.ackPoolEvent(ctx, nodeID, ev.ID, nil) case <-w.done: - w.logger.Debug("handleEvents: exiting") return } } } -// stop stops the reader, the worker goroutines and removes the worker from the -// workers and keep-alive maps. +// stop stops the reader, destroys the stream and closes the worker. func (w *Worker) stop(ctx context.Context) { w.lock.Lock() if w.stopped { @@ -222,23 +219,7 @@ func (w *Worker) stop(ctx context.Context) { w.logger.Error(fmt.Errorf("failed to destroy stream for worker: %w", err)) } close(w.done) -} - -// stopAndWait stops the worker and waits for its goroutines to exit up to -// w.workerShutdownTTL time. -func (w *Worker) stopAndWait(ctx context.Context) { - w.stop(ctx) - c := make(chan struct{}) - go func() { - w.wg.Wait() - close(c) - }() - select { - case <-c: - w.logger.Debug("stopAndWait: worker stopped") - case <-time.After(w.workerShutdownTTL): - w.logger.Error(fmt.Errorf("stop timeout"), "after", w.workerShutdownTTL) - } + w.wg.Wait() } // startJob starts a job. @@ -278,6 +259,7 @@ func (w *Worker) stopJob(ctx context.Context, key string, forRequeue bool) error if err := w.handler.Stop(key); err != nil { return fmt.Errorf("failed to stop job %q: %w", key, err) } + w.logger.Debug("stopped job", "job", key) w.jobs.Delete(key) if _, _, err := w.jobsMap.RemoveValues(ctx, w.ID, key); err != nil { w.logger.Error(fmt.Errorf("stop job: failed to remove job %q from jobs map: %w", key, err)) @@ -312,7 +294,7 @@ func (w *Worker) ackPoolEvent(ctx context.Context, nodeID, eventID string, acker stream, ok := w.nodeStreams.Load(nodeID) if !ok { var err error - stream, err = streaming.NewStream(nodeStreamName(w.Node.PoolName, nodeID), w.Node.rdb, soptions.WithStreamLogger(w.logger)) + stream, err = streaming.NewStream(nodeStreamName(w.node.PoolName, nodeID), w.node.rdb, soptions.WithStreamLogger(w.logger)) if err != nil { w.logger.Error(fmt.Errorf("failed to create stream for node %q: %w", nodeID, err)) return @@ -332,21 +314,20 @@ func (w *Worker) ackPoolEvent(ctx context.Context, nodeID, eventID string, acker // keepAlive keeps the worker registration up-to-date until ctx is cancelled. func (w *Worker) keepAlive(ctx context.Context) { defer w.wg.Done() + ticker := time.NewTicker(w.workerTTL / 2) defer ticker.Stop() for { select { case <-ticker.C: if w.IsStopped() { - // Let's not recreate the map if we just deleted it - return + return // Let's not recreate the map if we just deleted it } now := strconv.FormatInt(time.Now().UnixNano(), 10) if _, err := w.keepAliveMap.Set(ctx, w.ID, now); err != nil { w.logger.Error(fmt.Errorf("failed to update worker keep-alive: %w", err)) } case <-w.done: - w.logger.Debug("keepAlive: exiting") return } } @@ -358,7 +339,7 @@ func (w *Worker) rebalance(ctx context.Context, activeWorkers []string) { rebalanced := make(map[string]*Job) w.jobs.Range(func(key, value any) bool { job := value.(*Job) - wid := activeWorkers[w.Node.h.Hash(job.Key, int64(len(activeWorkers)))] + wid := activeWorkers[w.node.h.Hash(job.Key, int64(len(activeWorkers)))] if wid != w.ID { rebalanced[job.Key] = job } @@ -375,8 +356,9 @@ func (w *Worker) rebalance(ctx context.Context, activeWorkers []string) { w.logger.Error(fmt.Errorf("rebalance: failed to stop job: %w", err), "job", key) continue } + w.logger.Debug("stopped job", "job", key) w.jobs.Delete(key) - cherr, err := w.Node.requeueJob(ctx, w.ID, job) + cherr, err := w.node.requeueJob(ctx, w.ID, job) if err != nil { w.logger.Error(fmt.Errorf("rebalance: failed to requeue job: %w", err), "job", key) if err := w.handler.Start(job); err != nil { @@ -387,7 +369,7 @@ func (w *Worker) rebalance(ctx context.Context, activeWorkers []string) { delete(rebalanced, key) cherrs[key] = cherr } - go w.Node.processRequeuedJobs(ctx, w.ID, cherrs, false) + pulse.Go(ctx, func() { w.node.processRequeuedJobs(ctx, w.ID, cherrs, false) }) } // requeueJobs requeues the jobs handled by the worker. @@ -410,12 +392,12 @@ func (w *Worker) requeueJobs(ctx context.Context) error { // First mark the worker as inactive so that requeued jobs are not assigned to this worker // Use optimistic locking to avoid race conditions. - prev, err := w.Node.workerMap.TestAndSet(ctx, w.ID, createdAt, "-") + prev, err := w.node.workerMap.TestAndSet(ctx, w.ID, createdAt, "-") if err != nil { return fmt.Errorf("requeueJobs: failed to mark worker as inactive: %w", err) } - if prev == "-" || prev == "" { - w.logger.Debug("requeueJobs: worker already marked as inactive, skipping requeue") + if prev == "-" { + w.logger.Debug("requeueJobs: jobs already requeued, skipping requeue") return nil } @@ -446,28 +428,27 @@ func (w *Worker) attemptRequeue(ctx context.Context, jobsToRequeue map[string]*J err error } resultChan := make(chan result, len(jobsToRequeue)) + defer close(resultChan) + wg.Add(len(jobsToRequeue)) for key, job := range jobsToRequeue { - wg.Add(1) - go func(k string, j *Job) { + pulse.Go(ctx, func() { defer wg.Done() - err := w.requeueJob(ctx, j) - resultChan <- result{key: k, err: err} - }(key, job) + err := w.requeueJob(ctx, job) + if err != nil { + w.logger.Error(fmt.Errorf("failed to requeue job: %w", err), "job", key) + } else { + w.logger.Debug("requeueJobs: requeued", "job", key) + } + resultChan <- result{key: key, err: err} + }) } - - go func() { - wg.Wait() - close(resultChan) - }() + wg.Wait() remainingJobs := make(map[string]*Job) for { select { - case res, ok := <-resultChan: - if !ok { - return remainingJobs - } + case res := <-resultChan: if res.err != nil { w.logger.Error(fmt.Errorf("requeueJobs: failed to requeue job %q: %w", res.key, res.err)) remainingJobs[res.key] = jobsToRequeue[res.key] @@ -475,7 +456,11 @@ func (w *Worker) attemptRequeue(ctx context.Context, jobsToRequeue map[string]*J } delete(remainingJobs, res.key) w.logger.Info("requeued", "job", res.key) - case <-time.After(w.workerTTL): + if len(remainingJobs) == 0 { + w.logger.Debug("requeueJobs: all jobs requeued") + return remainingJobs + } + case <-time.After(w.workerShutdownTTL): w.logger.Error(fmt.Errorf("requeueJobs: timeout reached, some jobs may not have been processed")) return remainingJobs } @@ -484,31 +469,17 @@ func (w *Worker) attemptRequeue(ctx context.Context, jobsToRequeue map[string]*J // requeueJob requeues a job. func (w *Worker) requeueJob(ctx context.Context, job *Job) error { - eventID, err := w.Node.poolStream.Add(ctx, evStartJob, marshalJob(job)) + eventID, err := w.node.poolStream.Add(ctx, evStartJob, marshalJob(job)) if err != nil { return fmt.Errorf("requeueJob: failed to add job to pool stream: %w", err) } - w.Node.pendingJobs.Store(eventID, nil) + w.node.pendingJobs.Store(eventID, nil) if err := w.stopJob(ctx, job.Key, true); err != nil { return fmt.Errorf("failed to stop job: %w", err) } return nil } -// cleanup removes the worker from the workers, keep-alive and jobs maps. -func (w *Worker) cleanup(ctx context.Context) { - if _, err := w.Node.workerMap.Delete(ctx, w.ID); err != nil { - w.logger.Error(fmt.Errorf("failed to remove worker from worker map: %w", err)) - } - if _, err := w.keepAliveMap.Delete(ctx, w.ID); err != nil { - w.logger.Error(fmt.Errorf("failed to remove worker from keep alive map: %w", err)) - } - _, err := w.jobsMap.Delete(ctx, w.ID) - if err != nil { - w.logger.Error(fmt.Errorf("failed to remove worker from jobs map: %w", err)) - } -} - // workerStreamName returns the name of the stream used to communicate with the // worker with the given ID. func workerStreamName(id string) string { diff --git a/pool/worker_test.go b/pool/worker_test.go index 8e5fb7a..d8c000e 100644 --- a/pool/worker_test.go +++ b/pool/worker_test.go @@ -33,7 +33,7 @@ func TestWorkerRequeueJobs(t *testing.T) { // Emulate the worker failing by preventing it from refreshing its keepalive // This means we can't cleanup cleanly, hence "false" in CleanupRedis - worker.stopAndWait(ctx) + worker.stop(ctx) // Create a new worker to pick up the requeued job newWorker := newTestWorker(t, ctx, node) @@ -73,7 +73,7 @@ func TestStaleWorkerCleanupInNode(t *testing.T) { staleWorkers[i] = newTestWorker(t, ctx, node) staleWorkers[i].stop(ctx) // Set the last seen time to a past time - _, err := node.keepAliveMap.Set(ctx, staleWorkers[i].ID, strconv.FormatInt(time.Now().Add(-2*node.workerTTL).UnixNano(), 10)) + _, err := node.workerKeepAliveMap.Set(ctx, staleWorkers[i].ID, strconv.FormatInt(time.Now().Add(-2*node.workerTTL).UnixNano(), 10)) assert.NoError(t, err) } @@ -108,7 +108,7 @@ func TestStaleWorkerCleanupAcrossNodes(t *testing.T) { staleWorkers[i] = newTestWorker(t, ctx, node2) staleWorkers[i].stop(ctx) // Set the last seen time to a past time - _, err := node2.keepAliveMap.Set(ctx, staleWorkers[i].ID, strconv.FormatInt(time.Now().Add(-2*node2.workerTTL).UnixNano(), 10)) + _, err := node2.workerKeepAliveMap.Set(ctx, staleWorkers[i].ID, strconv.FormatInt(time.Now().Add(-2*node2.workerTTL).UnixNano(), 10)) assert.NoError(t, err) } diff --git a/rmap/map.go b/rmap/map.go index f61648e..0c2034a 100644 --- a/rmap/map.go +++ b/rmap/map.go @@ -406,7 +406,8 @@ func (sm *Map) TestAndDelete(ctx context.Context, key, test string) (string, err return prev.(string), nil } -// Reset clears the map content. +// Reset clears the map content. Reset is the only method that can be called +// after the map is closed. func (sm *Map) Reset(ctx context.Context) error { _, err := sm.runLuaScript(ctx, "reset", sm.resetScript, "*") return err @@ -544,7 +545,7 @@ func (sm *Map) run() { sm.lock.Unlock() case <-sm.done: - sm.logger.Info("stopped") + sm.logger.Info("closed") // no need to lock, stopping is true for _, c := range sm.chans { close(c) @@ -565,7 +566,7 @@ func (sm *Map) run() { // It is the caller's responsibility to make sure the map is locked. func (sm *Map) runLuaScript(ctx context.Context, name string, script *redis.Script, args ...any) (any, error) { sm.lock.RLock() - if sm.closing { + if sm.closing && name != "reset" { sm.lock.RUnlock() return "", fmt.Errorf("pulse map: %s is stopped", sm.Name) } diff --git a/rmap/map_test.go b/rmap/map_test.go index e497077..41dbbd5 100644 --- a/rmap/map_test.go +++ b/rmap/map_test.go @@ -160,7 +160,9 @@ func TestMapLocal(t *testing.T) { old, err = m.Delete(ctx, key) assert.Error(t, err) assert.Equal(t, "", old) - assert.Error(t, m.Reset(ctx)) + + // Reset should work after the map is closed + assert.NoError(t, m.Reset(ctx)) // Cleanup m, err = Join(ctx, "test", rdb) @@ -430,7 +432,7 @@ func TestLogs(t *testing.T) { assert.Contains(t, buf.String(), `key=foo val=bar`) assert.Contains(t, buf.String(), `msg=deleted key=foo`) assert.Contains(t, buf.String(), `reset`) - assert.Contains(t, buf.String(), `stopped`) + assert.Contains(t, buf.String(), `closed`) } func TestJoinErrors(t *testing.T) { diff --git a/streaming/reader.go b/streaming/reader.go index 299aef6..0ca6492 100644 --- a/streaming/reader.go +++ b/streaming/reader.go @@ -272,6 +272,7 @@ func (r *Reader) cleanup() { for _, c := range r.chans { close(c) } + r.chans = nil r.wait.Done() } diff --git a/streaming/sink.go b/streaming/sink.go index cd582b2..1ecc482 100644 --- a/streaming/sink.go +++ b/streaming/sink.go @@ -266,6 +266,15 @@ func (s *Sink) RemoveStream(ctx context.Context, stream *Stream) error { s.streamCursors[i] = stream.key s.streamCursors[len(s.streams)+i] = ">" } + if err := s.removeStreamConsumer(ctx, stream); err != nil { + return err + } + s.logger.Info("removed", "stream", stream.Name) + return nil +} + +// removeStreamConsumer removes the stream consumer from the sink. +func (s *Sink) removeStreamConsumer(ctx context.Context, stream *Stream) error { remains, _, err := s.consumersMap[stream.Name].RemoveValues(ctx, s.Name, s.consumer) if err != nil { return fmt.Errorf("failed to remove consumer %s from replicated map for stream %s: %w", s.consumer, stream.Name, err) @@ -275,13 +284,12 @@ func (s *Sink) RemoveStream(ctx context.Context, stream *Stream) error { return err } } - s.logger.Info("removed", "stream", stream.Name) return nil } // Close stops event polling, waits for all events to be processed, and closes the sink channel. // It is safe to call Close multiple times. -func (s *Sink) Close() { +func (s *Sink) Close(ctx context.Context) { s.lock.Lock() if s.closing { s.lock.Unlock() diff --git a/streaming/sink_test.go b/streaming/sink_test.go index e73295c..c865096 100644 --- a/streaming/sink_test.go +++ b/streaming/sink_test.go @@ -86,7 +86,7 @@ func TestReadSinceLastEvent(t *testing.T) { options.WithSinkStartAfter(eventID), options.WithSinkBlockDuration(testBlockDuration)) require.NoError(t, err) - defer sink2.Close() + defer cleanupSink(t, ctx, s, sink2) c2 := sink2.Subscribe() read = readOneEvent(t, ctx, c2, sink2) assert.Equal(t, "event", read.EventName) @@ -97,7 +97,7 @@ func TestReadSinceLastEvent(t *testing.T) { options.WithSinkStartAfter("0"), options.WithSinkBlockDuration(testBlockDuration)) require.NoError(t, err) - defer sink3.Close() + defer cleanupSink(t, ctx, s, sink3) c3 := sink3.Subscribe() read = readOneEvent(t, ctx, c3, sink3) assert.Equal(t, "event", read.EventName) @@ -128,7 +128,7 @@ func TestCleanup(t *testing.T) { assert.Equal(t, []byte("payload"), read.Payload) // Stop sink, destroy stream and check Redis keys are gone - sink.Close() + sink.Close(ctx) assert.Eventually(t, func() bool { return sink.IsClosed() }, max, delay) assert.Equal(t, rdb.Exists(ctx, s.key).Val(), int64(1)) assert.NoError(t, s.Destroy(ctx)) @@ -245,7 +245,7 @@ func TestMultipleConsumers(t *testing.T) { options.WithSinkAckGracePeriod(testAckDuration)) require.NoError(t, err) defer func() { - sink2.Close() + sink2.Close(ctx) assert.Eventually(t, func() bool { return sink2.IsClosed() }, max, delay) }() @@ -355,7 +355,7 @@ func TestNonAckMessageDeliveredToAnotherConsumer(t *testing.T) { options.WithSinkBlockDuration(testBlockDuration), options.WithSinkAckGracePeriod(testAckDuration)) require.NoError(t, err) - defer sink2.Close() + defer sink2.Close(ctx) // Subscribe to both sinks c1 := sink1.Subscribe() @@ -387,7 +387,7 @@ func TestNonAckMessageDeliveredToAnotherConsumer(t *testing.T) { assert.Equal(t, []byte("test_payload"), read1.Payload) // Close the receiver sink - receiverSink.Close() + receiverSink.Close(ctx) assert.Eventually(t, func() bool { return receiverSink.IsClosed() }, max, delay) // The message should now be redelivered to the other sink @@ -461,7 +461,7 @@ func TestStaleConsumerDeletionAndMessageClaiming(t *testing.T) { }, max, delay, "Expected two consumers") // Close the sink to stop keep-alive refresh - sink1.Close() + sink1.Close(ctx) assert.Eventually(t, func() bool { return sink1.IsClosed() }, max, delay) // Verify that the stale consumer is deleted diff --git a/streaming/testing.go b/streaming/testing.go index 1caf2e9..8bda437 100644 --- a/streaming/testing.go +++ b/streaming/testing.go @@ -64,7 +64,7 @@ func readOneReaderEvent(t *testing.T, c <-chan *Event) *Event { func cleanupSink(t *testing.T, ctx context.Context, s *Stream, sink *Sink) { t.Helper() if sink != nil { - sink.Close() + sink.Close(ctx) assert.Eventually(t, func() bool { return sink.IsClosed() }, max, delay) } if s != nil { diff --git a/testing/redis.go b/testing/redis.go index 78cfa82..cfcca8f 100644 --- a/testing/redis.go +++ b/testing/redis.go @@ -3,6 +3,7 @@ package testing import ( "context" "os" + "regexp" "strings" "testing" "time" @@ -40,8 +41,12 @@ func CleanupRedis(t *testing.T, rdb *redis.Client, checkClean bool, testName str require.NoError(t, err) var filtered []string for _, k := range keys { - // Sinks content gets reused, so ignore it if strings.HasSuffix(k, ":sinks:content") { + // Sinks content is cleaned up asynchronously, so ignore it + continue + } + if regexp.MustCompile(`^pulse:stream:[^:]+:node:.*`).MatchString(k) { + // Node streams are cleaned up asynchronously, so ignore them continue } if strings.Contains(k, testName) {