diff --git a/client/src/main/java/org/apache/rocketmq/client/consumer/DefaultLitePullConsumer.java b/client/src/main/java/org/apache/rocketmq/client/consumer/DefaultLitePullConsumer.java index 5e5bd4daaaa..9c9d1116b2f 100644 --- a/client/src/main/java/org/apache/rocketmq/client/consumer/DefaultLitePullConsumer.java +++ b/client/src/main/java/org/apache/rocketmq/client/consumer/DefaultLitePullConsumer.java @@ -103,7 +103,7 @@ public class DefaultLitePullConsumer extends ClientConfig implements LitePullCon /** * Pull thread number */ - private int pullThreadNums = 20; + private int pullThreadNums = 1; /** * Minimum commit offset interval time in milliseconds. diff --git a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/DefaultLitePullConsumerImpl.java b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/DefaultLitePullConsumerImpl.java index 20ca4770086..6cac6a60505 100644 --- a/client/src/main/java/org/apache/rocketmq/client/impl/consumer/DefaultLitePullConsumerImpl.java +++ b/client/src/main/java/org/apache/rocketmq/client/impl/consumer/DefaultLitePullConsumerImpl.java @@ -39,6 +39,7 @@ import org.apache.rocketmq.client.consumer.DefaultLitePullConsumer; import org.apache.rocketmq.client.consumer.MessageQueueListener; import org.apache.rocketmq.client.consumer.MessageSelector; +import org.apache.rocketmq.client.consumer.PullCallback; import org.apache.rocketmq.client.consumer.PullResult; import org.apache.rocketmq.client.consumer.TopicMessageQueueChangeListener; import org.apache.rocketmq.client.consumer.listener.ConsumeConcurrentlyStatus; @@ -857,7 +858,7 @@ public long searchOffset(MessageQueue mq, long timestamp) throws MQClientExcepti return this.mQClientFactory.getMQAdminImpl().searchOffset(mq, timestamp); } - public class PullTaskImpl implements Runnable { + private class PullTaskImpl implements Runnable { private final MessageQueue messageQueue; private volatile boolean cancelled = false; private Thread currentThread; @@ -963,27 +964,40 @@ public void run() { subscriptionData = FilterAPI.buildSubscriptionData(topic, subExpression4Assign); } - PullResult pullResult = pull(messageQueue, subscriptionData, offset, defaultLitePullConsumer.getPullBatchSize()); - if (this.isCancelled() || processQueue.isDropped()) { - return; - } - switch (pullResult.getPullStatus()) { - case FOUND: - final Object objLock = messageQueueLock.fetchLockObject(messageQueue); - synchronized (objLock) { - if (pullResult.getMsgFoundList() != null && !pullResult.getMsgFoundList().isEmpty() && assignedMessageQueue.getSeekOffset(messageQueue) == -1) { - processQueue.putMessage(pullResult.getMsgFoundList()); - submitConsumeRequest(new ConsumeRequest(pullResult.getMsgFoundList(), messageQueue, processQueue)); - } + pullAsync(messageQueue, subscriptionData, offset, defaultLitePullConsumer.getPullBatchSize(), new PullCallback() { + @Override + public void onSuccess(PullResult pullResult) { + DefaultLitePullConsumerImpl.this.pullAPIWrapper.processPullResult(messageQueue, pullResult, subscriptionData); + if (PullTaskImpl.this.isCancelled() || processQueue.isDropped()) { + return; } - break; - case OFFSET_ILLEGAL: - log.warn("The pull request offset illegal, {}", pullResult.toString()); - break; - default: - break; - } - updatePullOffset(messageQueue, pullResult.getNextBeginOffset(), processQueue); + switch (pullResult.getPullStatus()) { + case FOUND: + final Object objLock = messageQueueLock.fetchLockObject(messageQueue); + synchronized (objLock) { + if (pullResult.getMsgFoundList() != null && !pullResult.getMsgFoundList().isEmpty() && assignedMessageQueue.getSeekOffset(messageQueue) == -1) { + processQueue.putMessage(pullResult.getMsgFoundList()); + submitConsumeRequest(new ConsumeRequest(pullResult.getMsgFoundList(), messageQueue, processQueue)); + } + } + break; + case OFFSET_ILLEGAL: + log.warn("The pull request offset illegal, {}", pullResult.toString()); + break; + default: + break; + } + updatePullOffset(messageQueue, pullResult.getNextBeginOffset(), processQueue); + DefaultLitePullConsumerImpl.this.scheduledThreadPoolExecutor.schedule(PullTaskImpl.this, 0L, TimeUnit.MILLISECONDS); + } + + @Override + public void onException(Throwable e) { + log.warn("execute the pull request exception", e); + DefaultLitePullConsumerImpl.this.scheduledThreadPoolExecutor.schedule(PullTaskImpl.this, pullTimeDelayMillsWhenException, TimeUnit.MILLISECONDS); + } + }); + } catch (InterruptedException interruptedException) { log.warn("Polling thread was interrupted.", interruptedException); } catch (Throwable e) { @@ -1016,19 +1030,20 @@ public MessageQueue getMessageQueue() { } } - private PullResult pull(MessageQueue mq, SubscriptionData subscriptionData, long offset, int maxNums) + private PullResult pullAsync(MessageQueue mq, SubscriptionData subscriptionData, long offset, int maxNums, + PullCallback pullCallback) throws MQClientException, RemotingException, MQBrokerException, InterruptedException { - return pull(mq, subscriptionData, offset, maxNums, this.defaultLitePullConsumer.getConsumerPullTimeoutMillis()); + return pull(mq, subscriptionData, offset, maxNums, this.defaultLitePullConsumer.getConsumerPullTimeoutMillis(), CommunicationMode.ASYNC, pullCallback); } - private PullResult pull(MessageQueue mq, SubscriptionData subscriptionData, long offset, int maxNums, long timeout) + private PullResult pull(MessageQueue mq, SubscriptionData subscriptionData, long offset, int maxNums, long timeout, + CommunicationMode communicationMode, PullCallback pullCallback) throws MQClientException, RemotingException, MQBrokerException, InterruptedException { - return this.pullSyncImpl(mq, subscriptionData, offset, maxNums, true, timeout); + return this.pullImpl(mq, subscriptionData, offset, maxNums, true, timeout, communicationMode, pullCallback); } - private PullResult pullSyncImpl(MessageQueue mq, SubscriptionData subscriptionData, long offset, int maxNums, - boolean block, - long timeout) + private PullResult pullImpl(MessageQueue mq, SubscriptionData subscriptionData, long offset, int maxNums, + boolean block, long timeout, CommunicationMode communicationMode, PullCallback pullCallback) throws MQClientException, RemotingException, MQBrokerException, InterruptedException { if (null == mq) { @@ -1043,6 +1058,10 @@ private PullResult pullSyncImpl(MessageQueue mq, SubscriptionData subscriptionDa throw new MQClientException("maxNums <= 0", null); } + if (CommunicationMode.ASYNC == communicationMode && pullCallback == null) { + throw new MQClientException("Async communication mode but callback is null", null); + } + int sysFlag = PullSysFlag.buildSysFlag(false, block, true, false, true); long timeoutMillis = block ? this.defaultLitePullConsumer.getConsumerTimeoutMillisWhenSuspend() : timeout; @@ -1059,8 +1078,8 @@ private PullResult pullSyncImpl(MessageQueue mq, SubscriptionData subscriptionDa 0, this.defaultLitePullConsumer.getBrokerSuspendMaxTimeMillis(), timeoutMillis, - CommunicationMode.SYNC, - null + communicationMode, + pullCallback ); this.pullAPIWrapper.processPullResult(mq, pullResult, subscriptionData); return pullResult; diff --git a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultLitePullConsumerTest.java b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultLitePullConsumerTest.java index 24e39f56689..2b1f0eee608 100644 --- a/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultLitePullConsumerTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/consumer/DefaultLitePullConsumerTest.java @@ -73,7 +73,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -725,7 +724,7 @@ private void initDefaultLitePullConsumer(DefaultLitePullConsumer litePullConsume field.set(litePullConsumerImpl, offsetStore); when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), - anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) + anyLong(), any(CommunicationMode.class), any(PullCallback.class))) .thenAnswer(new Answer() { @Override public PullResult answer(InvocationOnMock mock) throws Throwable { @@ -739,6 +738,7 @@ public PullResult answer(InvocationOnMock mock) throws Throwable { messageClientExt.setBornHost(new InetSocketAddress(8080)); messageClientExt.setStoreHost(new InetSocketAddress(8080)); PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); + ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); return pullResult; } }); @@ -785,7 +785,7 @@ private void initDefaultLitePullConsumerWithTag(DefaultLitePullConsumer litePull field.set(litePullConsumerImpl, offsetStore); when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), - anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) + anyLong(), any(CommunicationMode.class), any(PullCallback.class))) .thenAnswer(new Answer() { @Override public PullResult answer(InvocationOnMock mock) throws Throwable { @@ -800,6 +800,7 @@ public PullResult answer(InvocationOnMock mock) throws Throwable { messageClientExt.setBornHost(new InetSocketAddress(8080)); messageClientExt.setStoreHost(new InetSocketAddress(8080)); PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); + ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); return pullResult; } }); diff --git a/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQLitePullConsumerWithTraceTest.java b/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQLitePullConsumerWithTraceTest.java index e0573bdfb0b..2871af14028 100644 --- a/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQLitePullConsumerWithTraceTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/trace/DefaultMQLitePullConsumerWithTraceTest.java @@ -76,7 +76,6 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -226,7 +225,7 @@ private void initDefaultLitePullConsumer(DefaultLitePullConsumer litePullConsume traceProducer.getDefaultMQProducerImpl().getMqClientFactory().registerProducer(producerGroupTraceTemp, traceProducer.getDefaultMQProducerImpl()); when(mQClientFactory.getMQClientAPIImpl().pullMessage(anyString(), any(PullMessageRequestHeader.class), - anyLong(), any(CommunicationMode.class), nullable(PullCallback.class))) + anyLong(), any(CommunicationMode.class), any(PullCallback.class))) .thenAnswer(new Answer() { @Override public Object answer(InvocationOnMock mock) throws Throwable { @@ -240,6 +239,7 @@ public Object answer(InvocationOnMock mock) throws Throwable { messageClientExt.setBornHost(new InetSocketAddress(8080)); messageClientExt.setStoreHost(new InetSocketAddress(8080)); PullResult pullResult = createPullResult(requestHeader, PullStatus.FOUND, Collections.singletonList(messageClientExt)); + ((PullCallback) mock.getArgument(4)).onSuccess(pullResult); return pullResult; } });