Skip to content

Commit

Permalink
graceful exit for worker goroutines
Browse files Browse the repository at this point in the history
  • Loading branch information
JyotinderSingh committed Dec 23, 2023
1 parent aff922d commit 2523d5d
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"net"
"os"
"sync"
"time"

"github.com/JyotinderSingh/task-queue/pkg/common"
Expand All @@ -28,50 +29,60 @@ type WorkerServer struct {
coordinatorAddress string
listener net.Listener
grpcServer *grpc.Server
coordinatorConnection *grpc.ClientConn
coordinatorServiceClient pb.CoordinatorServiceClient
heartbeatInterval time.Duration
taskQueue chan *pb.TaskRequest
ctx context.Context // The root context for all goroutines
cancel context.CancelFunc // Function to cancel the context
wg sync.WaitGroup // WaitGroup to wait for all goroutines to finish
}

// NewServer creates and returns a new WorkerServer.
func NewServer(port string, coordinator string) *WorkerServer {
ctx, cancel := context.WithCancel(context.Background())
return &WorkerServer{
id: uuid.New().ID(),
serverPort: port,
coordinatorAddress: coordinator,
heartbeatInterval: common.DefaultHeartbeat,
taskQueue: make(chan *pb.TaskRequest, 100), // Buffered channel
ctx: ctx,
cancel: cancel,
}
}

// Start initializes and starts the WorkerServer.
func (w *WorkerServer) Start() error {
ctx := context.Background()
w.startWorkerPool(ctx, workerPoolSize)
w.startWorkerPool(workerPoolSize)

if err := w.connectToCoordinator(); err != nil {
return fmt.Errorf("failed to connect to coordinator: %w", err)
}
defer w.closeGRPCConnection()

go w.periodicHeartbeat(ctx)
go w.periodicHeartbeat()

return w.startGRPCServer()
}

func (w *WorkerServer) connectToCoordinator() error {
log.Println("Connecting to coordinator...")
conn, err := grpc.Dial(w.coordinatorAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
var err error
w.coordinatorConnection, err = grpc.Dial(w.coordinatorAddress, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
if err != nil {
return err
}

w.coordinatorServiceClient = pb.NewCoordinatorServiceClient(conn)
w.coordinatorServiceClient = pb.NewCoordinatorServiceClient(w.coordinatorConnection)
log.Println("Connected to coordinator!")
return nil
}

func (w *WorkerServer) periodicHeartbeat(ctx context.Context) {
func (w *WorkerServer) periodicHeartbeat() {
w.wg.Add(1) // Add this goroutine to the waitgroup.
defer w.wg.Done() // Signal this goroutine is done when the function returns

ticker := time.NewTicker(w.heartbeatInterval)
defer ticker.Stop()

Expand All @@ -82,7 +93,7 @@ func (w *WorkerServer) periodicHeartbeat(ctx context.Context) {
log.Printf("Failed to send heartbeat: %v", err)
return
}
case <-ctx.Done():
case <-w.ctx.Done():
return
}
}
Expand Down Expand Up @@ -128,6 +139,11 @@ func (w *WorkerServer) startGRPCServer() error {

// Stop gracefully shuts down the WorkerServer.
func (w *WorkerServer) Stop() error {
// Signal all goroutines to stop
w.cancel()
// Wait for all goroutines to finish
w.wg.Wait()

w.closeGRPCConnection()
log.Println("Worker server stopped")
return nil
Expand All @@ -137,6 +153,16 @@ func (w *WorkerServer) closeGRPCConnection() {
if w.grpcServer != nil {
w.grpcServer.GracefulStop()
}

if w.listener != nil {
if err := w.listener.Close(); err != nil {
log.Printf("Error while closing the listener: %v", err)
}
}

if err := w.coordinatorConnection.Close(); err != nil {
log.Printf("Error while closing client connection with coordinator: %v", err)
}
}

// SubmitTask handles the submission of a task to the worker server.
Expand All @@ -153,14 +179,16 @@ func (w *WorkerServer) SubmitTask(ctx context.Context, req *pb.TaskRequest) (*pb
}

// startWorkerPool starts a pool of worker goroutines.
func (w *WorkerServer) startWorkerPool(ctx context.Context, numWorkers int) {
func (w *WorkerServer) startWorkerPool(numWorkers int) {
for i := 0; i < numWorkers; i++ {
go w.worker(ctx)
w.wg.Add(1)
go w.worker()
}
}

// worker is the function run by each worker goroutine.
func (w *WorkerServer) worker(ctx context.Context) {
func (w *WorkerServer) worker() {
defer w.wg.Done() // Signal this worker is done when the function returns.
for {
select {
case task := <-w.taskQueue:
Expand All @@ -170,7 +198,7 @@ func (w *WorkerServer) worker(ctx context.Context) {
Status: pb.TaskStatus_PROCESSING,
})
w.processTask(task)
case <-ctx.Done():
case <-w.ctx.Done():
return
}
}
Expand Down

0 comments on commit 2523d5d

Please sign in to comment.