Skip to content

Commit

Permalink
Changed internal message representation to an object that includes me…
Browse files Browse the repository at this point in the history
…ssage as well as a pointer to the parent message.
  • Loading branch information
nandita727 committed Feb 12, 2024
1 parent df06d48 commit c61b40d
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 61 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/meta/cp4m/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private void execute(ThreadState<T> thread) {
LOGGER.error("failed to communicate with LLM", e);
return;
}
store.add(llmResponse);
store.add(thread,llmResponse);
try {
handler.respond(llmResponse);
} catch (Exception e) {
Expand Down
10 changes: 2 additions & 8 deletions src/main/java/com/meta/cp4m/message/FBMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,5 @@ public record FBMessage(
Identifier senderId,
Identifier recipientId,
String message,
Role role,
Message parentMessage)
implements Message {
@Override
public @NewObject FBMessage withParentMessage(Message parentMessage) {
return new FBMessage(timestamp(),instanceId(),senderId(),recipientId(),message(), role(), parentMessage);
}
}
Role role)
implements Message {}
6 changes: 1 addition & 5 deletions src/main/java/com/meta/cp4m/message/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ private static Identifier threadId(Identifier id1, Identifier id2) {
return Identifier.from(id2.toString() + '|' + id1);
}

public <T extends Message> @NewObject T withParentMessage(Message parentMessage);

Instant timestamp();

Identifier instanceId();
Expand All @@ -36,8 +34,6 @@ private static Identifier threadId(Identifier id1, Identifier id2) {

Role role();

@Nullable Message parentMessage();

default Identifier threadId() {
return threadId(senderId(), recipientId());
}
Expand All @@ -49,7 +45,7 @@ enum Role {

private final int priority;

private Role(Integer priority){
Role(Integer priority){
this.priority = priority;
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/meta/cp4m/message/MessageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public interface MessageFactory<T extends Message> {
Map<Class<? extends Message>, MessageFactory<? extends Message>> FACTORY_MAP =
Stream.<FactoryContainer<?>>of(
new FactoryContainer<>(
FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r,null)),
FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r)),
new FactoryContainer<>(
WAMessage.class, (t, m, si, ri, ii, r) -> new WAMessage(t, ii, si, ri, m, r,null)))
WAMessage.class, (t, m, si, ri, ii, r) -> new WAMessage(t, ii, si, ri, m, r)))
.collect(
Collectors.toUnmodifiableMap(FactoryContainer::clazz, FactoryContainer::factory));

Expand Down
35 changes: 35 additions & 0 deletions src/main/java/com/meta/cp4m/message/MessageNode.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.meta.cp4m.message;

public class MessageNode <T extends Message>{
T message;
T parentMessage;

public MessageNode(T message){
this.message = message;
this.parentMessage = null;
}
public MessageNode(T message, T parentMessage){
this.message = message;
this.parentMessage = parentMessage;
}
public T getMessage() {
return message;
}
public T getParentMessage() {
return parentMessage;
}
public void setMessage(T message) {
this.message = message;
}
public void setParentMessage(T parentMessage) {
this.parentMessage = parentMessage;
}
}
49 changes: 25 additions & 24 deletions src/main/java/com/meta/cp4m/message/ThreadState.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,41 @@
import java.util.stream.Stream;

