Skip to content

Commit

Permalink
feat: LangChain4j support
Browse files Browse the repository at this point in the history
  • Loading branch information
xmxnt committed Dec 17, 2024
1 parent ee0c95f commit c9e04e1
Show file tree
Hide file tree
Showing 6 changed files with 527 additions and 79 deletions.
10 changes: 8 additions & 2 deletions lmos-router-llm/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
dependencies {
api(project(":lmos-router-core"))
implementation("org.slf4j:slf4j-api:1.7.25")
implementation("com.azure:azure-ai-openai:1.0.0-beta.10")
implementation("com.azure:azure-identity:1.14.0")
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json-jvm:1.7.3")
implementation("dev.langchain4j:langchain4j-open-ai:0.36.2")
compileOnly("dev.langchain4j:langchain4j-anthropic:0.36.2")
compileOnly("dev.langchain4j:langchain4j-google-ai-gemini:0.36.2")
compileOnly("dev.langchain4j:langchain4j-ollama:0.36.2")

testImplementation("dev.langchain4j:langchain4j-anthropic:0.36.2")
testImplementation("dev.langchain4j:langchain4j-google-ai-gemini:0.36.2")
testImplementation("dev.langchain4j:langchain4j-ollama:0.36.2")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// SPDX-FileCopyrightText: 2024 Deutsche Telekom AG
//
// SPDX-License-Identifier: Apache-2.0

package ai.ancf.lmos.router.llm

import ai.ancf.lmos.router.core.*
import dev.langchain4j.data.message.AiMessage
import dev.langchain4j.model.anthropic.AnthropicChatModel
import dev.langchain4j.model.chat.ChatLanguageModel
import dev.langchain4j.model.chat.request.ResponseFormat
import dev.langchain4j.model.chat.request.ResponseFormatType
import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel
import dev.langchain4j.model.ollama.OllamaChatModel
import dev.langchain4j.model.openai.OpenAiChatModel.OpenAiChatModelBuilder

/**
* A model client that uses Langchain4j(https://docs.langchain4j.dev/) to call a language model.
*
* @param chatLanguageModel The language model.
*/
class LangChainModelClient(
private val chatLanguageModel: ChatLanguageModel,
) : ModelClient {
override fun call(messages: List<ChatMessage>): Result<ChatMessage, AgentRoutingSpecResolverException> {
try {
val response =
chatLanguageModel.generate(
messages.map {
when (it) {
is UserMessage -> dev.langchain4j.data.message.UserMessage(it.content)
is AssistantMessage -> AiMessage(it.content)
is SystemMessage -> dev.langchain4j.data.message.SystemMessage(it.content)
else -> throw AgentRoutingSpecResolverException("Unknown message type")
}
},
)
return Success(AssistantMessage(response.content().text()))
} catch (e: Exception) {
return Failure(AgentRoutingSpecResolverException("Failed to call language model", e))
}
}
}

/**
* A factory class to create a Langchain4j language model client. The factory creates a client based on the given properties.
*
* The factory supports the following providers:
* - OPENAI
* - ANTHROPIC
* - GEMINI
* - OLLAMA
* - GROQ
* - OTHER (for other providers which are have OpenAI compatible API)
*/
class LangChainChatModelFactory private constructor() {
companion object {
fun createClient(properties: ModelClientProperties): ChatLanguageModel {
return when (properties.provider) {
LangChainClientProvider.OPENAI.name.lowercase(),
LangChainClientProvider.GROQ.name.lowercase(),
LangChainClientProvider.OTHER.name.lowercase(),
-> {
OpenAiChatModelBuilder().baseUrl(properties.url)
.apiKey(properties.apiKey)
.modelName(properties.model)
.maxTokens(properties.maxTokens)
.temperature(properties.temperature)
.responseFormat(properties.format)
.topP(properties.topP)
.build()
}

LangChainClientProvider.ANTHROPIC.name.lowercase() -> {
AnthropicChatModel.builder().baseUrl(properties.url)
.apiKey(properties.apiKey)
.modelName(properties.model)
.maxTokens(properties.maxTokens)
.temperature(properties.temperature)
.topP(properties.topP)
.topK(properties.topK)
.build()
}

LangChainClientProvider.GEMINI.name.lowercase() -> {
GoogleAiGeminiChatModel.builder()
.apiKey(properties.apiKey)
.modelName(properties.model)
.maxOutputTokens(properties.maxTokens)
.temperature(properties.temperature)
.topK(properties.topK)
.topP(properties.topP)
.responseFormat(ResponseFormat.builder().type(ResponseFormatType.valueOf(properties.format)).build())
.build()
}

LangChainClientProvider.OLLAMA.name.lowercase() -> {
OllamaChatModel.builder().baseUrl(properties.url)
.modelName(properties.model)
.temperature(properties.temperature)
.build()
}

else -> {
throw IllegalArgumentException("Unknown model client properties: $properties")
}
}
}
}
}

enum class LangChainClientProvider {
OPENAI,
ANTHROPIC,
GEMINI,
OLLAMA,
GROQ,
OTHER,
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
package ai.ancf.lmos.router.llm

import ai.ancf.lmos.router.core.*
import com.azure.ai.openai.OpenAIClientBuilder
import com.azure.ai.openai.models.*
import com.azure.core.credential.AzureKeyCredential

/**
* The [ModelClient] interface represents a client that can call a model.
Expand All @@ -29,48 +26,28 @@ interface ModelClient {
*
* @param defaultModelClientProperties The properties for the default model client.
*/
class DefaultModelClient(private val defaultModelClientProperties: DefaultModelClientProperties) : ModelClient {
private var client =
OpenAIClientBuilder()
.credential(AzureKeyCredential(defaultModelClientProperties.openAiApiKey))
.endpoint(defaultModelClientProperties.openAiUrl)
.buildClient()

class DefaultModelClient(
private val defaultModelClientProperties: DefaultModelClientProperties,
private val delegate: LangChainModelClient =
LangChainModelClient(LangChainChatModelFactory.createClient(defaultModelClientProperties)),
) : ModelClient {
override fun call(messages: List<ChatMessage>): Result<ChatMessage, AgentRoutingSpecResolverException> {
try {
val chatCompletionsOptions =
ChatCompletionsOptions(
messages.map {
when (it) {
is UserMessage -> ChatRequestUserMessage(it.content)
is AssistantMessage -> ChatRequestAssistantMessage(it.content)
is SystemMessage -> ChatRequestSystemMessage(it.content)
else -> throw AgentRoutingSpecResolverException("Unknown message type")
}
},
).setTemperature(defaultModelClientProperties.temperature)
.setModel(defaultModelClientProperties.model)
.setMaxTokens(defaultModelClientProperties.maxTokens)
.apply {
defaultModelClientProperties.format.let {
responseFormat = ChatCompletionsJsonResponseFormat()
}
}

return Success(
AssistantMessage(
client.getChatCompletions(
defaultModelClientProperties.model,
chatCompletionsOptions,
).choices.first().message.content,
),
)
} catch (e: Exception) {
return Failure(AgentRoutingSpecResolverException("Failed to call model", e))
}
return delegate.call(messages)
}
}

abstract class ModelClientProperties(
open val provider: String,
open val apiKey: String?,
open val url: String,
open val model: String,
open val maxTokens: Int,
open val temperature: Double,
open val format: String,
open val topK: Int = 0,
open val topP: Double = 0.0,
)

/**
* The [DefaultModelClientProperties] data class represents the properties for the default model client.
*
Expand All @@ -84,8 +61,19 @@ class DefaultModelClient(private val defaultModelClientProperties: DefaultModelC
data class DefaultModelClientProperties(
val openAiUrl: String = "https://api.openai.com/v1/chat/completions",
val openAiApiKey: String,
val model: String = "gpt-4o-mini",
val maxTokens: Int = 200,
val temperature: Double = 0.0,
val format: String = "json_object",
)
override val model: String = "gpt-4o-mini",
override val maxTokens: Int = 200,
override val temperature: Double = 0.0,
override val format: String = "json_object",
override val apiKey: String? = openAiApiKey,
override val url: String = openAiUrl,
override val provider: String = "openai",
) : ModelClientProperties(
provider,
apiKey,
url,
model,
maxTokens,
temperature,
format,
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,16 @@
package ai.ancf.lmos.router.llm

import ai.ancf.lmos.router.core.*
import com.azure.ai.openai.OpenAIClient
import com.azure.ai.openai.OpenAIClientBuilder
import com.azure.ai.openai.models.ChatChoice
import com.azure.ai.openai.models.ChatCompletions
import com.azure.ai.openai.models.ChatResponseMessage
import com.azure.core.credential.AzureKeyCredential
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkConstructor
import io.mockk.verify
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals

class DefaultModelClientTest {
private lateinit var defaultModelClientProperties: DefaultModelClientProperties
private lateinit var client: OpenAIClient
private lateinit var defaultModelClient: DefaultModelClient
private lateinit var delegate: LangChainModelClient

@BeforeEach
fun setUp() {
Expand All @@ -36,14 +28,8 @@ class DefaultModelClientTest {
format = "json_object",
)

client = mockk()
mockkConstructor(OpenAIClientBuilder::class)
every {
anyConstructed<OpenAIClientBuilder>().credential(any<AzureKeyCredential>()).endpoint(any<String>())
.buildClient()
} returns client

defaultModelClient = DefaultModelClient(defaultModelClientProperties)
delegate = mockk()
defaultModelClient = DefaultModelClient(defaultModelClientProperties, delegate)
}

@Test
Expand All @@ -54,20 +40,10 @@ class DefaultModelClientTest {
AssistantMessage("I am here to help you."),
)

val mockResponse = mockk<ChatCompletions>()
val mockChoice = mockk<ChatChoice>()
val mockChatMessage = mockk<ChatResponseMessage>()
every { mockResponse.choices } returns listOf(mockChoice)
every { mockChoice.message } returns mockChatMessage
every { mockChatMessage.content } returns "This is a response from the assistant."

every { client.getChatCompletions(any(), any()) } returns mockResponse

every { delegate.call(any()) } returns Success(AssistantMessage("This is a response from the assistant."))
val result = defaultModelClient.call(messages)

assertEquals(result, Success(AssistantMessage("This is a response from the assistant.")))

verify { client.getChatCompletions(any(), any()) }
}

@Test
Expand All @@ -77,11 +53,10 @@ class DefaultModelClientTest {
UserMessage("Hello, how can I assist you today?"),
)

every { client.getChatCompletions(any(), any()) } throws RuntimeException("API failure")

every { delegate.call(any()) } returns Failure(AgentRoutingSpecResolverException("Failed to call language model"))
val result = defaultModelClient.call(messages)

assert(result is Failure)
assert((result as Failure).exceptionOrNull()?.message == "Failed to call model")
assert((result as Failure).exceptionOrNull()?.message == "Failed to call language model")
}
}
Loading

0 comments on commit c9e04e1

Please sign in to comment.