Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert headers from being read-only upon message creation #1123

Merged
merged 3 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions src/main/java/io/nats/client/impl/Headers.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,13 @@ public Headers(Headers headers, boolean readOnly, String[] keysNotToCopy) {
* -or- if any value contains invalid characters
*/
public Headers add(String key, String... values) {
if (values != null) {
_add(key, Arrays.asList(values));
if (readOnly) {
throw new UnsupportedOperationException();
}
return this;
if (values == null || values.length == 0) {
return this;
}
return _add(key, Arrays.asList(values));
}

/**
Expand All @@ -109,12 +112,17 @@ public Headers add(String key, String... values) {
* -or- if any value contains invalid characters
*/
public Headers add(String key, Collection<String> values) {
_add(key, values);
return this;
if (readOnly) {
throw new UnsupportedOperationException();
}
if (values == null || values.isEmpty()) {
return this;
}
return _add(key, values);
}

// the add delegate
private void _add(String key, Collection<String> values) {
private Headers _add(String key, Collection<String> values) {
if (values != null) {
Checker checked = new Checker(key, values);
if (checked.hasValues()) {
Expand All @@ -129,6 +137,7 @@ private void _add(String key, Collection<String> values) {
serialized = null; // since the data changed, clear this so it's rebuilt
}
}
return this;
}

/**
Expand All @@ -143,10 +152,13 @@ private void _add(String key, Collection<String> values) {
* -or- if any value contains invalid characters
*/
public Headers put(String key, String... values) {
if (values != null) {
_put(key, Arrays.asList(values));
if (readOnly) {
throw new UnsupportedOperationException();
}
return this;
if (values == null || values.length == 0) {
return this;
}
return _put(key, Arrays.asList(values));
}

/**
Expand All @@ -161,8 +173,13 @@ public Headers put(String key, String... values) {
* -or- if any value contains invalid characters
*/
public Headers put(String key, Collection<String> values) {
_put(key, values);
return this;
if (readOnly) {
throw new UnsupportedOperationException();
}
if (values == null || values.isEmpty()) {
return this;
}
return _put(key, values);
}

/**
Expand All @@ -173,14 +190,20 @@ public Headers put(String key, Collection<String> values) {
* @return the Headers object
*/
public Headers put(Map<String, List<String>> map) {
if (readOnly) {
throw new UnsupportedOperationException();
}
if (map == null || map.isEmpty()) {
return this;
}
for (String key : map.keySet() ) {
put(key, map.get(key));
_put(key, map.get(key));
}
return this;
}

// the put delegate that all puts call
private void _put(String key, Collection<String> values) {
// the put delegate
private Headers _put(String key, Collection<String> values) {
if (key == null || key.isEmpty()) {
throw new IllegalArgumentException("Key cannot be null or empty.");
}
Expand All @@ -195,6 +218,7 @@ private void _put(String key, Collection<String> values) {
serialized = null; // since the data changed, clear this so it's rebuilt
}
}
return this;
}

/**
Expand All @@ -203,6 +227,9 @@ private void _put(String key, Collection<String> values) {
* @param keys the key or keys to remove
*/
public void remove(String... keys) {
if (readOnly) {
throw new UnsupportedOperationException();
}
for (String key : keys) {
_remove(key);
}
Expand All @@ -215,12 +242,16 @@ public void remove(String... keys) {
* @param keys the key or keys to remove
*/
public void remove(Collection<String> keys) {
if (readOnly) {
throw new UnsupportedOperationException();
}
for (String key : keys) {
_remove(key);
}
serialized = null; // since the data changed, clear this so it's rebuilt
}

// the remove delegate
private void _remove(String key) {
// if the values had a key, then the data length had a length
if (valuesMap.remove(key) != null) {
Expand Down Expand Up @@ -250,6 +281,9 @@ public boolean isEmpty() {
* Removes all the keys The object map will be empty after this call returns.
*/
public void clear() {
if (readOnly) {
throw new UnsupportedOperationException();
}
valuesMap.clear();
lengthMap.clear();
dataLength = 0;
Expand Down
16 changes: 15 additions & 1 deletion src/main/java/io/nats/client/impl/IncomingMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,27 @@

package io.nats.client.impl;

import io.nats.client.support.ByteArrayBuilder;

public class IncomingMessage extends NatsMessage {
IncomingMessage() {}
IncomingMessage() {
super((byte[])null);
}

IncomingMessage(byte[] data) {
super(data);
}

@Override
protected void calculate() {
// intentionally does nothing
}

@Override
ByteArrayBuilder getProtocolBab() {
throw new IllegalStateException("getProtocolBab not supported for this type of message.");
}

@Override
byte[] getProtocolBytes() {
throw new IllegalStateException("getProtocolBytes not supported for this type of message.");
Expand Down
55 changes: 25 additions & 30 deletions src/main/java/io/nats/client/impl/NatsConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -836,31 +836,31 @@ void cleanUpPongQueue() {
*/
@Override
public void publish(String subject, byte[] body) {
publishInternal(subject, null, null, body);
publishInternal(subject, null, null, body, true);
}

/**
* {@inheritDoc}
*/
@Override
public void publish(String subject, Headers headers, byte[] body) {
publishInternal(subject, null, headers, body);
publishInternal(subject, null, headers, body, true);
}

/**
* {@inheritDoc}
*/
@Override
public void publish(String subject, String replyTo, byte[] body) {
publishInternal(subject, replyTo, null, body);
publishInternal(subject, replyTo, null, body, true);
}

/**
* {@inheritDoc}
*/
@Override
public void publish(String subject, String replyTo, Headers headers, byte[] body) {
publishInternal(subject, replyTo, headers, body);
publishInternal(subject, replyTo, headers, body, true);
}

/**
Expand All @@ -869,35 +869,30 @@ public void publish(String subject, String replyTo, Headers headers, byte[] body
@Override
public void publish(Message message) {
validateNotNull(message, "Message");
publishInternal(message.getSubject(), message.getReplyTo(), message.getHeaders(), message.getData());
publishInternal(message.getSubject(), message.getReplyTo(), message.getHeaders(), message.getData(), false);
}

void publishInternal(String subject, String replyTo, Headers headers, byte[] data) {
checkIfNeedsHeaderSupport(headers);
void publishInternal(String subject, String replyTo, Headers headers, byte[] data, boolean validateSubRep) {
checkPayloadSize(data);
NatsPublishableMessage npm = new NatsPublishableMessage(subject, replyTo, headers, data, validateSubRep);
if (npm.hasHeaders && !serverInfo.get().isHeadersSupported()) {
throw new IllegalArgumentException("Headers are not supported by the server, version: " + serverInfo.get().getVersion());
}

if (isClosed()) {
throw new IllegalStateException("Connection is Closed");
} else if (blockPublishForDrain.get()) {
throw new IllegalStateException("Connection is Draining"); // Ok to publish while waiting on subs
}

NatsMessage nm = new NatsMessage(subject, replyTo, new Headers(headers), data);

Connection.Status stat = this.status;
if ((stat == Status.RECONNECTING || stat == Status.DISCONNECTED)
&& !this.writer.canQueueDuringReconnect(nm)) {
&& !this.writer.canQueueDuringReconnect(npm)) {
throw new IllegalStateException(
"Unable to queue any more messages during reconnect, max buffer is " + options.getReconnectBufferSize());
}
queueOutgoing(nm);
}

private void checkIfNeedsHeaderSupport(Headers headers) {
if (headers != null && !headers.isEmpty() && !serverInfo.get().isHeadersSupported()) {
throw new IllegalArgumentException(
"Headers are not supported by the server, version: " + serverInfo.get().getVersion());
}
queueOutgoing(npm);
}

private void checkPayloadSize(byte[] body) {
Expand Down Expand Up @@ -1099,15 +1094,15 @@ else if (future.isDone()) {
*/
@Override
public Message request(String subject, byte[] body, Duration timeout) throws InterruptedException {
return requestInternal(subject, null, body, timeout, cancelAction);
return requestInternal(subject, null, body, timeout, cancelAction, true);
}

/**
* {@inheritDoc}
*/
@Override
public Message request(String subject, Headers headers, byte[] body, Duration timeout) throws InterruptedException {
return requestInternal(subject, headers, body, timeout, cancelAction);
return requestInternal(subject, headers, body, timeout, cancelAction, true);
}

/**
Expand All @@ -1116,11 +1111,11 @@ public Message request(String subject, Headers headers, byte[] body, Duration ti
@Override
public Message request(Message message, Duration timeout) throws InterruptedException {
validateNotNull(message, "Message");
return requestInternal(message.getSubject(), message.getHeaders(), message.getData(), timeout, cancelAction);
return requestInternal(message.getSubject(), message.getHeaders(), message.getData(), timeout, cancelAction, false);
}

Message requestInternal(String subject, Headers headers, byte[] data, Duration timeout, CancelAction cancelAction) throws InterruptedException {
CompletableFuture<Message> incoming = requestFutureInternal(subject, headers, data, timeout, cancelAction);
Message requestInternal(String subject, Headers headers, byte[] data, Duration timeout, CancelAction cancelAction, boolean validateSubRep) throws InterruptedException {
CompletableFuture<Message> incoming = requestFutureInternal(subject, headers, data, timeout, cancelAction, validateSubRep);
try {
return incoming.get(timeout.toNanos(), TimeUnit.NANOSECONDS);
} catch (TimeoutException | ExecutionException | CancellationException e) {
Expand All @@ -1133,31 +1128,31 @@ Message requestInternal(String subject, Headers headers, byte[] data, Duration t
*/
@Override
public CompletableFuture<Message> request(String subject, byte[] body) {
return requestFutureInternal(subject, null, body, null, cancelAction);
return requestFutureInternal(subject, null, body, null, cancelAction, true);
}

/**
* {@inheritDoc}
*/
@Override
public CompletableFuture<Message> request(String subject, Headers headers, byte[] body) {
return requestFutureInternal(subject, headers, body, null, cancelAction);
return requestFutureInternal(subject, headers, body, null, cancelAction, true);
}

/**
* {@inheritDoc}
*/
@Override
public CompletableFuture<Message> requestWithTimeout(String subject, byte[] body, Duration timeout) {
return requestFutureInternal(subject, null, body, timeout, cancelAction);
return requestFutureInternal(subject, null, body, timeout, cancelAction, true);
}

/**
* {@inheritDoc}
*/
@Override
public CompletableFuture<Message> requestWithTimeout(String subject, Headers headers, byte[] body, Duration timeout) {
return requestFutureInternal(subject, headers, body, timeout, cancelAction);
return requestFutureInternal(subject, headers, body, timeout, cancelAction, true);
}

/**
Expand All @@ -1166,7 +1161,7 @@ public CompletableFuture<Message> requestWithTimeout(String subject, Headers hea
@Override
public CompletableFuture<Message> requestWithTimeout(Message message, Duration timeout) {
validateNotNull(message, "Message");
return requestFutureInternal(message.getSubject(), message.getHeaders(), message.getData(), timeout, cancelAction);
return requestFutureInternal(message.getSubject(), message.getHeaders(), message.getData(), timeout, cancelAction, false);
}

/**
Expand All @@ -1175,10 +1170,10 @@ public CompletableFuture<Message> requestWithTimeout(Message message, Duration t
@Override
public CompletableFuture<Message> request(Message message) {
validateNotNull(message, "Message");
return requestFutureInternal(message.getSubject(), message.getHeaders(), message.getData(), null, cancelAction);
return requestFutureInternal(message.getSubject(), message.getHeaders(), message.getData(), null, cancelAction, false);
}

CompletableFuture<Message> requestFutureInternal(String subject, Headers headers, byte[] data, Duration futureTimeout, CancelAction cancelAction) {
CompletableFuture<Message> requestFutureInternal(String subject, Headers headers, byte[] data, Duration futureTimeout, CancelAction cancelAction, boolean validateSubRep) {
checkPayloadSize(data);

if (isClosed()) {
Expand Down Expand Up @@ -1230,7 +1225,7 @@ CompletableFuture<Message> requestFutureInternal(String subject, Headers headers
responsesAwaiting.put(sub.getSID(), future);
}

publishInternal(subject, responseInbox, headers, data);
publishInternal(subject, responseInbox, headers, data, validateSubRep);
writer.flushBuffer();
statistics.incrementRequestsSent();

Expand Down
11 changes: 6 additions & 5 deletions src/main/java/io/nats/client/impl/NatsConnectionWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import io.nats.client.Options;
import io.nats.client.StatisticsCollector;
import io.nats.client.support.ByteArrayBuilder;

import java.io.IOException;
import java.nio.BufferOverflowException;
Expand Down Expand Up @@ -125,8 +126,7 @@ Future<Boolean> stop() {
// Clear old ping/pong requests
this.outgoing.filter((msg) ->
msg.isProtocol() &&
(msg.protocolBab.equals(OP_PING_BYTES) || msg.protocolBab.equals(OP_PONG_BYTES)));

(msg.getProtocolBab().equals(OP_PING_BYTES) || msg.getProtocolBab().equals(OP_PONG_BYTES)));
}
finally {
this.startStopLock.unlock();
Expand Down Expand Up @@ -162,9 +162,10 @@ void sendMessageBatch(NatsMessage msg, DataPort dataPort, StatisticsCollector st
}
}

int blen = msg.protocolBab.length();
System.arraycopy(msg.protocolBab.internalArray(), 0, sendBuffer, sendPosition, blen);
sendPosition += blen;
ByteArrayBuilder bab = msg.getProtocolBab();
int babLen = bab.length();
System.arraycopy(bab.internalArray(), 0, sendBuffer, sendPosition, babLen);
sendPosition += babLen;

sendBuffer[sendPosition++] = CR;
sendBuffer[sendPosition++] = LF;
Expand Down
Loading
Loading