diff --git a/lmos-router-llm/ReadMe.md b/lmos-router-llm/ReadMe.md index ab9e7b9..16b169b 100644 --- a/lmos-router-llm/ReadMe.md +++ b/lmos-router-llm/ReadMe.md @@ -7,32 +7,49 @@ SPDX-License-Identifier: CC-BY-4.0 ## Overview -The LLM submodule is responsible for resolving agent routing specifications using a language model. It includes classes and interfaces for interacting with the OpenAI API by default, providing prompts to the model, and resolving agent routing specifications based on the model's responses. +The **LLM Submodule** is responsible for resolving agent routing specifications using a language model. It includes classes and interfaces for interacting with the OpenAI API by default, providing prompts to the model, and resolving agent routing specifications based on the model's responses. Additionally, it supports multiple language model providers such as Anthropic, Gemini, Ollama, and other OpenAI-compatible APIs through the LangChain4j integration. ## Table of Contents 1. [Introduction](#introduction) 2. [Classes and Interfaces](#classes-and-interfaces) + - [ModelClient](#modelclient) + - [DefaultModelClient](#defaultmodelclient) + - [DefaultModelClientProperties](#defaultmodelclientproperties) + - [LLMAgentRoutingSpecsResolver](#llmagentroutingspecsresolver) + - [ModelPromptProvider](#modelpromptprovider) + - [DefaultModelPromptProvider](#defaultmodelpromptprovider) + - [ExternalModelPromptProvider](#externalmodelpromptprovider) + - [AgentRoutingSpecListType](#agentroutingspeclisttype) + - [ModelClientResponse](#modelclientresponse) + - [LangChainModelClient](#langchainmodelclient) + - [LangChainChatModelFactory](#langchainchatmodelfactory) + - [LangChainClientProvider](#langchainclientprovider) 3. [Usage](#usage) + - [Step 1: Initialize the DefaultModelClient](#step-1-initialize-the-defaultmodelclient) + - [Step 2: Initialize the LLMAgentRoutingSpecsResolver](#step-2-initialize-the-llmagentroutingspecsresolver) + - [Step 3: Resolve the Agent](#step-3-resolve-the-agent) + - [Advanced: Using LangChainModelClient](#advanced-using-langchainmodelclient) 4. [Configuration](#configuration) 5. [Error Handling](#error-handling) +6. [License](#license) ## Introduction -The LLM submodule leverages advanced language models to understand and match user queries with agent capabilities. It includes a default implementation for calling the OpenAI model and resolving agent routing specifications using the model's responses. +The **LLM Submodule** leverages advanced language models to understand and match user queries with agent capabilities. It includes a default implementation for calling the OpenAI model and resolving agent routing specifications using the model's responses. Through the integration with **LangChain4j**, the submodule extends support to additional providers such as Anthropic, Gemini, Ollama, and other OpenAI-compatible APIs, offering flexibility in choosing the underlying language model service. ## Classes and Interfaces ### ModelClient -The `ModelClient` interface represents a client that can call a model. +The `ModelClient` interface represents a client that can communicate with a language model. - **Method:** - `call(messages: List): Result` ### DefaultModelClient -The `DefaultModelClient` class is a default implementation of the `ModelClient` interface. It calls the OpenAI model with the given messages. +The `DefaultModelClient` class is the default implementation of the `ModelClient` interface. It interacts with the OpenAI API to process messages. - **Constructor:** - `DefaultModelClient(defaultModelClientProperties: DefaultModelClientProperties)` @@ -42,7 +59,7 @@ The `DefaultModelClient` class is a default implementation of the `ModelClient` ### DefaultModelClientProperties -The `DefaultModelClientProperties` data class represents the properties for the default model client. +The `DefaultModelClientProperties` data class encapsulates the configuration properties required by the `DefaultModelClient`. - **Fields:** - `openAiUrl: String` @@ -54,10 +71,10 @@ The `DefaultModelClientProperties` data class represents the properties for the ### LLMAgentRoutingSpecsResolver -The `LLMAgentRoutingSpecsResolver` class resolves agent routing specifications using a language model. +The `LLMAgentRoutingSpecsResolver` class is responsible for resolving agent routing specifications using a language model. - **Constructor:** - - `LLMAgentRoutingSpecsResolver(agentRoutingSpecsProvider: AgentRoutingSpecsProvider, modelPromptProvider: ModelPromptProvider, modelClient: ModelClient, serializer: Json)` + - `LLMAgentRoutingSpecsResolver(agentRoutingSpecsProvider: AgentRoutingSpecsProvider, modelPromptProvider: ModelPromptProvider, modelClient: ModelClient, serializer: Json, modelClientResponseProcessor: ModelClientResponseProcessor)` - **Methods:** - `resolve(context: Context, input: UserMessage): Result` @@ -65,21 +82,21 @@ The `LLMAgentRoutingSpecsResolver` class resolves agent routing specifications u ### ModelPromptProvider -The `ModelPromptProvider` interface represents a provider of model prompts. +The `ModelPromptProvider` interface defines a provider that generates prompts for the language model based on the context and user input. - **Method:** - `providePrompt(context: Context, agentRoutingSpecs: Set, input: UserMessage): Result` ### DefaultModelPromptProvider -The `DefaultModelPromptProvider` class provides a generic prompt for agent resolution. +The `DefaultModelPromptProvider` class offers a generic implementation of the `ModelPromptProvider`, generating standard prompts for agent resolution. - **Method:** - `providePrompt(context: Context, agentRoutingSpecs: Set, input: UserMessage): Result` ### ExternalModelPromptProvider -The `ExternalModelPromptProvider` class provides a prompt from an external file. The agent routing specifications for the prompt can be in XML or JSON format. +The `ExternalModelPromptProvider` class generates prompts from an external file, supporting agent routing specifications in XML or JSON formats. - **Constructor:** - `ExternalModelPromptProvider(promptFilePath: String, agentRoutingSpecsListType: AgentRoutingSpecListType)` @@ -89,7 +106,7 @@ The `ExternalModelPromptProvider` class provides a prompt from an external file. ### AgentRoutingSpecListType -The `AgentRoutingSpecListType` enum represents the format of the agent routing specs list. +The `AgentRoutingSpecListType` enum defines the supported formats for agent routing specifications. - **Values:** - `XML` @@ -97,18 +114,65 @@ The `AgentRoutingSpecListType` enum represents the format of the agent routing s ### ModelClientResponse -The `ModelClientResponse` class represents a model client response. +The `ModelClientResponse` class encapsulates the response from the language model client. - **Field:** - `agentName: String` +### LangChainModelClient + +The `LangChainModelClient` class is an advanced implementation of the `ModelClient` interface using **LangChain4j** to interact with various language models. + +- **Constructor:** + - `LangChainModelClient(chatLanguageModel: ChatLanguageModel)` + +- **Method:** + - `call(messages: List): Result` + +**Details:** +- Converts internal `ChatMessage` types (`UserMessage`, `AssistantMessage`, `SystemMessage`) to **LangChain4j** compatible message types. +- Handles exceptions by encapsulating them within `AgentRoutingSpecResolverException`. + +### LangChainChatModelFactory + +The `LangChainChatModelFactory` is a factory class responsible for creating instances of `ChatLanguageModel` based on the provided configuration. + +- **Companion Object Method:** + - `createClient(properties: ModelClientProperties): ChatLanguageModel` + +**Supported Providers:** +- `OPENAI` +- `ANTHROPIC` +- `GEMINI` +- `OLLAMA` +- `OTHER` (for OpenAI-compatible APIs) + +**Details:** +- Configures the language model client with appropriate settings such as API keys, model names, token limits, temperature, and response formats based on the selected provider. + +### LangChainClientProvider + +The `LangChainClientProvider` enum lists the supported language model providers. + +- **Values:** + - `OPENAI` + - `ANTHROPIC` + - `GEMINI` + - `OLLAMA` + - `OTHER` + ## Usage ### Step 1: Initialize the DefaultModelClient ```kotlin val defaultModelClientProperties = DefaultModelClientProperties( - openAiApiKey = "your-openai-api-key" + openAiUrl = "https://api.openai.com/v1/chat/completions", + openAiApiKey = "your-openai-api-key", + model = "gpt-4o", + maxTokens = 1500, + temperature = 0.7, + format = "json" ) val modelClient = DefaultModelClient(defaultModelClientProperties) ``` @@ -118,10 +182,15 @@ val modelClient = DefaultModelClient(defaultModelClientProperties) ```kotlin val agentRoutingSpecsProvider = SimpleAgentRoutingSpecProvider() val modelPromptProvider = DefaultModelPromptProvider() +val modelClientResponseProcessor = DefaultModelClientResponseProcessor() +val serializer = Json { ignoreUnknownKeys = true } + val llmAgentRoutingSpecsResolver = LLMAgentRoutingSpecsResolver( agentRoutingSpecsProvider, modelPromptProvider, - modelClient + modelClient, + serializer, + modelClientResponseProcessor ) ``` @@ -131,4 +200,91 @@ val llmAgentRoutingSpecsResolver = LLMAgentRoutingSpecsResolver( val context = Context(listOf(AssistantMessage("Hello"))) val input = UserMessage("Can you help me find a new phone?") val result = llmAgentRoutingSpecsResolver.resolve(context, input) -``` \ No newline at end of file +``` + +### Advanced: Using LangChainModelClient + +For enhanced flexibility and support for multiple language model providers, you can utilize the `LangChainModelClient` along with the `LangChainChatModelFactory`. + +```kotlin +// Define model client properties +val langChainProperties = ModelClientProperties( + provider = LangChainClientProvider.OPENAI.name.lowercase(), + apiKey = "your-openai-api-key", + model = "gpt-4o", + maxTokens = 1500, + temperature = 0.7, + topP = 0.9, + topK = 40, + format = "json", + baseUrl = null // Required for OTHER provider +) + +// Create ChatLanguageModel using the factory +val chatLanguageModel = LangChainChatModelFactory.createClient(langChainProperties) + +// Initialize LangChainModelClient +val langChainModelClient = LangChainModelClient(chatLanguageModel) + +// Use LangChainModelClient with LLMAgentRoutingSpecsResolver +val llmAgentRoutingSpecsResolverAdvanced = LLMAgentRoutingSpecsResolver( + agentRoutingSpecsProvider, + modelPromptProvider, + langChainModelClient, + serializer, + modelClientResponseProcessor +) + +// Resolve agent as before +val advancedResult = llmAgentRoutingSpecsResolverAdvanced.resolve(context, input) +``` + +**Supported Providers via LangChainModelClient:** +- **OpenAI:** Default provider with full configuration support. +- **Anthropic:** Supports models from Anthropic. +- **Gemini:** Integrates with Google AI Gemini models. +- **Ollama:** Connects to Ollama-based models. +- **Other:** For any OpenAI-compatible API by specifying the `baseUrl`. + +## Configuration + +Configure the `DefaultModelClientProperties` or `ModelClientProperties` based on the chosen provider. Ensure that all required fields such as `apiKey`, `model`, and `baseUrl` (if applicable) are correctly set. + +**Example Configuration for OpenAI:** + +```kotlin +val properties = ModelClientProperties( + provider = LangChainClientProvider.OPENAI.name.lowercase(), + apiKey = "your-openai-api-key", + model = "gpt-4o", + maxTokens = 1500, + temperature = 0.7, + topP = 0.9, + topK = 40, + format = "json", + baseUrl = null +) +``` + +**Example Configuration for Other Providers:** + +```kotlin +val properties = ModelClientProperties( + provider = LangChainClientProvider.ANTHROPIC.name.lowercase(), + apiKey = "your-anthropic-api-key", + model = "claude-3-5-sonnet-20241022", + maxTokens = 1500, + temperature = 0.7 +) +``` + +## Error Handling + +The submodule employs error handling mechanisms to manage failures during model interactions and agent resolution. + +- **Exceptions:** + - `AgentRoutingSpecResolverException`: Thrown when there are issues in resolving agent routing specifications or interacting with the language model. + +- **Handling Strategy:** + - Utilize the `Result` type to handle successes and failures gracefully. + - Implement appropriate fallback mechanisms or user notifications in case of failures. diff --git a/lmos-router-llm/build.gradle.kts b/lmos-router-llm/build.gradle.kts index 05fafc9..a8c297b 100644 --- a/lmos-router-llm/build.gradle.kts +++ b/lmos-router-llm/build.gradle.kts @@ -5,7 +5,14 @@ dependencies { api(project(":lmos-router-core")) implementation("org.slf4j:slf4j-api:1.7.25") - implementation("com.azure:azure-ai-openai:1.0.0-beta.10") - implementation("com.azure:azure-identity:1.14.0") implementation("org.jetbrains.kotlinx:kotlinx-serialization-json-jvm:1.7.3") + compileOnly("dev.langchain4j:langchain4j-open-ai:0.36.2") + compileOnly("dev.langchain4j:langchain4j-anthropic:0.36.2") + compileOnly("dev.langchain4j:langchain4j-google-ai-gemini:0.36.2") + compileOnly("dev.langchain4j:langchain4j-ollama:0.36.2") + + testImplementation("dev.langchain4j:langchain4j-open-ai:0.36.2") + testImplementation("dev.langchain4j:langchain4j-anthropic:0.36.2") + testImplementation("dev.langchain4j:langchain4j-google-ai-gemini:0.36.2") + testImplementation("dev.langchain4j:langchain4j-ollama:0.36.2") } diff --git a/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/LLMAgentRoutingSpecsResolver.kt b/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/LLMAgentRoutingSpecsResolver.kt index 85491fc..328b5c4 100644 --- a/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/LLMAgentRoutingSpecsResolver.kt +++ b/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/LLMAgentRoutingSpecsResolver.kt @@ -16,6 +16,7 @@ import org.slf4j.LoggerFactory * @param modelPromptProvider The provider of model prompts. * @param modelClient The client for the language model. * @param serializer The JSON serializer. + * @param modelClientResponseProcessor The processor for the model client response. */ class LLMAgentRoutingSpecsResolver( override val agentRoutingSpecsProvider: AgentRoutingSpecsProvider, @@ -34,6 +35,7 @@ class LLMAgentRoutingSpecsResolver( ignoreUnknownKeys = true isLenient = true }, + private val modelClientResponseProcessor: ModelClientResponseProcessor = DefaultModelClientResponseProcessor(), ) : AgentRoutingSpecsResolver { private val log = LoggerFactory.getLogger(LLMAgentRoutingSpecsResolver::class.java) @@ -62,7 +64,10 @@ class LLMAgentRoutingSpecsResolver( messages.add(input) log.trace("Fetching agent spec completion") - val response: String = modelClient.call(messages).getOrThrow().content + var response: String = modelClient.call(messages).getOrThrow().content + + response = modelClientResponseProcessor.process(response) + val agent: ModelClientResponse = serializer.decodeFromString(serializer(), response) log.trace("Agent resolved: ${agent.agentName}") diff --git a/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/LangChainModelClient.kt b/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/LangChainModelClient.kt new file mode 100644 index 0000000..c658677 --- /dev/null +++ b/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/LangChainModelClient.kt @@ -0,0 +1,148 @@ +// SPDX-FileCopyrightText: 2024 Deutsche Telekom AG +// +// SPDX-License-Identifier: Apache-2.0 + +package ai.ancf.lmos.router.llm + +import ai.ancf.lmos.router.core.* +import dev.langchain4j.data.message.AiMessage +import dev.langchain4j.model.anthropic.AnthropicChatModel +import dev.langchain4j.model.chat.ChatLanguageModel +import dev.langchain4j.model.chat.request.ResponseFormat +import dev.langchain4j.model.chat.request.ResponseFormatType +import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel +import dev.langchain4j.model.ollama.OllamaChatModel +import dev.langchain4j.model.openai.OpenAiChatModel.OpenAiChatModelBuilder + +/** + * A model client that uses Langchain4j(https://docs.langchain4j.dev/) to call a language model. + * + * @param chatLanguageModel The language model. + */ +class LangChainModelClient( + private val chatLanguageModel: ChatLanguageModel, +) : ModelClient { + override fun call(messages: List): Result { + try { + val response = + chatLanguageModel.generate( + messages.map { + when (it) { + is UserMessage -> dev.langchain4j.data.message.UserMessage(it.content) + is AssistantMessage -> AiMessage(it.content) + is SystemMessage -> dev.langchain4j.data.message.SystemMessage(it.content) + else -> throw AgentRoutingSpecResolverException("Unknown message type") + } + }, + ) + return Success(AssistantMessage(response.content().text())) + } catch (e: Exception) { + return Failure(AgentRoutingSpecResolverException("Failed to call language model", e)) + } + } +} + +/** + * A factory class to create a Langchain4j language model client. The factory creates a client based on the given properties. + * + * The factory supports the following providers: + * - OPENAI + * - ANTHROPIC + * - GEMINI + * - OLLAMA + * - OTHER (for other providers which have OpenAI compatible API) + */ +class LangChainChatModelFactory private constructor() { + companion object { + fun createClient(properties: ModelClientProperties): ChatLanguageModel { + return when (properties.provider) { + LangChainClientProvider.OPENAI.name.lowercase(), + -> { + val model = + OpenAiChatModelBuilder() + .apiKey(properties.apiKey) + .modelName(properties.model) + .maxTokens(properties.maxTokens) + .temperature(properties.temperature) + + properties.topP?.let { model.topP(it) } + properties.format.takeIf { it != null }?.let { model.responseFormat(it) } + + model.build() + } + + LangChainClientProvider.ANTHROPIC.name.lowercase() -> { + val model = + AnthropicChatModel.builder() + .apiKey(properties.apiKey) + .modelName(properties.model) + .maxTokens(properties.maxTokens) + .temperature(properties.temperature) + + properties.topP?.let { model.topP(it) } + properties.topK?.let { model.topK(it) } + + model.build() + } + + LangChainClientProvider.GEMINI.name.lowercase() -> { + val model = + GoogleAiGeminiChatModel.builder() + .modelName(properties.model) + .maxOutputTokens(properties.maxTokens) + .temperature(properties.temperature) + .topK(properties.topK) + .topP(properties.topP) + + properties.format.takeIf { it != null }?.let { + model.responseFormat( + ResponseFormat.builder() + .type(ResponseFormatType.valueOf(it)) + .build(), + ) + } + properties.apiKey?.let { model.apiKey(it) } + properties.topK?.let { model.topK(it) } + properties.topP?.let { model.topP(it) } + + model.build() + } + + LangChainClientProvider.OLLAMA.name.lowercase() -> { + OllamaChatModel.builder().baseUrl(properties.baseUrl) + .modelName(properties.model) + .temperature(properties.temperature) + .build() + } + + LangChainClientProvider.OTHER.name.lowercase() -> { + require(properties.baseUrl != null) { "baseUrl is required for OTHER provider" } + + val model = + OpenAiChatModelBuilder() + .baseUrl(properties.baseUrl) + .apiKey(properties.apiKey) + .modelName(properties.model) + .maxTokens(properties.maxTokens) + .temperature(properties.temperature) + properties.topP?.let { model.topP(it) } + properties.format.takeIf { it != null }?.let { model.responseFormat(it) } + + model.build() + } + + else -> { + throw IllegalArgumentException("Unknown model client properties: $properties") + } + } + } + } +} + +enum class LangChainClientProvider { + OPENAI, + ANTHROPIC, + GEMINI, + OLLAMA, + OTHER, +} diff --git a/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/ModelClient.kt b/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/ModelClient.kt index 57b37ea..e428d40 100644 --- a/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/ModelClient.kt +++ b/lmos-router-llm/src/main/kotlin/ai/ancf/lmos/router/llm/ModelClient.kt @@ -5,9 +5,6 @@ package ai.ancf.lmos.router.llm import ai.ancf.lmos.router.core.* -import com.azure.ai.openai.OpenAIClientBuilder -import com.azure.ai.openai.models.* -import com.azure.core.credential.AzureKeyCredential /** * The [ModelClient] interface represents a client that can call a model. @@ -29,48 +26,28 @@ interface ModelClient { * * @param defaultModelClientProperties The properties for the default model client. */ -class DefaultModelClient(private val defaultModelClientProperties: DefaultModelClientProperties) : ModelClient { - private var client = - OpenAIClientBuilder() - .credential(AzureKeyCredential(defaultModelClientProperties.openAiApiKey)) - .endpoint(defaultModelClientProperties.openAiUrl) - .buildClient() - +class DefaultModelClient( + private val defaultModelClientProperties: DefaultModelClientProperties, + private val delegate: LangChainModelClient = + LangChainModelClient(LangChainChatModelFactory.createClient(defaultModelClientProperties)), +) : ModelClient { override fun call(messages: List): Result { - try { - val chatCompletionsOptions = - ChatCompletionsOptions( - messages.map { - when (it) { - is UserMessage -> ChatRequestUserMessage(it.content) - is AssistantMessage -> ChatRequestAssistantMessage(it.content) - is SystemMessage -> ChatRequestSystemMessage(it.content) - else -> throw AgentRoutingSpecResolverException("Unknown message type") - } - }, - ).setTemperature(defaultModelClientProperties.temperature) - .setModel(defaultModelClientProperties.model) - .setMaxTokens(defaultModelClientProperties.maxTokens) - .apply { - defaultModelClientProperties.format.let { - responseFormat = ChatCompletionsJsonResponseFormat() - } - } - - return Success( - AssistantMessage( - client.getChatCompletions( - defaultModelClientProperties.model, - chatCompletionsOptions, - ).choices.first().message.content, - ), - ) - } catch (e: Exception) { - return Failure(AgentRoutingSpecResolverException("Failed to call model", e)) - } + return delegate.call(messages) } } +open class ModelClientProperties( + open val provider: String, + open val apiKey: String? = null, + open val baseUrl: String? = null, + open val model: String, + open val maxTokens: Int = 2000, + open val temperature: Double = 0.0, + open val format: String? = null, + open val topK: Int? = null, + open val topP: Double? = null, +) + /** * The [DefaultModelClientProperties] data class represents the properties for the default model client. * @@ -82,10 +59,52 @@ class DefaultModelClient(private val defaultModelClientProperties: DefaultModelC * @param format The format. */ data class DefaultModelClientProperties( - val openAiUrl: String = "https://api.openai.com/v1/chat/completions", + val openAiUrl: String = "https://api.openai.com/v1", val openAiApiKey: String, - val model: String = "gpt-4o-mini", - val maxTokens: Int = 200, - val temperature: Double = 0.0, - val format: String = "json_object", -) + override val model: String = "gpt-4o-mini", + override val maxTokens: Int = 200, + override val temperature: Double = 0.0, + override val format: String = "json_object", + override val apiKey: String? = openAiApiKey, + override val baseUrl: String = openAiUrl, + override val provider: String = "openai", +) : ModelClientProperties( + provider, + apiKey, + baseUrl, + model, + maxTokens, + temperature, + format, + ) + +/** + * This interface represents a model response processor. + * + * The objective is to process the response from the model and return agentSpec compliant json. + */ +interface ModelClientResponseProcessor { + fun process(modelResponse: String): String +} + +/** + * This class is a default implementation of the ModelResponseProcessor interface. + * + * The processResponse method processes the response from the model. + * + * By default, it cleans the response and remove ```json and tags. Refer default prompt for more information. + */ +class DefaultModelClientResponseProcessor : ModelClientResponseProcessor { + override fun process(modelResponse: String): String { + var response = modelResponse.trim() + + if (response.contains("```json")) { + response = response.substringAfter("```json").substringBefore("```").trim() + } + + if (response.contains("")) { + response = response.substringAfter("").substringBefore("").trim() + } + return response + } +} diff --git a/lmos-router-llm/src/test/kotlin/SampleLLMFlow.kt b/lmos-router-llm/src/test/kotlin/SampleLLMFlow.kt index 5f0cbf6..1971cbd 100644 --- a/lmos-router-llm/src/test/kotlin/SampleLLMFlow.kt +++ b/lmos-router-llm/src/test/kotlin/SampleLLMFlow.kt @@ -6,13 +6,16 @@ package ai.ancf.lmos.router.llm import ai.ancf.lmos.router.core.* import io.mockk.clearAllMocks -import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test class SampleLLMFlow { @Test fun `test sample llm resolver flow`() { + require(System.getenv("OPENAI_API_KEY") != null) { + "Please set the OPENAI_API_KEY environment variable to run the tests" + } + val agentRoutingSpecsProvider = JsonAgentRoutingSpecsProvider(jsonFilePath = "src/test/resources/agentRoutingSpecs.json") val agentRoutingSpecResolver = @@ -40,6 +43,10 @@ class SampleLLMFlow { @Test fun `sample test with external prompt agent spec in xml format`() { + require(System.getenv("OPENAI_API_KEY") != null) { + "Please set the OPENAI_API_KEY environment variable to run the tests" + } + val agentRoutingSpecsResolver = LLMAgentRoutingSpecsResolver( JsonAgentRoutingSpecsProvider(jsonFilePath = "src/test/resources/agentRoutingSpecs.json"), @@ -54,6 +61,31 @@ class SampleLLMFlow { @Test fun `sample test with external prompt agent spec in json format`() { + require(System.getenv("OPENAI_API_KEY") != null) { + "Please set the OPENAI_API_KEY environment variable to run the tests" + } + + val routingSpecsResolver = + LLMAgentRoutingSpecsResolver( + JsonAgentRoutingSpecsProvider(jsonFilePath = "src/test/resources/agentRoutingSpecs.json"), + ExternalModelPromptProvider( + "src/test/resources/prompt_agentRoutingSpec_json.txt", + AgentRoutingSpecListType.JSON, + ), + ) + val context = Context(listOf(AssistantMessage("Hello"))) + val input = UserMessage("Can you help me find a new phone?") + val result = routingSpecsResolver.resolve(context, input) + assert(result is Success) + assert((result as Success).getOrNull()?.name == "offer-agent") + } + + @Test + fun `sample test with external prompt agent spec in json format with gemini model`() { + require(System.getenv("GEMINI_API_KEY") != null) { + "Please set the GEMINI_API_KEY environment variable to run the tests" + } + val routingSpecsResolver = LLMAgentRoutingSpecsResolver( JsonAgentRoutingSpecsProvider(jsonFilePath = "src/test/resources/agentRoutingSpecs.json"), @@ -61,6 +93,107 @@ class SampleLLMFlow { "src/test/resources/prompt_agentRoutingSpec_json.txt", AgentRoutingSpecListType.JSON, ), + modelClient = + LangChainModelClient( + LangChainChatModelFactory.createClient( + ModelClientProperties( + provider = "gemini", + apiKey = System.getenv("GEMINI_API_KEY"), + model = "gemini-1.5-flash", + ), + ), + ), + ) + val context = Context(listOf(AssistantMessage("Hello"))) + val input = UserMessage("Can you help me find a new phone?") + val result = routingSpecsResolver.resolve(context, input) + assert(result is Success) + assert((result as Success).getOrNull()?.name == "offer-agent") + } + + @Test + fun `sample test with external prompt agent spec in json format with anthropic model`() { + require(System.getenv("ANTHROPIC_API_KEY") != null) { + "Please set the ANTHROPIC_API_KEY environment variable to run the tests" + } + + val routingSpecsResolver = + LLMAgentRoutingSpecsResolver( + JsonAgentRoutingSpecsProvider(jsonFilePath = "src/test/resources/agentRoutingSpecs.json"), + ExternalModelPromptProvider( + "src/test/resources/prompt_agentRoutingSpec_json.txt", + AgentRoutingSpecListType.JSON, + ), + modelClient = + LangChainModelClient( + LangChainChatModelFactory.createClient( + ModelClientProperties( + provider = "anthropic", + apiKey = System.getenv("ANTHROPIC_API_KEY"), + baseUrl = "https://api.anthropic.com/v1", + model = "claude-3-5-sonnet-20241022", + ), + ), + ), + ) + val context = Context(listOf(AssistantMessage("Hello"))) + val input = UserMessage("Can you help me find a new phone?") + val result = routingSpecsResolver.resolve(context, input) + assert(result is Success) + assert((result as Success).getOrNull()?.name == "offer-agent") + } + + @Test + fun `sample test with external prompt agent spec in json format with ollama model`() { + val routingSpecsResolver = + LLMAgentRoutingSpecsResolver( + JsonAgentRoutingSpecsProvider(jsonFilePath = "src/test/resources/agentRoutingSpecs.json"), + ExternalModelPromptProvider( + "src/test/resources/prompt_agentRoutingSpec_json.txt", + AgentRoutingSpecListType.JSON, + ), + modelClient = + LangChainModelClient( + LangChainChatModelFactory.createClient( + ModelClientProperties( + provider = "ollama", + baseUrl = "http://localhost:11434", + model = "qwen2.5-coder:7b", + ), + ), + ), + ) + val context = Context(listOf(AssistantMessage("Hello"))) + val input = UserMessage("Can you help me find a new phone?") + val result = routingSpecsResolver.resolve(context, input) + assert(result is Success) + assert((result as Success).getOrNull()?.name == "offer-agent") + } + + @Test + fun `sample test with external prompt agent spec in json format with other(groq) model`() { + require(System.getenv("GROQ_API_KEY") != null) { + "Please set the GROQ_API_KEY environment variable to run the tests" + } + + val routingSpecsResolver = + LLMAgentRoutingSpecsResolver( + JsonAgentRoutingSpecsProvider(jsonFilePath = "src/test/resources/agentRoutingSpecs.json"), + ExternalModelPromptProvider( + "src/test/resources/prompt_agentRoutingSpec_json.txt", + AgentRoutingSpecListType.JSON, + ), + modelClient = + LangChainModelClient( + LangChainChatModelFactory.createClient( + ModelClientProperties( + provider = "other", + baseUrl = "https://api.groq.com/openai/v1", + model = "llama3-8b-8192", + apiKey = System.getenv("GROQ_API_KEY"), + ), + ), + ), ) val context = Context(listOf(AssistantMessage("Hello"))) val input = UserMessage("Can you help me find a new phone?") @@ -74,14 +207,4 @@ class SampleLLMFlow { // Clear all mock clearAllMocks() } - - companion object { - @JvmStatic - @BeforeAll - fun setup() { - require(System.getenv("OPENAI_API_KEY") != null) { - "Please set the OPENAI_API_KEY environment variable to run the tests" - } - } - } } diff --git a/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/DefaultModelClientTest.kt b/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/DefaultModelClientTest.kt index 0c7d716..bb58f26 100644 --- a/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/DefaultModelClientTest.kt +++ b/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/DefaultModelClientTest.kt @@ -5,24 +5,16 @@ package ai.ancf.lmos.router.llm import ai.ancf.lmos.router.core.* -import com.azure.ai.openai.OpenAIClient -import com.azure.ai.openai.OpenAIClientBuilder -import com.azure.ai.openai.models.ChatChoice -import com.azure.ai.openai.models.ChatCompletions -import com.azure.ai.openai.models.ChatResponseMessage -import com.azure.core.credential.AzureKeyCredential import io.mockk.every import io.mockk.mockk -import io.mockk.mockkConstructor -import io.mockk.verify import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test class DefaultModelClientTest { private lateinit var defaultModelClientProperties: DefaultModelClientProperties - private lateinit var client: OpenAIClient private lateinit var defaultModelClient: DefaultModelClient + private lateinit var delegate: LangChainModelClient @BeforeEach fun setUp() { @@ -36,14 +28,8 @@ class DefaultModelClientTest { format = "json_object", ) - client = mockk() - mockkConstructor(OpenAIClientBuilder::class) - every { - anyConstructed().credential(any()).endpoint(any()) - .buildClient() - } returns client - - defaultModelClient = DefaultModelClient(defaultModelClientProperties) + delegate = mockk() + defaultModelClient = DefaultModelClient(defaultModelClientProperties, delegate) } @Test @@ -54,20 +40,10 @@ class DefaultModelClientTest { AssistantMessage("I am here to help you."), ) - val mockResponse = mockk() - val mockChoice = mockk() - val mockChatMessage = mockk() - every { mockResponse.choices } returns listOf(mockChoice) - every { mockChoice.message } returns mockChatMessage - every { mockChatMessage.content } returns "This is a response from the assistant." - - every { client.getChatCompletions(any(), any()) } returns mockResponse - + every { delegate.call(any()) } returns Success(AssistantMessage("This is a response from the assistant.")) val result = defaultModelClient.call(messages) assertEquals(result, Success(AssistantMessage("This is a response from the assistant."))) - - verify { client.getChatCompletions(any(), any()) } } @Test @@ -77,11 +53,10 @@ class DefaultModelClientTest { UserMessage("Hello, how can I assist you today?"), ) - every { client.getChatCompletions(any(), any()) } throws RuntimeException("API failure") - + every { delegate.call(any()) } returns Failure(AgentRoutingSpecResolverException("Failed to call language model")) val result = defaultModelClient.call(messages) assert(result is Failure) - assert((result as Failure).exceptionOrNull()?.message == "Failed to call model") + assert((result as Failure).exceptionOrNull()?.message == "Failed to call language model") } } diff --git a/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/LangChainChatModelFactoryTest.kt b/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/LangChainChatModelFactoryTest.kt new file mode 100644 index 0000000..9a2cac1 --- /dev/null +++ b/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/LangChainChatModelFactoryTest.kt @@ -0,0 +1,172 @@ +// SPDX-FileCopyrightText: 2024 Deutsche Telekom AG +// +// SPDX-License-Identifier: Apache-2.0 + +package ai.ancf.lmos.router.llm.ai.ancf.lmos.router.llm + +import ai.ancf.lmos.router.llm.LangChainChatModelFactory +import ai.ancf.lmos.router.llm.LangChainClientProvider +import ai.ancf.lmos.router.llm.ModelClientProperties +import dev.langchain4j.model.anthropic.AnthropicChatModel +import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel +import dev.langchain4j.model.ollama.OllamaChatModel +import dev.langchain4j.model.openai.OpenAiChatModel +import io.mockk.every +import io.mockk.mockk +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows + +/** + * Tests for [LangChainChatModelFactory]. + */ +class LangChainChatModelFactoryTest { + @Test + fun `createClient should return OpenAiChatModel for OPENAI provider`() { + // Arrange + val properties = mockk() + every { properties.provider } returns LangChainClientProvider.OPENAI.name.lowercase() + every { properties.baseUrl } returns "https://api.openai.com" + every { properties.apiKey } returns "openai-api-key" + every { properties.model } returns "gpt-4o-mini" + every { properties.maxTokens } returns 1000 + every { properties.temperature } returns 0.7 + every { properties.topP } returns 0.9 + every { properties.topK } returns 50 + every { properties.format } returns "json_object" + + // Act + val client = LangChainChatModelFactory.createClient(properties) + + // Assert + assertTrue(client is OpenAiChatModel) + } + + @Test + fun `createClient should return AnthropicChatModel for ANTHROPIC provider`() { + // Arrange + val properties = mockk() + every { properties.provider } returns LangChainClientProvider.ANTHROPIC.name.lowercase() + every { properties.baseUrl } returns "https://api.anthropic.com" + every { properties.apiKey } returns "anthropic-api-key" + every { properties.model } returns "claude-v1" + every { properties.maxTokens } returns 1500 + every { properties.temperature } returns 0.6 + every { properties.topP } returns 0.8 + every { properties.topK } returns 40 + every { properties.format } returns "text" + + // Act + val client = LangChainChatModelFactory.createClient(properties) + + // Assert + assertTrue(client is AnthropicChatModel) + } + + @Test + fun `createClient should return GoogleAiGeminiChatModel for GEMINI provider`() { + // Arrange + val properties = mockk() + every { properties.provider } returns LangChainClientProvider.GEMINI.name.lowercase() + every { properties.baseUrl } returns "" // Gemini model does not use baseUrl + every { properties.apiKey } returns "gemini-api-key" + every { properties.model } returns "gemini-1" + every { properties.maxTokens } returns 2000 + every { properties.temperature } returns 0.5 + every { properties.topP } returns 0.85 + every { properties.topK } returns 30 + every { properties.format } returns "JSON" + + // Act + val client = LangChainChatModelFactory.createClient(properties) + + // Assert + assertTrue(client is GoogleAiGeminiChatModel) + } + + @Test + fun `createClient should return OllamaChatModel for OLLAMA provider`() { + // Arrange + val properties = mockk() + every { properties.provider } returns LangChainClientProvider.OLLAMA.name.lowercase() + every { properties.baseUrl } returns "http://localhost:11434" + every { properties.apiKey } returns "" // Ollama model does not require API key + every { properties.model } returns "ollama-model" + every { properties.maxTokens } returns 0 // Ollama model does not use maxTokens + every { properties.temperature } returns 0.4 + every { properties.topP } returns 0.0 // Ollama model does not use topP + every { properties.topK } returns 0 // Ollama model does not use topK + every { properties.format } returns "" // Ollama model does not use format + + // Act + val client = LangChainChatModelFactory.createClient(properties) + + // Assert + assertTrue(client is OllamaChatModel) + } + + @Test + fun `createClient should return OpenAiChatModel for OTHER provider`() { + // Arrange + val properties = mockk() + every { properties.provider } returns LangChainClientProvider.OTHER.name.lowercase() + every { properties.baseUrl } returns "https://api.other.com" + every { properties.apiKey } returns "other-api-key" + every { properties.model } returns "other-model" + every { properties.maxTokens } returns 800 + every { properties.temperature } returns 0.8 + every { properties.topP } returns 0.95 + every { properties.topK } returns 60 + every { properties.format } returns "json_object" + + // Act + val client = LangChainChatModelFactory.createClient(properties) + + // Assert + assertTrue(client is OpenAiChatModel) + } + + @Test + fun `createClient should throw IllegalArgumentException for unknown provider`() { + // Arrange + val properties = mockk() + every { properties.provider } returns "unknown_provider" + every { properties.baseUrl } returns "https://api.unknown.com" + every { properties.apiKey } returns "unknown-api-key" + every { properties.model } returns "unknown-model" + every { properties.maxTokens } returns 500 + every { properties.temperature } returns 0.9 + every { properties.topP } returns 0.99 + every { properties.topK } returns 70 + every { properties.format } returns "plain" + + // Act & Assert + val exception = + assertThrows { + LangChainChatModelFactory.createClient(properties) + } + assertEquals("Unknown model client properties: $properties", exception.message) + } + + @Test + fun `createClient should throw IllegalArgumentException when provider is empty`() { + // Arrange + val properties = mockk() + every { properties.provider } returns "" + every { properties.baseUrl } returns "https://api.null.com" + every { properties.apiKey } returns "null-api-key" + every { properties.model } returns "null-model" + every { properties.maxTokens } returns 500 + every { properties.temperature } returns 0.9 + every { properties.topP } returns 0.99 + every { properties.topK } returns 70 + every { properties.format } returns "plain" + + // Act & Assert + val exception = + assertThrows { + LangChainChatModelFactory.createClient(properties) + } + assertEquals("Unknown model client properties: $properties", exception.message) + } +} diff --git a/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/LangChainModelClientTest.kt b/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/LangChainModelClientTest.kt new file mode 100644 index 0000000..a6de987 --- /dev/null +++ b/lmos-router-llm/src/test/kotlin/ai/ancf/lmos/router/llm/LangChainModelClientTest.kt @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: 2024 Deutsche Telekom AG +// +// SPDX-License-Identifier: Apache-2.0 + +package ai.ancf.lmos.router.llm.ai.ancf.lmos.router.llm + +import ai.ancf.lmos.router.core.* +import ai.ancf.lmos.router.llm.LangChainModelClient +import dev.langchain4j.model.chat.ChatLanguageModel +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import org.junit.jupiter.api.Assertions.* +import org.junit.jupiter.api.Test +import dev.langchain4j.data.message.SystemMessage as LangchainSystemMessage +import dev.langchain4j.data.message.UserMessage as LangchainUserMessage + +/** + * Tests for [LangChainModelClient]. + */ +class LangChainModelClientTest { + private val mockChatLanguageModel = mockk() + private val modelClient = LangChainModelClient(mockChatLanguageModel) + + @Test + fun `call should return Success with AssistantMessage when model generates response successfully`() { + // Arrange + val userMessage = UserMessage("Hello") + val assistantResponse = "Hi there!" + val mockResponse = mockk>() + + every { mockResponse.content().text() } returns assistantResponse + every { + mockChatLanguageModel.generate( + listOf(LangchainUserMessage("Hello")), + ) + } returns mockResponse + + // Act + val result = modelClient.call(listOf(userMessage)) + + // Assert + assertTrue(result is Success) + val success = result as Success + assertEquals(AssistantMessage(assistantResponse), success.value) + verify(exactly = 1) { + mockChatLanguageModel.generate(listOf(LangchainUserMessage("Hello"))) + } + } + + @Test + fun `call should return Failure when model generate throws exception`() { + // Arrange + val userMessage = UserMessage("Hello") + val exception = RuntimeException("Model error") + + every { + mockChatLanguageModel.generate( + listOf(LangchainUserMessage("Hello")), + ) + } throws exception + + // Act + val result = modelClient.call(listOf(userMessage)) + + // Assert + assertTrue(result is Failure) + val failure = result as Failure + assertEquals("Failed to call language model", failure.reason.message) + assertEquals(exception, failure.reason.cause) + verify(exactly = 1) { + mockChatLanguageModel.generate(listOf(LangchainUserMessage("Hello"))) + } + } + + @Test + fun `call should handle AssistantMessage and SystemMessage correctly`() { + // Arrange + val assistantMsg = AssistantMessage("I can help you with that.") + val systemMsg = SystemMessage("System initialized.") + val mockResponse = mockk>() + + every { mockResponse.content().text() } returns "Sure, let's proceed." + every { + mockChatLanguageModel.generate( + listOf( + LangchainUserMessage("Hello"), + LangchainSystemMessage("System initialized."), + dev.langchain4j.data.message.AiMessage("I can help you with that."), + ), + ) + } returns mockResponse + + val messages = + listOf( + UserMessage("Hello"), + systemMsg, + assistantMsg, + ) + + // Act + val result = modelClient.call(messages) + + // Assert + assertTrue(result is Success) + val success = result as Success + assertEquals(AssistantMessage("Sure, let's proceed."), success.value) + verify(exactly = 1) { + mockChatLanguageModel.generate( + listOf( + LangchainUserMessage("Hello"), + LangchainSystemMessage("System initialized."), + dev.langchain4j.data.message.AiMessage("I can help you with that."), + ), + ) + } + } + + @Test + fun `call should handle empty message list`() { + // Arrange + val messages = emptyList() + + // Assuming that the language model can handle empty input and returns a valid response + val assistantResponse = "Hello! How can I assist you today?" + val mockResponse = mockk>() + + every { mockResponse.content().text() } returns assistantResponse + every { mockChatLanguageModel.generate(emptyList()) } returns mockResponse + + // Act + val result = modelClient.call(messages) + + // Assert + assertTrue(result is Success) + val success = result as Success + assertEquals(AssistantMessage(assistantResponse), success.value) + verify(exactly = 1) { + mockChatLanguageModel.generate(emptyList()) + } + } + + @Test + fun `call should propagate AgentRoutingSpecResolverException when model generate throws it`() { + // Arrange + val userMessage = UserMessage("Hello") + val exception = AgentRoutingSpecResolverException("Custom exception") + + every { + mockChatLanguageModel.generate( + listOf(LangchainUserMessage("Hello")), + ) + } throws exception + + // Act + val result = modelClient.call(listOf(userMessage)) + + // Assert + assertTrue(result is Failure) + val failure = result as Failure + assertEquals("Failed to call language model", failure.reason.message) + assertEquals(exception, failure.reason.cause) + verify(exactly = 1) { + mockChatLanguageModel.generate(listOf(LangchainUserMessage("Hello"))) + } + } +}