Skip to content

Commit

Permalink
Addressed more code review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
nandita727 committed Jan 25, 2024
1 parent e3cac15 commit eeefdbe
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 64 deletions.
4 changes: 2 additions & 2 deletions src/main/java/com/meta/cp4m/llm/HuggingFaceLlamaPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public T handle(ThreadState<T> threadState) throws IOException {
Optional<String> prompt = promptCreator.createPrompt(threadState);
if (prompt.isEmpty()) {
return threadState.newMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.", threadState.tail());
Instant.now(), "I'm sorry but that request was too long for me.");
}

body.put("inputs", prompt.get());
Expand All @@ -72,6 +72,6 @@ public T handle(ThreadState<T> threadState) throws IOException {
String llmResponse = allGeneratedText.strip().replace(prompt.get().strip(), "");
Instant timestamp = Instant.now();

return threadState.newMessageFromBot(timestamp, llmResponse, threadState.tail());
return threadState.newMessageFromBot(timestamp, llmResponse);
}
}
4 changes: 2 additions & 2 deletions src/main/java/com/meta/cp4m/llm/OpenAIPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public T handle(ThreadState<T> threadState) throws IOException {
Optional<ArrayNode> prunedMessages = pruneMessages(messages, null);
if (prunedMessages.isEmpty()) {
return threadState.newMessageFromBot(
Instant.now(), "I'm sorry but that request was too long for me.",fromUser);
Instant.now(), "I'm sorry but that request was too long for me.");
}
body.set("messages", prunedMessages.get());