public class ThreadState<T extends Message> {
private final List<T> messages;
private final List<MessageNode<T>> messageNodes;
private final MessageFactory<T> messageFactory;

private ThreadState(T message) {
Objects.requireNonNull(message);
Preconditions.checkArgument(
message.role() != Role.SYSTEM, "ThreadState should never hold a system message");
this.messages = ImmutableList.of(message);
MessageNode<T> messageNode = new MessageNode<>(message,null);
this.messageNodes = ImmutableList.of(messageNode);
messageFactory = MessageFactory.instance(message);
}

/** Constructor that exists to support the with method */
private ThreadState(ThreadState<T> old, T newMessage) {
private ThreadState(ThreadState<T> current, ThreadState<T> old, T newMessage) {
Objects.requireNonNull(newMessage);
Preconditions.checkArgument(
newMessage.role() != Role.SYSTEM, "ThreadState should never hold a system message");
messageFactory = old.messageFactory;
messageFactory = current.messageFactory;
Preconditions.checkArgument(
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<T> messages = old.messages;
T mWithParentMessage = newMessage.role() == Role.USER ? newMessage.withParentMessage(old.tail()): newMessage;
this.messages =
Stream.concat(messages.stream(), Stream.of(mWithParentMessage))
.sorted((m1,m2) -> m1.parentMessage() == m2.parentMessage() ? compare(m1.role().getPriority(),m2.role().getPriority()) : (m1.timestamp().compareTo(m2.timestamp())))
.collect(Collectors.toUnmodifiableList());
Identifier oldUserId = old.userId();
Identifier userId = userId();
Identifier oldBotId = old.botId();
Identifier botId = botId();
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<MessageNode<T>> messageNodes = current.messageNodes;
MessageNode<T> mWithParentMessage = new MessageNode<>(newMessage,old.tail());
this.messageNodes =
Stream.concat(messageNodes.stream(), Stream.of(mWithParentMessage))
.sorted((m1,m2) -> m1.getParentMessage() == m2.getParentMessage() ? compare(m1.getMessage().role().getPriority(),m2.getMessage().role().getPriority()) : (m1.getMessage().timestamp().compareTo(m2.getMessage().timestamp())))
.collect(Collectors.toUnmodifiableList());

Preconditions.checkArgument(
old.userId().equals(userId()) && old.botId().equals(botId()),
"userId and botId not consistent with this thread state");
old.userId().equals(userId()) && old.botId().equals(botId()),
"userId and botId not consistent with this thread state");
}

private int compare(int priority1, int priority2){
return priority1 > priority2 ? +1 : priority1 < priority2 ? -1 : 0;
return Integer.compare(priority1, priority2);
}

public static <T extends Message> ThreadState<T> of(T message) {
Expand Down Expand Up @@ -82,25 +80,28 @@ public Identifier botId() {
}

public T newMessageFromBot(Instant timestamp, String message) {
T newMessage = messageFactory.newMessage(
return messageFactory.newMessage(
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT);
return newMessage.withParentMessage(tail());
}

public T newMessageFromUser(Instant timestamp, String message, Identifier instanceId) {
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
}

public ThreadState<T> with(T message) {
return new ThreadState<>(this, message);
return new ThreadState<>(this,this, message);
}

public ThreadState<T> with(ThreadState<T> thread,T message) {
return new ThreadState<>(this,thread, message);
}

public List<T> messages() {
return messages;
return messageNodes.stream().map(MessageNode::getMessage).collect(Collectors.toList());
}

public T tail() {
return messages.get(messages.size() - 1);
return messageNodes.get(messageNodes.size() - 1).getMessage();
}

}
10 changes: 2 additions & 8 deletions src/main/java/com/meta/cp4m/message/WAMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,5 @@ public record WAMessage(
Identifier senderId,
Identifier recipientId,
String message,
Role role,
Message parentMessage)
implements Message {
@Override
public @NewObject WAMessage withParentMessage(Message parentMessage) {
return new WAMessage(timestamp(),instanceId(),senderId(),recipientId(),message(), role(), parentMessage);
}
}
Role role)
implements Message {}
2 changes: 2 additions & 0 deletions src/main/java/com/meta/cp4m/store/ChatStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public interface ChatStore<T extends Message> {

ThreadState<T> add(T message);

ThreadState<T> add(ThreadState<T> thread,T message);

long size();

List<ThreadState<T>> list();
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/com/meta/cp4m/store/MemoryStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ public ThreadState<T> add(T message) {
});
}

@Override
public ThreadState<T> add(ThreadState<T> thread, T message){
return this.store.asMap().compute(message.threadId(), (k,v) -> {return v.with(thread,message);});
}

@Override
public long size() {
return store.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ void chunkingHappens() throws IOException {
Stream.generate(() -> "0123456789.").limit(300).collect(Collectors.joining(" "));
FBMessage bigMessage =
new FBMessage(
Instant.now(), Identifier.random(), pageId, Identifier.random(), bigText, Role.USER,null);
Instant.now(), Identifier.random(), pageId, Identifier.random(), bigText, Role.USER);
messageHandler.respond(bigMessage);
assertThat(requests.size()).isEqualTo(300);
assertThat(requests).allSatisfy(m -> assertThat(m.body()).contains("0123456789"));
Expand Down
8 changes: 4 additions & 4 deletions src/test/java/com/meta/cp4m/message/MessageTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ void threadId() {
Identifier id0 = Identifier.from("0");
Identifier id1 = Identifier.from("1");
Identifier id2 = Identifier.from("2");
Message message = new FBMessage(timestamp, id0, id1, id2, "", Message.Role.ASSISTANT, null);
Message response = new FBMessage(timestamp, id0, id2, id1, "", Message.Role.ASSISTANT, null);
Message message = new FBMessage(timestamp, id0, id1, id2, "", Message.Role.ASSISTANT);
Message response = new FBMessage(timestamp, id0, id2, id1, "", Message.Role.ASSISTANT);
assertThat(message.threadId()).isEqualTo(response.threadId());

message =
Expand All @@ -33,15 +33,15 @@ void threadId() {
Identifier.from("12"),
Identifier.from("34"),
"",
Message.Role.ASSISTANT, null);
Message.Role.ASSISTANT);
response =
new FBMessage(
timestamp,
id0,
Identifier.from("1"),
Identifier.from("234"),
"",
Message.Role.ASSISTANT, null);
Message.Role.ASSISTANT);
assertThat(message.threadId()).isNotEqualTo(response.threadId());
}
}
16 changes: 8 additions & 8 deletions src/test/java/com/meta/cp4m/store/ThreadStateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ void botAndUserId() {
Message.Role.ASSISTANT);

