Skip to content

Commit

Permalink
Fix heartbeating implementation on stomp
Browse files Browse the repository at this point in the history
  • Loading branch information
trickl committed Jan 19, 2020
1 parent c116063 commit 90cd891
Show file tree
Hide file tree
Showing 17 changed files with 226 additions and 42 deletions.
18 changes: 18 additions & 0 deletions src/main/java/com/trickl/exceptions/MissingHeartbeatException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.trickl.exceptions;

public class MissingHeartbeatException extends Exception {

private static final long serialVersionUID = -1761231643713163261L;

public MissingHeartbeatException(String message) {
super(message);
}

public MissingHeartbeatException(String message, Throwable cause) {
super(message, cause);
}

public MissingHeartbeatException(Throwable cause) {
super(cause);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package com.trickl.exceptions;

import lombok.NoArgsConstructor;

@NoArgsConstructor
public class NoSuchStreamException extends Exception {

private static final long serialVersionUID = -1761231643713163261L;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package com.trickl.exceptions;

import lombok.NoArgsConstructor;

@NoArgsConstructor
public class NoSupportingDelegateException extends Exception {

private static final long serialVersionUID = -1761231643713163261L;
Expand Down
18 changes: 18 additions & 0 deletions src/main/java/com/trickl/exceptions/RemoteStreamException.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.trickl.exceptions;

public class RemoteStreamException extends Exception {

private static final long serialVersionUID = -1761231643713163261L;

public RemoteStreamException(String message) {
super(message);
}

public RemoteStreamException(String message, Throwable cause) {
super(message, cause);
}

public RemoteStreamException(Throwable cause) {
super(cause);
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package com.trickl.exceptions;

import lombok.NoArgsConstructor;

@NoArgsConstructor
public class SubscriptionFailedException extends Exception {

private static final long serialVersionUID = -1761231643713163261L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.time.Duration;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.logging.Level;
Expand All @@ -22,7 +23,7 @@ public class ConditionalTimeoutPublisher<T> implements Supplier<Mono<T>> {
private final Publisher<T> source;
private final Duration timeout;
private final Predicate<? super T> condition;
private final Supplier<Throwable> onTimeoutThrow;
private final Function<TimeoutException, Throwable> onTimeoutThrow;
private final Runnable onTimeoutDo;
private final Scheduler scheduler;

Expand All @@ -49,7 +50,7 @@ public Mono<T> get() {
.onErrorMap(error -> {
if (error instanceof TimeoutException
&& onTimeoutThrow != null) {
return onTimeoutThrow.get();
return onTimeoutThrow.apply((TimeoutException) error);
}
return error;
})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.trickl.flux.websocket;

import com.trickl.exceptions.NoSuchStreamException;
import com.trickl.exceptions.SubscriptionFailedException;
import com.trickl.flux.consumers.SimpMessageSender;
import com.trickl.model.streams.StreamDetails;
Expand Down Expand Up @@ -29,8 +28,6 @@
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessagingTemplate;
import org.springframework.messaging.simp.user.SimpSubscription;
import org.springframework.messaging.simp.user.SimpUserRegistry;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.web.socket.messaging.AbstractSubProtocolEvent;
Expand All @@ -46,8 +43,6 @@ public class WebSocketRequestRouter<T> implements SmartApplicationListener {

private final Function<StreamId, Optional<Flux<T>>> fluxFactory;

private final SimpUserRegistry simpUserRegistry;

private final SimpMessagingTemplate messagingTemplate;

private final Map<StreamId, Optional<Flux<T>>> fluxes = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -182,6 +177,7 @@ public void unsubscribeAll() {
* @param subscriptionId The subscription to unsubscribe
*/
public void unsubscribe(String subscriptionId) {
log.info("Unsubscribing " + subscriptionId);
subscriptions.computeIfPresent(subscriptionId,
(id, subscription) -> {
subscription.cancel();
Expand Down Expand Up @@ -220,13 +216,11 @@ public void onApplicationEvent(ApplicationEvent event) {
log.log(Level.WARNING, "Subscription failed", ex);
}
} else if (event instanceof SessionDisconnectEvent) {
// Unsubscribe any hanging subscriptions
// TODO: Not very efficient
String sessionId = accessor.getSessionId();
simpUserRegistry
.findSubscriptions(
subscription -> subscription.getSession().getId().equals(sessionId))
.stream()
.map(SimpSubscription::getId)
subscriptionDetails.values().stream()
.filter(sub -> sub.getSessionId().equals(sessionId))
.map(SubscriptionDetails::getId)
.forEach(this::unsubscribe);

} else if (event instanceof SessionUnsubscribeEvent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.trickl.flux.websocket.stomp.frames.StompDisconnectFrame;

import java.net.URI;
import java.time.Duration;

import lombok.RequiredArgsConstructor;

Expand All @@ -23,6 +24,8 @@ public class RawStompFluxClient {
private final WebSocketClient webSocketClient;
private final URI transportUri;
private final Mono<HttpHeaders> webSocketHeadersProvider;
private final Duration heartbeatSendFrequency;
private final Duration heartbeatReceiveFrequency;

/**
* Connect to a stomp service.
Expand Down Expand Up @@ -54,6 +57,8 @@ public Flux<StompFrame> get(Publisher<StompFrame> send) {
protected void onConnect(FluxSink<StompFrame> frameSink) {
StompConnectFrame connectFrame = StompConnectFrame.builder()
.acceptVersion("1.0,1.1,1.2")
.heartbeatSendFrequency(heartbeatSendFrequency)
.heartbeatReceiveFrequency(heartbeatReceiveFrequency)
.host(transportUri.getHost())
.build();
frameSink.next(connectFrame);
Expand Down
92 changes: 86 additions & 6 deletions src/main/java/com/trickl/flux/websocket/stomp/StompFluxClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.trickl.exceptions.MissingHeartbeatException;
import com.trickl.exceptions.RemoteStreamException;
import com.trickl.flux.mappers.ThrowableMapper;
import com.trickl.flux.publishers.ConditionalTimeoutPublisher;
import com.trickl.flux.websocket.stomp.StompFrame;
import com.trickl.flux.websocket.stomp.frames.StompConnectedFrame;
import com.trickl.flux.websocket.stomp.frames.StompErrorFrame;
import com.trickl.flux.websocket.stomp.frames.StompHeartbeatFrame;
import com.trickl.flux.websocket.stomp.frames.StompMessageFrame;
import com.trickl.flux.websocket.stomp.frames.StompSendFrame;
import com.trickl.flux.websocket.stomp.frames.StompSubscribeFrame;
Expand Down Expand Up @@ -34,6 +39,7 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

@Log
@RequiredArgsConstructor
Expand All @@ -42,13 +48,25 @@ public class StompFluxClient {
private final URI transportUri;
private final Mono<HttpHeaders> webSocketHeadersProvider;
private final ObjectMapper objectMapper;
private final Duration heartbeatSendFrequency;
private final Duration heartbeatReceiveFrequency;

private final EmitterProcessor<StompFrame> responseProcessor = EmitterProcessor.create();

private final EmitterProcessor<StompFrame> streamRequestProcessor = EmitterProcessor.create();

private final FluxSink<StompFrame> streamRequestSink = streamRequestProcessor.sink();

private final EmitterProcessor<Duration> heartbeatSendProcessor = EmitterProcessor.create();

private final FluxSink<Duration> heartbeatSendSink = heartbeatSendProcessor.sink();

private final EmitterProcessor<Duration> heartbeatExpectationProcessor
= EmitterProcessor.create();

private final FluxSink<Duration> heartbeatExpectationSink
= heartbeatExpectationProcessor.sink();

private final AtomicInteger maxSubscriptionNumber = new AtomicInteger(0);

private final Map<String, String> subscriptionDestinationIdMap = new HashMap<>();
Expand Down Expand Up @@ -77,23 +95,34 @@ public void connect() {
try {
RawStompFluxClient stompFluxClient =
new RawStompFluxClient(
webSocketClient, transportUri, webSocketHeadersProvider);
webSocketClient,
transportUri,
webSocketHeadersProvider,
heartbeatSendFrequency,
heartbeatReceiveFrequency);

Flux<StompFrame> heartbeats = heartbeatSendProcessor.switchMap(this::createHeartbeats);
Flux<StompMessageFrame> heartbeatExpectation = heartbeatExpectationProcessor
.switchMap(this::listenHeartbeats);

Publisher<StompFrame> sendWithResponse =
Flux.merge(streamRequestProcessor, responseProcessor);
Flux.merge(streamRequestProcessor, heartbeats, responseProcessor);

Flux<StompMessageFrame> stream = stompFluxClient.get(sendWithResponse)
.doOnNext(frame -> {
log.info("Got frame " + frame.getClass());
if (StompConnectedFrame.class.equals(frame.getClass())) {
handleConnectStream();
handleConnectStream((StompConnectedFrame) frame);
}
})
.flatMap(new ThrowableMapper<StompFrame, StompFrame>(this::handleErrorFrame))
.mergeWith(heartbeatExpectation)
.filter(frame -> frame.getHeaderAccessor().getCommand().equals(StompCommand.MESSAGE))
.cast(StompMessageFrame.class)
.onErrorContinue(JsonProcessingException.class, this::warnAndDropError)
.onErrorContinue(JsonProcessingException.class, this::warnAndDropError)
.doOnError(this::sendErrorFrame)
.doAfterTerminate(this::handleTerminateStream)
.retryBackoff(maxRetriesOnError, retryOnErrorFirstBackoff)
.retryBackoff(maxRetriesOnError, retryOnErrorFirstBackoff)
.publish()
.refCount();

Expand All @@ -109,11 +138,62 @@ protected void handleTerminateStream() {
}


protected void handleConnectStream() {
protected void handleConnectStream(StompConnectedFrame frame) {
expectHeartbeats(frame.getHeartbeatSendFrequency());
sendHeartbeats(frame.getHeartbeatReceiveFrequency());
isConnected.set(true);
resubscribeAll();
}

protected StompFrame handleErrorFrame(StompFrame frame) throws RemoteStreamException {
if (StompErrorFrame.class.equals(frame.getClass())) {
throw new RemoteStreamException(((StompErrorFrame) frame).getMessage());
}
return frame;
}

protected void sendErrorFrame(Throwable error) {
StompFrame frame = StompErrorFrame.builder()
.message(error.getLocalizedMessage())
.build();
streamRequestSink.next(frame);
}

protected Publisher<StompHeartbeatFrame> createHeartbeats(Duration frequency) {
if (frequency.isZero()) {
return Flux.empty();
}

return Flux.interval(frequency)
.log("HEARTBEAT")
.map(count -> new StompHeartbeatFrame())
.startWith(new StompHeartbeatFrame());
}

protected Publisher<StompMessageFrame> listenHeartbeats(Duration frequency) {
if (frequency.isZero() || sharedStream.get() == null) {
return Flux.empty();
}

return new ConditionalTimeoutPublisher<>(
sharedStream.get(),
frequency,
value -> true,
error -> new MissingHeartbeatException("No heartbeat within " + frequency, error),
null,
Schedulers.parallel()).get();
}

protected void sendHeartbeats(Duration frequency) {
log.info("Sending heartbeats every " + frequency.toString());
heartbeatSendSink.next(frequency);
}

protected void expectHeartbeats(Duration frequency) {
log.info("Expecting heartbeats every " + frequency.toString());
heartbeatExpectationSink.next(frequency);
}

protected void warnAndDropError(Throwable ex, Object value) {
log.log(
Level.WARNING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.trickl.flux.websocket.stomp.StompFrame;
import com.trickl.flux.websocket.stomp.frames.StompConnectedFrame;
import com.trickl.flux.websocket.stomp.frames.StompErrorFrame;
import com.trickl.flux.websocket.stomp.frames.StompHeartbeatFrame;
import com.trickl.flux.websocket.stomp.frames.StompMessageFrame;
import com.trickl.flux.websocket.stomp.frames.StompReceiptFrame;

Expand All @@ -11,6 +12,7 @@
import lombok.RequiredArgsConstructor;

import org.springframework.messaging.Message;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.stomp.StompConversionException;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;

Expand All @@ -24,18 +26,20 @@ public class StompFrameBuilder implements Function<Message<byte[]>, StompFrame>
* @throws StompConversionException If
*/
@Override
public StompFrame apply(Message<byte[]> message) {
public StompFrame apply(Message<byte[]> message) {
StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message);
if (headerAccessor.getMessageType().equals(SimpMessageType.HEARTBEAT)) {
return new StompHeartbeatFrame();
}
switch (headerAccessor.getCommand()) {
case MESSAGE:
return StompMessageFrame.create(
headerAccessor, message.getPayload());
return StompMessageFrame.from(headerAccessor, message.getPayload());
case CONNECTED:
return StompConnectedFrame.create(headerAccessor);
return StompConnectedFrame.from(headerAccessor);
case RECEIPT:
return StompReceiptFrame.create(headerAccessor);
return StompReceiptFrame.from(headerAccessor);
case ERROR:
return StompErrorFrame.create(headerAccessor);
return StompErrorFrame.from(headerAccessor);
default:
throw new StompConversionException("Unable to decode STOMP message"
+ headerAccessor.toMap());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import com.trickl.flux.websocket.stomp.StompFrame;

import java.time.Duration;

import lombok.Builder;
import lombok.Data;

Expand All @@ -16,13 +18,22 @@ public class StompConnectFrame implements StompFrame {
protected String acceptVersion;
protected String host;

@Builder.Default
protected Duration heartbeatSendFrequency = Duration.ZERO;

@Builder.Default
protected Duration heartbeatReceiveFrequency = Duration.ZERO;

/**
* Get the stomp headers for this message.
*/
public StompHeaderAccessor getHeaderAccessor() {
StompHeaderAccessor stompHeaderAccessor = StompHeaderAccessor.create(StompCommand.CONNECT);
stompHeaderAccessor.setAcceptVersion(acceptVersion);
stompHeaderAccessor.setHost(host);
stompHeaderAccessor.setHeartbeat(
heartbeatSendFrequency.toMillis(),
heartbeatReceiveFrequency.toMillis());
return stompHeaderAccessor;
}

Expand Down
Loading

0 comments on commit 90cd891

Please sign in to comment.