diff --git a/src/main/java/com/meta/cp4m/Service.java b/src/main/java/com/meta/cp4m/Service.java index b085567..f29d9ae 100644 --- a/src/main/java/com/meta/cp4m/Service.java +++ b/src/main/java/com/meta/cp4m/Service.java @@ -69,6 +69,14 @@ public MessageHandler messageHandler() { return this.handler; } + public ChatStore store() { + return this.store; + } + + public LLMPlugin plugin() { + return this.llmPlugin; + } + private void execute(ThreadState thread) { T llmResponse; try { diff --git a/src/main/java/com/meta/cp4m/llm/LLMConfig.java b/src/main/java/com/meta/cp4m/llm/LLMConfig.java index e5b9f8b..74e6f3d 100644 --- a/src/main/java/com/meta/cp4m/llm/LLMConfig.java +++ b/src/main/java/com/meta/cp4m/llm/LLMConfig.java @@ -16,6 +16,7 @@ @JsonSubTypes({ @JsonSubTypes.Type(value = OpenAIConfig.class, name = "openai"), @JsonSubTypes.Type(value = HuggingFaceConfig.class, name = "hugging_face"), + @JsonSubTypes.Type(value = MirrorPluginConfig.class, name = "mirror"), }) public interface LLMConfig { diff --git a/src/main/java/com/meta/cp4m/llm/MirrorPlugin.java b/src/main/java/com/meta/cp4m/llm/MirrorPlugin.java new file mode 100644 index 0000000..5d44952 --- /dev/null +++ b/src/main/java/com/meta/cp4m/llm/MirrorPlugin.java @@ -0,0 +1,22 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.cp4m.llm; + +import com.meta.cp4m.message.Message; +import com.meta.cp4m.message.ThreadState; +import java.io.IOException; +import java.time.Instant; + +public class MirrorPlugin implements LLMPlugin { + + @Override + public T handle(ThreadState threadState) throws IOException { + return threadState.newMessageFromBot(Instant.now(), threadState.tail().message()); + } +} diff --git a/src/main/java/com/meta/cp4m/llm/MirrorPluginConfig.java b/src/main/java/com/meta/cp4m/llm/MirrorPluginConfig.java new file mode 100644 index 0000000..79a8342 --- /dev/null +++ b/src/main/java/com/meta/cp4m/llm/MirrorPluginConfig.java @@ -0,0 +1,19 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.cp4m.llm; + +import com.meta.cp4m.message.Message; + +public record MirrorPluginConfig(String name) implements LLMConfig { + + @Override + public LLMPlugin toPlugin() { + return new MirrorPlugin<>(); + } +} diff --git a/src/main/java/com/meta/cp4m/store/NullStoreConfig.java b/src/main/java/com/meta/cp4m/store/NullStoreConfig.java new file mode 100644 index 0000000..5771fbd --- /dev/null +++ b/src/main/java/com/meta/cp4m/store/NullStoreConfig.java @@ -0,0 +1,19 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.cp4m.store; + +import com.meta.cp4m.message.Message; + +public record NullStoreConfig(String name) implements StoreConfig { + + @Override + public ChatStore toStore() { + return new NullStore<>(); + } +} diff --git a/src/main/java/com/meta/cp4m/store/StoreConfig.java b/src/main/java/com/meta/cp4m/store/StoreConfig.java index bb5ee96..498f732 100644 --- a/src/main/java/com/meta/cp4m/store/StoreConfig.java +++ b/src/main/java/com/meta/cp4m/store/StoreConfig.java @@ -15,6 +15,7 @@ @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @JsonSubTypes({ @JsonSubTypes.Type(value = MemoryStoreConfig.class, name = "memory"), + @JsonSubTypes.Type(value = NullStoreConfig.class, name = "null"), }) public interface StoreConfig { diff --git a/src/test/java/com/meta/cp4m/llm/MirrorPluginTest.java b/src/test/java/com/meta/cp4m/llm/MirrorPluginTest.java new file mode 100644 index 0000000..494c340 --- /dev/null +++ b/src/test/java/com/meta/cp4m/llm/MirrorPluginTest.java @@ -0,0 +1,82 @@ +/* + * + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +package com.meta.cp4m.llm; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.dataformat.toml.TomlMapper; +import com.meta.cp4m.Identifier; +import com.meta.cp4m.configuration.ConfigurationUtils; +import com.meta.cp4m.configuration.RootConfiguration; +import com.meta.cp4m.message.FBMessage; +import com.meta.cp4m.message.Message; +import com.meta.cp4m.message.ThreadState; +import java.io.IOException; +import java.time.Instant; +import org.junit.jupiter.api.Test; + +class MirrorPluginTest { + + private static final String TOML = + """ +port = 8081 + +[[plugins]] +name = "mirror_test" +type = "mirror" + +[[stores]] +name = "memory_test" +type = "memory" +storage_duration_hours = 1 +storage_capacity_mbs = 1 + +[[handlers]] +type = "messenger" +name = "messenger_test" +verify_token = "imgibberish" +app_secret = "imnotasecret" +page_access_token = "imnotasecreteither" + +[[services]] +webhook_path = "/messenger" +plugin = "mirror_test" +store = "memory_test" +handler = "messenger_test" +"""; + + @Test + void sanity() throws IOException { + MirrorPlugin plugin = new MirrorPlugin<>(); + FBMessage output = + plugin.handle( + ThreadState.of( + new FBMessage( + Instant.now(), + Identifier.random(), + Identifier.random(), + Identifier.random(), + "test", + Message.Role.USER))); + assertThat(output.message()).isEqualTo("test"); + } + + @Test + void configLoads() throws JsonProcessingException { + TomlMapper mapper = ConfigurationUtils.tomlMapper(); + RootConfiguration config = + ConfigurationUtils.tomlMapper() + .convertValue(mapper.readTree(TOML), RootConfiguration.class); + + assertThat(config.toServicesRunner().services()) + .hasSize(1) + .allSatisfy(p -> assertThat(p.plugin()).isOfAnyClassIn(MirrorPlugin.class)); + } +} diff --git a/src/test/java/com/meta/cp4m/store/NullStoreTest.java b/src/test/java/com/meta/cp4m/store/NullStoreTest.java index 5a1c419..749da22 100644 --- a/src/test/java/com/meta/cp4m/store/NullStoreTest.java +++ b/src/test/java/com/meta/cp4m/store/NullStoreTest.java @@ -10,46 +10,103 @@ import static org.assertj.core.api.Assertions.assertThat; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.dataformat.toml.TomlMapper; import com.meta.cp4m.Identifier; +import com.meta.cp4m.configuration.ConfigurationUtils; +import com.meta.cp4m.configuration.RootConfiguration; import com.meta.cp4m.message.*; import java.time.Instant; import org.junit.jupiter.api.Test; -class NullStoreStoreTest { - - @Test - void test() { - Identifier senderId = Identifier.random(); - Identifier recipientId = Identifier.random(); - - MessageFactory messageFactory = MessageFactory.instance(FBMessage.class); - NullStore nullStore = new NullStore<>(); - - assertThat(nullStore.size()).isEqualTo(0); - FBMessage message = - messageFactory.newMessage( - Instant.now(), "", senderId, recipientId, Identifier.random(), Message.Role.ASSISTANT); - ThreadState thread = nullStore.add(message); - assertThat(nullStore.size()).isEqualTo(0); - assertThat(thread.messages()).hasSize(1).contains(message); - - FBMessage message2 = - messageFactory.newMessage( - Instant.now(), "", recipientId, senderId, Identifier.random(), Message.Role.USER); - thread = nullStore.add(message2); - assertThat(nullStore.size()).isEqualTo(0); - assertThat(thread.messages()).hasSize(1); - - FBMessage message3 = - messageFactory.newMessage( - Instant.now(), - "", - Identifier.random(), - Identifier.random(), - Identifier.random(), - Message.Role.USER); - thread = nullStore.add(message3); - assertThat(nullStore.size()).isEqualTo(0); - assertThat(thread.messages()).hasSize(1).contains(message3); - } -} \ No newline at end of file +class NullStoreTest { + private static final String TOML = + """ + port = 8081 + + [[plugins]] + name = "mirror_test" + type = "mirror" + + [[stores]] + name = "null_test" + type = "null" + + [[handlers]] + type = "messenger" + name = "messenger_test" + verify_token = "imgibberish" + app_secret = "imnotasecret" + page_access_token = "imnotasecreteither" + + [[services]] + webhook_path = "/messenger" + plugin = "mirror_test" + store = "null_test" + handler = "messenger_test" + """; + + @Test + void test() { + Identifier senderId = Identifier.random(); + Identifier recipientId = Identifier.random(); + + MessageFactory messageFactory = MessageFactory.instance(FBMessage.class); + NullStore nullStore = new NullStore<>(); + + assertThat(nullStore.size()).isEqualTo(0); + FBMessage message = + messageFactory.newMessage( + Instant.now(), "", senderId, recipientId, Identifier.random(), Message.Role.ASSISTANT); + ThreadState thread = nullStore.add(message); + assertThat(nullStore.size()).isEqualTo(0); + assertThat(thread.messages()).hasSize(1).contains(message); + + FBMessage message2 = + messageFactory.newMessage( + Instant.now(), "", recipientId, senderId, Identifier.random(), Message.Role.USER); + thread = nullStore.add(message2); + assertThat(nullStore.size()).isEqualTo(0); + assertThat(thread.messages()).hasSize(1); + + FBMessage message3 = + messageFactory.newMessage( + Instant.now(), + "", + Identifier.random(), + Identifier.random(), + Identifier.random(), + Message.Role.USER); + thread = nullStore.add(message3); + assertThat(nullStore.size()).isEqualTo(0); + assertThat(thread.messages()).hasSize(1).contains(message3); + } + + @Test + void configLoads() throws JsonProcessingException { + TomlMapper mapper = ConfigurationUtils.tomlMapper(); + RootConfiguration config = + ConfigurationUtils.tomlMapper() + .convertValue(mapper.readTree(TOML), RootConfiguration.class); + + assertThat(config.toServicesRunner().services()) + .hasSize(1) + .allSatisfy(p -> assertThat(p.store()).isOfAnyClassIn(NullStore.class)); + } + + @Test + void configLoadsWithoutDefinedStore() throws JsonProcessingException { + TomlMapper mapper = ConfigurationUtils.tomlMapper(); + ObjectNode node = (ObjectNode) mapper.readTree(TOML); + node.remove("stores"); + ((ObjectNode) node.get("services").get(0)).remove("store"); + RootConfiguration config = + ConfigurationUtils.tomlMapper() + .convertValue(mapper.readTree(TOML), RootConfiguration.class); + + assertThat(config.toServicesRunner().services()) + .hasSize(1) + .allSatisfy(p -> assertThat(p.store()).isOfAnyClassIn(NullStore.class)); + } +}