diff --git a/consumer_with_idle_trigger.go b/consumer_with_idle_trigger.go index 5c6e54a..0219a79 100644 --- a/consumer_with_idle_trigger.go +++ b/consumer_with_idle_trigger.go @@ -13,20 +13,22 @@ import ( ) type ConsumerWithIdleTrigger struct { - sqs *sqs.Client - handler HandlerWithIdleTrigger - wg *sync.WaitGroup - cfg Config - idleDurationTimeout time.Duration + sqs *sqs.Client + handler HandlerWithIdleTrigger + wg *sync.WaitGroup + cfg Config + idleDurationTimeout time.Duration + sqsReceiveWaitTimeSeconds int32 } -func NewConsumerWithIdleTrigger(awsCfg aws.Config, cfg Config, handler HandlerWithIdleTrigger, idleDurationTimeout time.Duration) *ConsumerWithIdleTrigger { +func NewConsumerWithIdleTrigger(awsCfg aws.Config, cfg Config, handler HandlerWithIdleTrigger, idleDurationTimeout time.Duration, sqsReceiveWaitTimeSeconds int32) *ConsumerWithIdleTrigger { return &ConsumerWithIdleTrigger{ - sqs: sqs.NewFromConfig(awsCfg), - handler: handler, - wg: &sync.WaitGroup{}, - cfg: cfg, - idleDurationTimeout: idleDurationTimeout, + sqs: sqs.NewFromConfig(awsCfg), + handler: handler, + wg: &sync.WaitGroup{}, + cfg: cfg, + idleDurationTimeout: idleDurationTimeout, + sqsReceiveWaitTimeSeconds: sqsReceiveWaitTimeSeconds, } } @@ -57,7 +59,7 @@ loop: output, err := c.sqs.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{ QueueUrl: &c.cfg.QueueURL, MaxNumberOfMessages: c.cfg.BatchSize, - WaitTimeSeconds: int32(c.idleDurationTimeout.Seconds()), + WaitTimeSeconds: c.sqsReceiveWaitTimeSeconds, MessageAttributeNames: []string{"TraceID", "SpanID"}, }) if err != nil { diff --git a/consumer_with_idle_trigger_test.go b/consumer_with_idle_trigger_test.go index 4594645..a7a78b3 100644 --- a/consumer_with_idle_trigger_test.go +++ b/consumer_with_idle_trigger_test.go @@ -28,7 +28,8 @@ type MsgHandlerWithIdleTrigger struct { } const ( - IdleTimeout = 500 * time.Millisecond + IdleTimeout = 500 * time.Millisecond + SqsReceiveWaitTimeSeconds = int32(1) ) func TestConsumeWithIdleTrigger(t *testing.T) { @@ -58,7 +59,7 @@ func TestConsumeWithIdleTrigger(t *testing.T) { BatchSize: batchSize, ExtendEnabled: true, } - consumer := NewConsumerWithIdleTrigger(awsCfg, config, msgHandler, IdleTimeout) + consumer := NewConsumerWithIdleTrigger(awsCfg, config, msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds) go consumer.Consume(ctx) t.Cleanup(func() { @@ -104,7 +105,7 @@ func TestConsumeWithIdleTimeout_GracefulShutdown(t *testing.T) { t: t, msgsReceivedCount: 0, } - consumer := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout) + consumer := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds) var wg sync.WaitGroup wg.Add(2) @@ -155,7 +156,7 @@ func TestConsumeWithIdleTimeout_TimesOut(t *testing.T) { t: t, msgsReceivedCount: 0, } - consumer := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout) + consumer := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds) go consumer.Consume(ctx) t.Cleanup(func() { @@ -163,7 +164,7 @@ func TestConsumeWithIdleTimeout_TimesOut(t *testing.T) { }) // Wait for the timeout - time.Sleep(time.Second * 2) + time.Sleep(time.Second * 3) // ensure that it gets called multiple times assert.GreaterOrEqual(t, msgHandler.idleTimeoutTriggeredCount, 2) @@ -195,7 +196,7 @@ func TestConsumeWithIdleTimeout_TimesOutAndConsumes(t *testing.T) { ExtendEnabled: true, } msgHandler := handlerWithIdleTrigger(t, expectedMsg, expectedMsgAttributes) - consumer := NewConsumerWithIdleTrigger(awsCfg, config, msgHandler, IdleTimeout) + consumer := NewConsumerWithIdleTrigger(awsCfg, config, msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds) go consumer.Consume(ctx) t.Cleanup(func() { @@ -206,7 +207,7 @@ func TestConsumeWithIdleTimeout_TimesOutAndConsumes(t *testing.T) { } cancel() }) - time.Sleep(time.Second * 1) + time.Sleep(time.Second * 2) // ensure that it gets called first before receiving a message assert.GreaterOrEqual(t, msgHandler.idleTimeoutTriggeredCount, 1) @@ -218,7 +219,7 @@ func TestConsumeWithIdleTimeout_TimesOutAndConsumes(t *testing.T) { time.Sleep(time.Second * 2) // Check that the message arrived assert.Equal(t, 1, msgHandler.msgsReceivedCount) - assert.GreaterOrEqual(t, msgHandler.idleTimeoutTriggeredCount, 3) + assert.GreaterOrEqual(t, msgHandler.idleTimeoutTriggeredCount, 2) }