Expand All @@ -182,6 +182,6 @@ public T handle(ThreadState<T> threadState) throws IOException {
Instant timestamp = Instant.ofEpochSecond(responseBody.get("created").longValue());
JsonNode choice = responseBody.get("choices").get(0);
String messageContent = choice.get("message").get("content").textValue();
return threadState.newMessageFromBot(timestamp, messageContent,fromUser);
return threadState.newMessageFromBot(timestamp, messageContent);
}
}
5 changes: 2 additions & 3 deletions src/main/java/com/meta/cp4m/message/FBMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ public record FBMessage(
Role role,
Message parentMessage)
implements Message {
private static final MessageFactory<FBMessage> MESSAGE_FACTORY = MessageFactory.instance(FBMessage.class);
@Override
public Message addParentMessage(Message parentMessage) {
return MESSAGE_FACTORY.newMessage(timestamp(),message(),senderId(),recipientId(),instanceId(), Role.USER, parentMessage);
public Message withParentMessage(Message parentMessage) {
return new FBMessage(timestamp(),instanceId(),senderId(),recipientId(),message(), role(), parentMessage);
}
}
2 changes: 1 addition & 1 deletion src/main/java/com/meta/cp4m/message/FBMessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ private List<FBMessage> postHandler(Context ctx, JsonNode body) {

@Nullable JsonNode textObject = messageObject.get("text");
if (textObject != null && textObject.isTextual()) {
FBMessage m = MESSAGE_FACTORY.newMessage(timestamp, textObject.textValue(), senderId, recipientId,messageId, Message.Role.USER,null);
FBMessage m = MESSAGE_FACTORY.newMessage(timestamp, textObject.textValue(), senderId, recipientId,messageId, Message.Role.USER);
output.add(m);
} else {
LOGGER
Expand Down
18 changes: 10 additions & 8 deletions src/main/java/com/meta/cp4m/message/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,19 @@
package com.meta.cp4m.message;

import com.meta.cp4m.Identifier;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.time.Instant;

public interface Message {
public static Identifier threadId(Identifier id1, Identifier id2) {
private static Identifier threadId(Identifier id1, Identifier id2) {
if (id1.compareTo(id2) <= 0) {
return Identifier.from(id1.toString() + '|' + id2);
}
return Identifier.from(id2.toString() + '|' + id1);
}

static void parentMessage(Message parentMessage){

}

public Message addParentMessage(Message parentMessage);
public <T extends Message> T withParentMessage(Message parentMessage);

Instant timestamp();

Expand All @@ -37,7 +35,7 @@ static void parentMessage(Message parentMessage){

Role role();

Message parentMessage();
@Nullable Message parentMessage();

default Identifier threadId() {
return threadId(senderId(), recipientId());
Expand All @@ -48,11 +46,15 @@ enum Role {
USER(1),
SYSTEM(2);

public final Integer priority;
private final int priority;

private Role(Integer priority){
this.priority = priority;
}

public int getPriority(){
return this.priority;
}

}
}
7 changes: 3 additions & 4 deletions src/main/java/com/meta/cp4m/message/MessageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public interface MessageFactory<T extends Message> {
Map<Class<? extends Message>, MessageFactory<? extends Message>> FACTORY_MAP =
Stream.<FactoryContainer<?>>of(
new FactoryContainer<>(
FBMessage.class, (t, m, si, ri, ii, r, pm) -> new FBMessage(t, ii, si, ri, m, r, pm)),
FBMessage.class, (t, m, si, ri, ii, r) -> new FBMessage(t, ii, si, ri, m, r,null)),
new FactoryContainer<>(
WAMessage.class, (t, m, si, ri, ii, r, pm) -> new WAMessage(t, ii, si, ri, m, r, pm)))
WAMessage.class, (t, m, si, ri, ii, r) -> new WAMessage(t, ii, si, ri, m, r,null)))
.collect(
Collectors.toUnmodifiableMap(FactoryContainer::clazz, FactoryContainer::factory));

Expand All @@ -46,8 +46,7 @@ T newMessage(
Identifier senderId,
Identifier recipientId,
Identifier instanceId,
Role role,
Message parentMessage);
Role role);

/** this exists to provide compiler guarantees for type matching in the FACTORY_MAP */
record FactoryContainer<T extends Message>(Class<T> clazz, MessageFactory<T> factory) {}
Expand Down
22 changes: 15 additions & 7 deletions src/main/java/com/meta/cp4m/message/ThreadState.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,24 @@ private ThreadState(ThreadState<T> old, T newMessage) {
old.tail().threadId().equals(newMessage.threadId()),
"all messages in a thread must have the same thread id");
List<T> messages = old.messages;
T mWithParentMessage = newMessage.role() == Role.USER ? (T) newMessage.addParentMessage(old.tail()): newMessage;
T mWithParentMessage = newMessage.role() == Role.USER ? newMessage.withParentMessage(old.tail()): newMessage;
this.messages =
Stream.concat(messages.stream(), Stream.of(mWithParentMessage))
.sorted((m1,m2) -> m1.parentMessage() == m2.parentMessage() ? (m1.role().priority.compareTo(m2.role().priority)) : (m1.timestamp().compareTo(m2.timestamp())))
.sorted((m1,m2) -> m1.parentMessage() == m2.parentMessage() ? compare(m1.role().getPriority(),m2.role().getPriority()) : (m1.timestamp().compareTo(m2.timestamp())))
.collect(Collectors.toUnmodifiableList());

Identifier oldUserId = old.userId();
Identifier userId = userId();
Identifier oldBotId = old.botId();
Identifier botId = botId();
Preconditions.checkArgument(
old.userId().equals(userId()) && old.botId().equals(botId()),
"userId and botId not consistent with this thread state");
}

private int compare(int priority1, int priority2){
return priority1 > priority2 ? +1 : priority1 < priority2 ? -1 : 0;
}

public static <T extends Message> ThreadState<T> of(T message) {
return new ThreadState<>(message);
}
Expand All @@ -74,13 +81,14 @@ public Identifier botId() {
};
}

public T newMessageFromBot(Instant timestamp, String message, T parentMessage) {
return messageFactory.newMessage(
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT, parentMessage);
public T newMessageFromBot(Instant timestamp, String message) {
T newMessage = messageFactory.newMessage(
timestamp, message, botId(), userId(), Identifier.random(), Role.ASSISTANT);
return newMessage.withParentMessage(tail());
}

public T newMessageFromUser(Instant timestamp, String message, Identifier instanceId) {
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER, this.tail());
return messageFactory.newMessage(timestamp, message, userId(), botId(), instanceId, Role.USER);
}

public ThreadState<T> with(T message) {
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/com/meta/cp4m/message/WAMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ public record WAMessage(
Role role,
Message parentMessage)
implements Message {
private static final MessageFactory<WAMessage> MESSAGE_FACTORY = MessageFactory.instance(WAMessage.class);
@Override
public Message addParentMessage(Message parentMessage) {
return MESSAGE_FACTORY.newMessage(timestamp(),message(),senderId(),recipientId(),instanceId(), Role.USER, parentMessage);
public Message withParentMessage(Message parentMessage) {
return new WAMessage(timestamp(),instanceId(),senderId(),recipientId(),message(), Role.USER, parentMessage);
}
}
2 changes: 1 addition & 1 deletion src/main/java/com/meta/cp4m/message/WAMessageHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private List<WAMessage> post(Context ctx, WebhookPayload payload) {
continue;
}
TextWebhookMessage textMessage = (TextWebhookMessage) message;
WAMessage waMessage = MESSAGE_FACTORY.newMessage(message.timestamp(), textMessage.text().body(), message.from(), phoneNumberId,message.id(), Message.Role.USER,null);
WAMessage waMessage = MESSAGE_FACTORY.newMessage(message.timestamp(), textMessage.text().body(), message.from(), phoneNumberId,message.id(), Message.Role.USER);
readExecutor.execute(() -> markRead(phoneNumberId, textMessage.id().toString()));
waMessages.add(waMessage);
}
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/com/meta/cp4m/llm/DummyLLMPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ public String dummyResponse() {
@Override
public T handle(ThreadState<T> threadState) {
receivedThreadStates.add(threadState);
return threadState.newMessageFromBot(Instant.now(), dummyLLMResponse, threadState.tail());
return threadState.newMessageFromBot(Instant.now(), dummyLLMResponse);
}
}
11 changes: 5 additions & 6 deletions src/test/java/com/meta/cp4m/llm/HuggingFaceLlamaPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ public class HuggingFaceLlamaPluginTest {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER,
null));
Role.USER));