final ThreadState<FBMessage> finalMs = ms;
assertThatCode(() -> finalMs.with(message2)).doesNotThrowAnyException();
assertThatCode(() -> finalMs.with(finalMs.newMessageFromBot(start, "")))
assertThatCode(() -> finalMs.with(finalMs,message2)).doesNotThrowAnyException();
assertThatCode(() -> finalMs.with(finalMs,finalMs.newMessageFromBot(start, "")))
.doesNotThrowAnyException();
assertThatCode(() -> finalMs.with(finalMs.newMessageFromUser(start, "", Identifier.random())))
.doesNotThrowAnyException();
ms = ms.with(message2);
ms = ms.with(ms,message2);
assertThat(ms.userId()).isEqualTo(message1.senderId());
assertThat(ms.botId()).isEqualTo(message1.recipientId());
FBMessage mDifferentSenderId =
Expand Down Expand Up @@ -149,7 +149,7 @@ void botAndUserId() {
message1.recipientId(),
Identifier.random(),
Message.Role.ASSISTANT);
assertThatThrownBy(() -> finalMs1.with(illegalRecipientId.withParentMessage(message1)))
assertThatThrownBy(() -> finalMs1.with(finalMs1,illegalRecipientId))
.isInstanceOf(IllegalArgumentException.class);
}

Expand Down Expand Up @@ -214,9 +214,9 @@ void orderPreservationWhenUserSendsTwoMessagesInARowFBMessage() {
ThreadState<FBMessage> finalMs = ms;
ms = ms.with(userMessage2);
FBMessage botMessage1 = finalMs.newMessageFromBot(start.plusSeconds(4), "bot sample message 1");
ms = ms.with(botMessage1);
ms = ms.with(finalMs,botMessage1);
FBMessage botMessage2 = ms.newMessageFromBot(start.plusSeconds(8), "bot sample message 2");
ms = ms.with(botMessage2);
ms = ms.with(ms, botMessage2);
assertThat(ms.messages()).hasSize(4);
assertThat(ms.messages().get(0).instanceId()).isSameAs(userMessage1.instanceId());
assertThat(ms.messages().get(1).instanceId()).isSameAs(botMessage1.instanceId());
Expand Down Expand Up @@ -249,9 +249,9 @@ void orderPreservationWhenUserSendsTwoMessagesInARowWAMessage() {
ThreadState<WAMessage> finalMs = ms;
ms = ms.with(userMessage2);
WAMessage botMessage1 = finalMs.newMessageFromBot(start.plusSeconds(4), "bot sample message 1");
ms = ms.with(botMessage1);
ms = ms.with(finalMs,botMessage1);
WAMessage botMessage2 = ms.newMessageFromBot(start.plusSeconds(8), "bot sample message 2");
ms = ms.with(botMessage2);
ms = ms.with(ms,botMessage2);
assertThat(ms.messages()).hasSize(4);
assertThat(ms.messages().get(0).instanceId()).isSameAs(userMessage1.instanceId());
assertThat(ms.messages().get(1).instanceId()).isSameAs(botMessage1.instanceId());
Expand Down

0 comments on commit c61b40d

Please sign in to comment.