Skip to content

Commit

Permalink
add mirror plugin, null store explicitly configurable, and add test f…
Browse files Browse the repository at this point in the history
…or null store
  • Loading branch information
hunterjackson committed Jun 7, 2024
1 parent 699f447 commit 4aef434
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 38 deletions.
8 changes: 8 additions & 0 deletions src/main/java/com/meta/cp4m/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ public MessageHandler<T> messageHandler() {
return this.handler;
}

public ChatStore<T> store() {
return this.store;
}

public LLMPlugin<T> plugin() {
return this.llmPlugin;
}

private void execute(ThreadState<T> thread) {
T llmResponse;
try {
Expand Down
1 change: 1 addition & 0 deletions src/main/java/com/meta/cp4m/llm/LLMConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@JsonSubTypes({
@JsonSubTypes.Type(value = OpenAIConfig.class, name = "openai"),
@JsonSubTypes.Type(value = HuggingFaceConfig.class, name = "hugging_face"),
@JsonSubTypes.Type(value = MirrorPluginConfig.class, name = "mirror"),
})
public interface LLMConfig {

Expand Down
22 changes: 22 additions & 0 deletions src/main/java/com/meta/cp4m/llm/MirrorPlugin.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.meta.cp4m.llm;

import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;
import java.io.IOException;
import java.time.Instant;

public class MirrorPlugin<T extends Message> implements LLMPlugin<T> {

@Override
public T handle(ThreadState<T> threadState) throws IOException {
return threadState.newMessageFromBot(Instant.now(), threadState.tail().message());
}
}
19 changes: 19 additions & 0 deletions src/main/java/com/meta/cp4m/llm/MirrorPluginConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.meta.cp4m.llm;

import com.meta.cp4m.message.Message;

public record MirrorPluginConfig(String name) implements LLMConfig {

@Override
public <T extends Message> LLMPlugin<T> toPlugin() {
return new MirrorPlugin<>();
}
}
19 changes: 19 additions & 0 deletions src/main/java/com/meta/cp4m/store/NullStoreConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.meta.cp4m.store;

import com.meta.cp4m.message.Message;

public record NullStoreConfig(String name) implements StoreConfig {

@Override
public <T extends Message> ChatStore<T> toStore() {
return new NullStore<>();
}
}
1 change: 1 addition & 0 deletions src/main/java/com/meta/cp4m/store/StoreConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
@JsonSubTypes({
@JsonSubTypes.Type(value = MemoryStoreConfig.class, name = "memory"),
@JsonSubTypes.Type(value = NullStoreConfig.class, name = "null"),
})
public interface StoreConfig {

Expand Down
82 changes: 82 additions & 0 deletions src/test/java/com/meta/cp4m/llm/MirrorPluginTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

package com.meta.cp4m.llm;

import static org.assertj.core.api.Assertions.assertThat;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.dataformat.toml.TomlMapper;
import com.meta.cp4m.Identifier;
import com.meta.cp4m.configuration.ConfigurationUtils;
import com.meta.cp4m.configuration.RootConfiguration;
import com.meta.cp4m.message.FBMessage;
import com.meta.cp4m.message.Message;
import com.meta.cp4m.message.ThreadState;
import java.io.IOException;
import java.time.Instant;
import org.junit.jupiter.api.Test;

class MirrorPluginTest {

private static final String TOML =
"""
port = 8081
[[plugins]]
name = "mirror_test"
type = "mirror"
[[stores]]
name = "memory_test"
type = "memory"
storage_duration_hours = 1
storage_capacity_mbs = 1
[[handlers]]
type = "messenger"
name = "messenger_test"
verify_token = "imgibberish"
app_secret = "imnotasecret"
page_access_token = "imnotasecreteither"
[[services]]
webhook_path = "/messenger"
plugin = "mirror_test"
store = "memory_test"
handler = "messenger_test"
""";

@Test
void sanity() throws IOException {
MirrorPlugin<FBMessage> plugin = new MirrorPlugin<>();
FBMessage output =
plugin.handle(
ThreadState.of(
new FBMessage(
Instant.now(),
Identifier.random(),
Identifier.random(),
Identifier.random(),
"test",
Message.Role.USER)));
assertThat(output.message()).isEqualTo("test");
}

@Test
void configLoads() throws JsonProcessingException {
TomlMapper mapper = ConfigurationUtils.tomlMapper();
RootConfiguration config =
ConfigurationUtils.tomlMapper()
.convertValue(mapper.readTree(TOML), RootConfiguration.class);

assertThat(config.toServicesRunner().services())
.hasSize(1)
.allSatisfy(p -> assertThat(p.plugin()).isOfAnyClassIn(MirrorPlugin.class));
}
}
133 changes: 95 additions & 38 deletions src/test/java/com/meta/cp4m/store/NullStoreTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,103 @@

import static org.assertj.core.api.Assertions.assertThat;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.dataformat.toml.TomlMapper;
import com.meta.cp4m.Identifier;
import com.meta.cp4m.configuration.ConfigurationUtils;
import com.meta.cp4m.configuration.RootConfiguration;
import com.meta.cp4m.message.*;
import java.time.Instant;
import org.junit.jupiter.api.Test;

class NullStoreStoreTest {

@Test
void test() {
Identifier senderId = Identifier.random();
Identifier recipientId = Identifier.random();

MessageFactory<FBMessage> messageFactory = MessageFactory.instance(FBMessage.class);
NullStore<FBMessage> nullStore = new NullStore<>();

assertThat(nullStore.size()).isEqualTo(0);
FBMessage message =
messageFactory.newMessage(
Instant.now(), "", senderId, recipientId, Identifier.random(), Message.Role.ASSISTANT);
ThreadState<FBMessage> thread = nullStore.add(message);
assertThat(nullStore.size()).isEqualTo(0);
assertThat(thread.messages()).hasSize(1).contains(message);

FBMessage message2 =
messageFactory.newMessage(
Instant.now(), "", recipientId, senderId, Identifier.random(), Message.Role.USER);
thread = nullStore.add(message2);
assertThat(nullStore.size()).isEqualTo(0);
assertThat(thread.messages()).hasSize(1);

FBMessage message3 =
messageFactory.newMessage(
Instant.now(),
"",
Identifier.random(),
Identifier.random(),
Identifier.random(),
Message.Role.USER);
thread = nullStore.add(message3);
assertThat(nullStore.size()).isEqualTo(0);
assertThat(thread.messages()).hasSize(1).contains(message3);
}
}
class NullStoreTest {
private static final String TOML =
"""
port = 8081
[[plugins]]
name = "mirror_test"
type = "mirror"
[[stores]]
name = "null_test"
type = "null"
[[handlers]]
type = "messenger"
name = "messenger_test"
verify_token = "imgibberish"
app_secret = "imnotasecret"
page_access_token = "imnotasecreteither"
[[services]]
webhook_path = "/messenger"
plugin = "mirror_test"
store = "null_test"
handler = "messenger_test"
""";

@Test
void test() {
Identifier senderId = Identifier.random();
Identifier recipientId = Identifier.random();

MessageFactory<FBMessage> messageFactory = MessageFactory.instance(FBMessage.class);
NullStore<FBMessage> nullStore = new NullStore<>();

assertThat(nullStore.size()).isEqualTo(0);
FBMessage message =
messageFactory.newMessage(
Instant.now(), "", senderId, recipientId, Identifier.random(), Message.Role.ASSISTANT);
ThreadState<FBMessage> thread = nullStore.add(message);
assertThat(nullStore.size()).isEqualTo(0);
assertThat(thread.messages()).hasSize(1).contains(message);

FBMessage message2 =
messageFactory.newMessage(
Instant.now(), "", recipientId, senderId, Identifier.random(), Message.Role.USER);
thread = nullStore.add(message2);
assertThat(nullStore.size()).isEqualTo(0);
assertThat(thread.messages()).hasSize(1);

FBMessage message3 =
messageFactory.newMessage(
Instant.now(),
"",
Identifier.random(),
Identifier.random(),
Identifier.random(),
Message.Role.USER);
thread = nullStore.add(message3);
assertThat(nullStore.size()).isEqualTo(0);
assertThat(thread.messages()).hasSize(1).contains(message3);
}

@Test
void configLoads() throws JsonProcessingException {
TomlMapper mapper = ConfigurationUtils.tomlMapper();
RootConfiguration config =
ConfigurationUtils.tomlMapper()
.convertValue(mapper.readTree(TOML), RootConfiguration.class);

assertThat(config.toServicesRunner().services())
.hasSize(1)
.allSatisfy(p -> assertThat(p.store()).isOfAnyClassIn(NullStore.class));
}

@Test
void configLoadsWithoutDefinedStore() throws JsonProcessingException {
TomlMapper mapper = ConfigurationUtils.tomlMapper();
ObjectNode node = (ObjectNode) mapper.readTree(TOML);
node.remove("stores");
((ObjectNode) node.get("services").get(0)).remove("store");
RootConfiguration config =
ConfigurationUtils.tomlMapper()
.convertValue(mapper.readTree(TOML), RootConfiguration.class);

assertThat(config.toServicesRunner().services())
.hasSize(1)
.allSatisfy(p -> assertThat(p.store()).isOfAnyClassIn(NullStore.class));
}
}

0 comments on commit 4aef434

Please sign in to comment.