static {
SAMPLE_RESPONSE.addObject().put("generated_text", TEST_MESSAGE);
Expand Down Expand Up @@ -146,7 +145,7 @@ void createPayloadWithSystemMessage() {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER, null));
Role.USER));
HuggingFaceLlamaPrompt<FBMessage> promptBuilder =
new HuggingFaceLlamaPrompt<>(config.systemMessage(), config.maxInputTokens());
Optional<String> createdPayload = promptBuilder.createPrompt(stack);
Expand Down Expand Up @@ -174,7 +173,7 @@ void contextTooBig() throws IOException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER, null));
Role.USER));
FBMessage response = plugin.handle(thread);
assertThat(response.message()).isEqualTo("I'm sorry but that request was too long for me.");
}
Expand All @@ -198,7 +197,7 @@ void truncatesContext() throws IOException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER,null));
Role.USER));
thread =
thread.with(thread.newMessageFromUser(Instant.now(), "test message", Identifier.from(2)));
HuggingFaceLlamaPrompt<FBMessage> promptBuilder =
Expand Down Expand Up @@ -257,7 +256,7 @@ void orderedCorrectly() throws IOException, InterruptedException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER, null));
Role.USER));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
stack = stack.with(stack.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/com/meta/cp4m/llm/OpenAIPluginTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public class OpenAIPluginTest {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER, null));
Role.USER));

static {
((ObjectNode) SAMPLE_RESPONSE)
Expand Down Expand Up @@ -198,7 +198,7 @@ void orderedCorrectly() throws IOException, InterruptedException {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Role.USER, null));
Role.USER));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "2", Identifier.from(2)));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "3", Identifier.from(3)));
thread = thread.with(thread.newMessageFromUser(Instant.now(), "4", Identifier.from(4)));
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/com/meta/cp4m/store/MemoryStoreTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ void test() {
assertThat(memoryStore.size()).isEqualTo(0);
FBMessage message =
messageFactory.newMessage(
Instant.now(), "", senderId, recipientId, Identifier.random(), Message.Role.ASSISTANT, null);
Instant.now(), "", senderId, recipientId, Identifier.random(), Message.Role.ASSISTANT);
ThreadState<FBMessage> thread = memoryStore.add(message);
assertThat(memoryStore.size()).isEqualTo(1);
assertThat(thread.messages()).hasSize(1).contains(message);

FBMessage message2 =
messageFactory.newMessage(
Instant.now(), "", recipientId, senderId, Identifier.random(), Message.Role.USER, null);
Instant.now(), "", recipientId, senderId, Identifier.random(), Message.Role.USER);
thread = memoryStore.add(message2);
assertThat(memoryStore.size()).isEqualTo(1);
assertThat(thread.messages()).hasSize(2);
Expand All @@ -48,7 +48,7 @@ void test() {
Identifier.random(),
Identifier.random(),
Identifier.random(),
Message.Role.USER, null);
Message.Role.USER);
thread = memoryStore.add(message3);
assertThat(memoryStore.size()).isEqualTo(2);
assertThat(thread.messages()).hasSize(1).contains(message3);
Expand Down
Loading

0 comments on commit eeefdbe

Please sign in to comment.