diff --git a/docs/changelog/118871.yaml b/docs/changelog/118871.yaml new file mode 100644 index 0000000000000..3c1a06d450f39 --- /dev/null +++ b/docs/changelog/118871.yaml @@ -0,0 +1,5 @@ +pr: 118871 +summary: "[Elastic Inference Service] Add ElasticInferenceService Unified ChatCompletions Integration" +area: Inference +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index da0c91861596d..c5bb47ce1e4f7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -154,6 +154,7 @@ static TransportVersion def(int id) { public static final TransportVersion JINA_AI_INTEGRATION_ADDED = def(8_819_00_0); public static final TransportVersion TRACK_INDEX_FAILED_DUE_TO_VERSION_CONFLICT_METRIC = def(8_820_00_0); public static final TransportVersion REPLACE_FAILURE_STORE_OPTIONS_WITH_SELECTOR_SYNTAX = def(8_821_00_0); + public static final TransportVersion ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION = def(8_822_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java index 32ed68953041a..e2f47f1a7a343 100644 --- a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -111,9 +111,13 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalFloat(topP); } - public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List toolCalls) - implements - Writeable { + public record Message( + Content content, + String role, + @Nullable String name, + @Nullable String toolCallId, + @Nullable List toolCalls + ) implements Writeable { @SuppressWarnings("unchecked") static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 58d870ceed6f2..6634eecc2c959 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -242,7 +242,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(9)); + assertThat(services.size(), equalTo(10)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -259,6 +259,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "elastic", "googleaistudio", "openai", "streaming_completion_test_service" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..c0bccb9b2cd49 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.elastic; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; + +import java.util.concurrent.Flow; + +public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { + public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction); + } + + @Override + public boolean canHandleStreamingResponses() { + return true; + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(openAiProcessor); + return new StreamingUnifiedChatCompletionResults(openAiProcessor); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java new file mode 100644 index 0000000000000..66314db1e05bd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; + +import java.util.Objects; +import java.util.function.Supplier; + +public class ElasticInferenceServiceUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager { + + private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceUnifiedCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + public static ElasticInferenceServiceUnifiedCompletionRequestManager of( + ElasticInferenceServiceCompletionModel model, + ThreadPool threadPool, + TraceContext traceContext + ) { + return new ElasticInferenceServiceUnifiedCompletionRequestManager( + Objects.requireNonNull(model), + Objects.requireNonNull(threadPool), + Objects.requireNonNull(traceContext) + ); + } + + private final ElasticInferenceServiceCompletionModel model; + private final TraceContext traceContext; + + private ElasticInferenceServiceUnifiedCompletionRequestManager( + ElasticInferenceServiceCompletionModel model, + ThreadPool threadPool, + TraceContext traceContext + ) { + super(threadPool, model); + this.model = model; + this.traceContext = traceContext; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest( + inferenceInputs.castTo(UnifiedChatInput.class), + model, + traceContext + ); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createCompletionHandler() { + return new ElasticInferenceServiceUnifiedChatCompletionResponseHandler( + "elastic inference service completion", + // We use OpenAiChatCompletionResponseEntity here as the ElasticInferenceServiceResponseEntity fields are a subset of the OpenAI + // one. + OpenAiChatCompletionResponseEntity::fromResponse + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java index d445a779f8230..d1aaa6d5f984f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceSparseEmbeddingsRequest.java @@ -12,13 +12,13 @@ import org.apache.http.entity.ByteArrayEntity; import org.apache.http.message.BasicHeader; import org.elasticsearch.common.Strings; -import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; import java.net.URI; import java.nio.charset.StandardCharsets; @@ -27,13 +27,10 @@ public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest { private final URI uri; - private final ElasticInferenceServiceSparseEmbeddingsModel model; - private final Truncator.TruncationResult truncationResult; private final Truncator truncator; - - private final TraceContext traceContext; + private final TraceContextHandler traceContextHandler; public ElasticInferenceServiceSparseEmbeddingsRequest( Truncator truncator, @@ -45,7 +42,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest( this.truncationResult = truncationResult; this.model = Objects.requireNonNull(model); this.uri = model.uri(); - this.traceContext = traceContext; + this.traceContextHandler = new TraceContextHandler(traceContext); } @Override @@ -56,15 +53,16 @@ public HttpRequest createHttpRequest() { ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); httpPost.setEntity(byteEntity); - if (traceContext != null) { - propagateTraceContext(httpPost); - } - + traceContextHandler.propagateTraceContext(httpPost); httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); return new HttpRequest(httpPost, getInferenceEntityId()); } + public TraceContext getTraceContext() { + return traceContextHandler.traceContext(); + } + @Override public String getInferenceEntityId() { return model.getInferenceEntityId(); @@ -75,15 +73,10 @@ public URI getURI() { return this.uri; } - public TraceContext getTraceContext() { - return traceContext; - } - @Override public Request truncate() { var truncatedInput = truncator.truncate(truncationResult.input()); - - return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContext); + return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContextHandler.traceContext()); } @Override @@ -91,16 +84,4 @@ public boolean[] getTruncationInfo() { return truncationResult.truncated().clone(); } - private void propagateTraceContext(HttpPost httpPost) { - var traceParent = traceContext.traceParent(); - var traceState = traceContext.traceState(); - - if (traceParent != null) { - httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent); - } - - if (traceState != null) { - httpPost.setHeader(Task.TRACE_STATE, traceState); - } - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java new file mode 100644 index 0000000000000..112ead7057933 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.telemetry.TraceContext; +import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request { + + private final ElasticInferenceServiceCompletionModel model; + private final UnifiedChatInput unifiedChatInput; + private final TraceContextHandler traceContextHandler; + + public ElasticInferenceServiceUnifiedChatCompletionRequest( + UnifiedChatInput unifiedChatInput, + ElasticInferenceServiceCompletionModel model, + TraceContext traceContext + ) { + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.model = Objects.requireNonNull(model); + this.traceContextHandler = new TraceContextHandler(traceContext); + } + + @Override + public HttpRequest createHttpRequest() { + var httpPost = new HttpPost(model.uri()); + var requestEntity = Strings.toString( + new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()) + ); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + traceContextHandler.propagateTraceContext(httpPost); + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + // No truncation + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // No truncation + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return true; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..ded8a074478cf --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; + +import java.io.IOException; +import java.util.Objects; + +public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject { + private static final String MODEL_FIELD = "model"; + + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + private final String modelId; + + public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) { + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); + this.modelId = Objects.requireNonNull(modelId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, params); + builder.field(MODEL_FIELD, modelId); + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java index f82e7ff3f5260..7f8626dacc684 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiEmbeddingsRequest.java @@ -27,7 +27,7 @@ import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; -public class OpenAiEmbeddingsRequest implements OpenAiRequest { +public class OpenAiEmbeddingsRequest implements Request { private final Truncator truncator; private final OpenAiAccount account; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index 2e6bdb748fd33..e5b85633a499b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -27,7 +27,7 @@ import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; -public class OpenAiUnifiedChatCompletionRequest implements OpenAiRequest { +public class OpenAiUnifiedChatCompletionRequest implements Request { private final OpenAiAccount account; private final OpenAiChatCompletionModel model; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java index f28c1b3fe8a55..b80100c9e2f79 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -8,10 +8,10 @@ package org.elasticsearch.xpack.inference.external.request.openai; import org.elasticsearch.common.Strings; -import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; import java.io.IOException; @@ -19,168 +19,28 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject { - public static final String NAME_FIELD = "name"; - public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; - public static final String TOOL_CALLS_FIELD = "tool_calls"; - public static final String ID_FIELD = "id"; - public static final String FUNCTION_FIELD = "function"; - public static final String ARGUMENTS_FIELD = "arguments"; - public static final String DESCRIPTION_FIELD = "description"; - public static final String PARAMETERS_FIELD = "parameters"; - public static final String STRICT_FIELD = "strict"; - public static final String TOP_P_FIELD = "top_p"; public static final String USER_FIELD = "user"; - public static final String STREAM_FIELD = "stream"; - private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; private static final String MODEL_FIELD = "model"; - public static final String MESSAGES_FIELD = "messages"; - private static final String ROLE_FIELD = "role"; - private static final String CONTENT_FIELD = "content"; - private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; - private static final String STOP_FIELD = "stop"; - private static final String TEMPERATURE_FIELD = "temperature"; - private static final String TOOL_CHOICE_FIELD = "tool_choice"; - private static final String TOOL_FIELD = "tools"; - private static final String TEXT_FIELD = "text"; - private static final String TYPE_FIELD = "type"; - private static final String STREAM_OPTIONS_FIELD = "stream_options"; - private static final String INCLUDE_USAGE_FIELD = "include_usage"; - private final UnifiedCompletionRequest unifiedRequest; - private final boolean stream; private final OpenAiChatCompletionModel model; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { - Objects.requireNonNull(unifiedChatInput); - - this.unifiedRequest = unifiedChatInput.getRequest(); - this.stream = unifiedChatInput.stream(); + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); this.model = Objects.requireNonNull(model); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.startArray(MESSAGES_FIELD); - { - for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { - builder.startObject(); - { - switch (message.content()) { - case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); - case UnifiedCompletionRequest.ContentObjects contentObjects -> { - builder.startArray(CONTENT_FIELD); - for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { - builder.startObject(); - builder.field(TEXT_FIELD, contentObject.text()); - builder.field(TYPE_FIELD, contentObject.type()); - builder.endObject(); - } - builder.endArray(); - } - case null -> { - // do nothing - } - } - - builder.field(ROLE_FIELD, message.role()); - if (message.name() != null) { - builder.field(NAME_FIELD, message.name()); - } - if (message.toolCallId() != null) { - builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); - } - if (message.toolCalls() != null) { - builder.startArray(TOOL_CALLS_FIELD); - for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { - builder.startObject(); - { - builder.field(ID_FIELD, toolCall.id()); - builder.startObject(FUNCTION_FIELD); - { - builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); - builder.field(NAME_FIELD, toolCall.function().name()); - } - builder.endObject(); - builder.field(TYPE_FIELD, toolCall.type()); - } - builder.endObject(); - } - builder.endArray(); - } - } - builder.endObject(); - } - } - builder.endArray(); + unifiedRequestEntity.toXContent(builder, params); builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); - if (unifiedRequest.maxCompletionTokens() != null) { - builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); - } - - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); - - if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { - builder.field(STOP_FIELD, unifiedRequest.stop()); - } - if (unifiedRequest.temperature() != null) { - builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); - } - if (unifiedRequest.toolChoice() != null) { - if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { - builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); - } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { - builder.startObject(TOOL_CHOICE_FIELD); - { - builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); - builder.startObject(FUNCTION_FIELD); - { - builder.field( - NAME_FIELD, - ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() - ); - } - builder.endObject(); - } - builder.endObject(); - } - } - if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) { - builder.startArray(TOOL_FIELD); - for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { - builder.startObject(); - { - builder.field(TYPE_FIELD, t.type()); - builder.startObject(FUNCTION_FIELD); - { - builder.field(DESCRIPTION_FIELD, t.function().description()); - builder.field(NAME_FIELD, t.function().name()); - builder.field(PARAMETERS_FIELD, t.function().parameters()); - if (t.function().strict() != null) { - builder.field(STRICT_FIELD, t.function().strict()); - } - } - builder.endObject(); - } - builder.endObject(); - } - builder.endArray(); - } - if (unifiedRequest.topP() != null) { - builder.field(TOP_P_FIELD, unifiedRequest.topP()); - } if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { builder.field(USER_FIELD, model.getTaskSettings().user()); } - builder.field(STREAM_FIELD, stream); - if (stream) { - builder.startObject(STREAM_OPTIONS_FIELD); - builder.field(INCLUDE_USAGE_FIELD, true); - builder.endObject(); - } builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..3ea8e28479ef2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.unified; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentFragment; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; + +import java.io.IOException; +import java.util.Objects; + +public class UnifiedChatCompletionRequestEntity implements ToXContentFragment { + + public static final String NAME_FIELD = "name"; + public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; + public static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_FIELD = "function"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String STRICT_FIELD = "strict"; + public static final String TOP_P_FIELD = "top_p"; + public static final String STREAM_FIELD = "stream"; + private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; + public static final String MESSAGES_FIELD = "messages"; + private static final String ROLE_FIELD = "role"; + private static final String CONTENT_FIELD = "content"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private static final String STOP_FIELD = "stop"; + private static final String TEMPERATURE_FIELD = "temperature"; + private static final String TOOL_CHOICE_FIELD = "tool_choice"; + private static final String TOOL_FIELD = "tools"; + private static final String TEXT_FIELD = "text"; + private static final String TYPE_FIELD = "type"; + private static final String STREAM_OPTIONS_FIELD = "stream_options"; + private static final String INCLUDE_USAGE_FIELD = "include_usage"; + + private final UnifiedCompletionRequest unifiedRequest; + private final boolean stream; + + public UnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput) { + Objects.requireNonNull(unifiedChatInput); + + this.unifiedRequest = unifiedChatInput.getRequest(); + this.stream = unifiedChatInput.stream(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startArray(MESSAGES_FIELD); + { + for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { + builder.startObject(); + { + switch (message.content()) { + case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); + case UnifiedCompletionRequest.ContentObjects contentObjects -> { + builder.startArray(CONTENT_FIELD); + for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { + builder.startObject(); + builder.field(TEXT_FIELD, contentObject.text()); + builder.field(TYPE_FIELD, contentObject.type()); + builder.endObject(); + } + builder.endArray(); + } + case null -> { + // do nothing because content is optional + } + } + + builder.field(ROLE_FIELD, message.role()); + if (message.name() != null) { + builder.field(NAME_FIELD, message.name()); + } + if (message.toolCallId() != null) { + builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); + } + if (message.toolCalls() != null) { + builder.startArray(TOOL_CALLS_FIELD); + for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { + builder.startObject(); + { + builder.field(ID_FIELD, toolCall.id()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); + builder.field(NAME_FIELD, toolCall.function().name()); + } + builder.endObject(); + builder.field(TYPE_FIELD, toolCall.type()); + } + builder.endObject(); + } + builder.endArray(); + } + } + builder.endObject(); + } + } + builder.endArray(); + + if (unifiedRequest.maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); + } + + // Underlying providers expect OpenAI to only return 1 possible choice. + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); + + if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { + builder.field(STOP_FIELD, unifiedRequest.stop()); + } + if (unifiedRequest.temperature() != null) { + builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); + } + if (unifiedRequest.toolChoice() != null) { + if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { + builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); + } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { + builder.startObject(TOOL_CHOICE_FIELD); + { + builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field( + NAME_FIELD, + ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() + ); + } + builder.endObject(); + } + builder.endObject(); + } + } + boolean usesTools = unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false; + + if (usesTools) { + builder.startArray(TOOL_FIELD); + for (UnifiedCompletionRequest.Tool tool : unifiedRequest.tools()) { + builder.startObject(); + { + builder.field(TYPE_FIELD, tool.type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(DESCRIPTION_FIELD, tool.function().description()); + builder.field(NAME_FIELD, tool.function().name()); + builder.field(PARAMETERS_FIELD, tool.function().parameters()); + if (tool.function().strict() != null) { + builder.field(STRICT_FIELD, tool.function().strict()); + } + } + builder.endObject(); + } + builder.endObject(); + } + builder.endArray(); + } + if (unifiedRequest.topP() != null) { + builder.field(TOP_P_FIELD, unifiedRequest.topP()); + } + + builder.field(STREAM_FIELD, stream); + if (stream) { + builder.startObject(STREAM_OPTIONS_FIELD); + builder.field(INCLUDE_USAGE_FIELD, true); + builder.endObject(); + } + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 68782488099a1..48416faac6a06 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -31,23 +31,29 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ElasticInferenceServiceUnifiedCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.EnumSet; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -55,7 +61,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class ElasticInferenceService extends SenderService { @@ -64,8 +69,8 @@ public class ElasticInferenceService extends SenderService { private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); private static final String SERVICE_NAME = "Elastic"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING); public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -76,6 +81,11 @@ public ElasticInferenceService( this.elasticInferenceServiceComponents = elasticInferenceServiceComponents; } + @Override + public Set supportedStreamingTasks() { + return COMPLETION_ONLY; + } + @Override protected void doUnifiedCompletionInfer( Model model, @@ -83,7 +93,31 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - throwUnsupportedUnifiedCompletionOperation(NAME); + if (model instanceof ElasticInferenceServiceCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + // We extract the trace context here as it's sufficient to propagate the trace information of the REST request, + // which handles the request to the inference API overall (including the outgoing request, which is started in a new thread + // generating a different "traceparent" as every task and every REST request creates a new span). + var currentTraceInfo = getCurrentTraceInfo(); + + var completionModel = (ElasticInferenceServiceCompletionModel) model; + var overriddenModel = ElasticInferenceServiceCompletionModel.of(completionModel, inputs.getRequest()); + var errorMessage = constructFailedToSendRequestMessage( + overriddenModel.uri(), + String.format(Locale.ROOT, "%s completions", ELASTIC_INFERENCE_SERVICE_IDENTIFIER) + ); + + var requestManager = ElasticInferenceServiceUnifiedCompletionRequestManager.of( + overriddenModel, + getServiceComponents().threadPool(), + currentTraceInfo + ); + var action = new SenderExecutableAction(getSender(), requestManager, errorMessage); + + action.execute(inputs, timeout, listener); } @Override @@ -95,7 +129,7 @@ protected void doInfer( TimeValue timeout, ActionListener listener ) { - if (model instanceof ElasticInferenceServiceModel == false) { + if (model instanceof ElasticInferenceServiceExecutableActionModel == false) { listener.onFailure(createInvalidModelException(model)); return; } @@ -105,7 +139,7 @@ protected void doInfer( // generating a different "traceparent" as every task and every REST request creates a new span). var currentTraceInfo = getCurrentTraceInfo(); - ElasticInferenceServiceModel elasticInferenceServiceModel = (ElasticInferenceServiceModel) model; + ElasticInferenceServiceExecutableActionModel elasticInferenceServiceModel = (ElasticInferenceServiceExecutableActionModel) model; var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), currentTraceInfo); var action = elasticInferenceServiceModel.accept(actionCreator, taskSettings); @@ -197,6 +231,16 @@ private static ElasticInferenceServiceModel createModel( eisServiceComponents, context ); + case COMPLETION -> new ElasticInferenceServiceCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + eisServiceComponents, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceExecutableActionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceExecutableActionModel.java new file mode 100644 index 0000000000000..3223b8b3e8d91 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceExecutableActionModel.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor; + +import java.util.Map; + +public abstract class ElasticInferenceServiceExecutableActionModel extends ElasticInferenceServiceModel { + + public ElasticInferenceServiceExecutableActionModel( + ModelConfigurations configurations, + ModelSecrets secrets, + ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + super(configurations, secrets, rateLimitServiceSettings, elasticInferenceServiceComponents); + } + + public ElasticInferenceServiceExecutableActionModel( + ElasticInferenceServiceExecutableActionModel model, + ServiceSettings serviceSettings + ) { + super(model, serviceSettings); + } + + public abstract ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java index e7809d869fec4..e03cc36e62417 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -11,10 +11,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor; -import java.util.Map; import java.util.Objects; public abstract class ElasticInferenceServiceModel extends Model { @@ -49,7 +46,4 @@ public ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings( public ElasticInferenceServiceComponents elasticInferenceServiceComponents() { return elasticInferenceServiceComponents; } - - public abstract ExecutableAction accept(ElasticInferenceServiceActionVisitor visitor, Map taskSettings); - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index bc2daddc2a346..5146cec1552af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -12,6 +12,10 @@ import java.util.List; +/** + * Encapsulates settings using {@link Setting}. This does not represent service settings that are persisted + * via {@link org.elasticsearch.inference.ServiceSettings}, but rather Elasticsearch settings passed on startup. + */ public class ElasticInferenceServiceSettings { @Deprecated diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java index 54728a92e6254..112be95dac1fd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.elastic; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -15,6 +16,7 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.elastic.ElasticInferenceServiceActionVisitor; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -27,7 +29,7 @@ import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER; -public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceModel { +public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferenceServiceExecutableActionModel { private final URI uri; @@ -57,12 +59,7 @@ public ElasticInferenceServiceSparseEmbeddingsModel( ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings ) { super(model, serviceSettings); - - try { - this.uri = createUri(); - } catch (URISyntaxException e) { - throw new RuntimeException(e); - } + this.uri = createUri(); } ElasticInferenceServiceSparseEmbeddingsModel( @@ -80,12 +77,7 @@ public ElasticInferenceServiceSparseEmbeddingsModel( serviceSettings, elasticInferenceServiceComponents ); - - try { - this.uri = createUri(); - } catch (URISyntaxException e) { - throw new RuntimeException(e); - } + this.uri = createUri(); } @Override @@ -102,19 +94,42 @@ public URI uri() { return uri; } - private URI createUri() throws URISyntaxException { + private URI createUri() throws ElasticsearchStatusException { String modelId = getServiceSettings().modelId(); String modelIdUriPath; switch (modelId) { case ElserModels.ELSER_V2_MODEL -> modelIdUriPath = "ELSERv2"; - default -> throw new IllegalArgumentException( - String.format(Locale.ROOT, "Unsupported model for %s [%s]", ELASTIC_INFERENCE_SERVICE_IDENTIFIER, modelId) + default -> throw new ElasticsearchStatusException( + String.format( + Locale.ROOT, + "Unsupported model [%s] for service [%s] and task type [%s]", + modelId, + ELASTIC_INFERENCE_SERVICE_IDENTIFIER, + TaskType.SPARSE_EMBEDDING + ), + RestStatus.BAD_REQUEST ); } - return new URI( - elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/sparse-text-embeddings/" + modelIdUriPath - ); + try { + // TODO, consider transforming the base URL into a URI for better error handling. + return new URI( + elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/sparse-text-embeddings/" + modelIdUriPath + ); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + "Failed to create URI for service [" + + this.getConfigurations().getService() + + "] with taskType [" + + this.getTaskType() + + "] with model [" + + this.getServiceSettings().modelId() + + "]: " + + e.getMessage(), + RestStatus.BAD_REQUEST, + e + ); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java new file mode 100644 index 0000000000000..84039cd7cc33c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.completion; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +public class ElasticInferenceServiceCompletionModel extends ElasticInferenceServiceModel { + + public static ElasticInferenceServiceCompletionModel of( + ElasticInferenceServiceCompletionModel model, + UnifiedCompletionRequest request + ) { + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings( + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.rateLimitSettings() + ); + + return new ElasticInferenceServiceCompletionModel(model, overriddenServiceSettings); + } + + private final URI uri; + + public ElasticInferenceServiceCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets, + ElasticInferenceServiceComponents elasticInferenceServiceComponents, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + ElasticInferenceServiceCompletionServiceSettings.fromMap(serviceSettings, context), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ); + } + + public ElasticInferenceServiceCompletionModel( + ElasticInferenceServiceCompletionModel model, + ElasticInferenceServiceCompletionServiceSettings serviceSettings + ) { + super(model, serviceSettings); + this.uri = createUri(); + + } + + ElasticInferenceServiceCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + ElasticInferenceServiceCompletionServiceSettings serviceSettings, + @Nullable TaskSettings taskSettings, + @Nullable SecretSettings secretSettings, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings, + elasticInferenceServiceComponents + ); + + this.uri = createUri(); + + } + + @Override + public ElasticInferenceServiceCompletionServiceSettings getServiceSettings() { + return (ElasticInferenceServiceCompletionServiceSettings) super.getServiceSettings(); + } + + public URI uri() { + return uri; + } + + private URI createUri() throws ElasticsearchStatusException { + try { + // TODO, consider transforming the base URL into a URI for better error handling. + return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/chat/completions"); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + "Failed to create URI for service [" + + this.getConfigurations().getService() + + "] with taskType [" + + this.getTaskType() + + "]: " + + e.getMessage(), + RestStatus.BAD_REQUEST, + e + ); + } + } + + // TODO create/refactor the Configuration class to be extensible for different task types (i.e completion, sparse embeddings). +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java new file mode 100644 index 0000000000000..3c8182a7d41a4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + ElasticInferenceServiceRateLimitServiceSettings { + + public static final String NAME = "elastic_inference_service_completion_service_settings"; + + // TODO what value do we put here? + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240L); + + public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + ElasticInferenceService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ElasticInferenceServiceCompletionServiceSettings(modelId, rateLimitSettings); + } + + private final String modelId; + private final RateLimitSettings rateLimitSettings; + + public ElasticInferenceServiceCompletionServiceSettings(String modelId, RateLimitSettings rateLimitSettings) { + this.modelId = Objects.requireNonNull(modelId); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public ElasticInferenceServiceCompletionServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + ElasticInferenceServiceCompletionServiceSettings that = (ElasticInferenceServiceCompletionServiceSettings) object; + return Objects.equals(modelId, that.modelId) && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextHandler.java new file mode 100644 index 0000000000000..92fe214d821db --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/TraceContextHandler.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.telemetry; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.tasks.Task; + +public record TraceContextHandler(TraceContext traceContext) { + + public void propagateTraceContext(HttpPost httpPost) { + if (traceContext == null) { + return; + } + + var traceParent = traceContext.traceParent(); + var traceState = traceContext.traceState(); + + if (traceParent != null) { + httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent); + } + + if (traceState != null) { + httpPost.setHeader(Task.TRACE_STATE, traceState); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 9395ae222e9ba..e02ac7b8853ad 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.InferenceServiceExtension; @@ -21,6 +22,9 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; @@ -33,6 +37,7 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.hamcrest.Matchers; +import java.io.IOException; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -47,6 +52,7 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; @@ -252,4 +258,14 @@ public static ActionListener getModelListenerForException(Class except assertThat(e.getMessage(), is(expectedMessage)); }); } + + public static void assertJsonEquals(String actual, String expected) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + try ( + var actualParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, actual); + var expectedParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, expected); + ) { + assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..75ff63e1314ac --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.elastic; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; + +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; + +public class ElasticInferenceServiceUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + private static final String ROLE = "user"; + + public void testModelUserFieldsSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java index 2037c77a3cf2a..f43b185391697 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.request.openai; -import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; @@ -19,186 +18,19 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Locale; -import java.util.Map; -import java.util.Random; +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; -import static org.hamcrest.Matchers.equalTo; public class OpenAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { - // 1. Basic Serialization - // Test with minimal required fields to ensure basic serialization works. - public void testBasicSerialization() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); - - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "test-endpoint", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } + private static final String ROLE = "user"; + private static final String USER = "a_user"; - // 2. Serialization with All Fields - // Test with all possible fields populated to ensure complete serialization. - public void testSerializationWithAllFields() throws IOException { - // Create a message with all fields populated + public void testModelUserFieldsSerialization() throws IOException { UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "name", - "tool_call_id", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id", - new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), - "type" - ) - ) - ); - - // Create a tool with all fields populated - UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( - "type", - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request with all fields populated - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - "model", - 100L, // maxCompletionTokens - Collections.singletonList("stop"), - 0.9f, // temperature - new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), - Collections.singletonList(tool), - 0.8f // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user", - "name": "name", - "tool_call_id": "tool_call_id", - "tool_calls": [ - { - "id": "id", - "function": { - "arguments": "arguments", - "name": "function_name" - }, - "type": "type" - } - ] - } - ], - "model": "model-name", - "max_completion_tokens": 100, - "n": 1, - "stop": ["stop"], - "temperature": 0.9, - "tool_choice": "tool_choice", - "tools": [ - { - "type": "type", - "function": { - "description": "Fetches the weather in the given location", - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "description": "The location to get the weather for", - "type": "string" - }, - "unit": { - "description": "The unit to return the temperature in", - "type": "string", - "enum": ["F", "C"] - } - }, - "additionalProperties": false, - "required": ["location", "unit"] - }, - "strict": true - } - } - ], - "top_p": 0.8, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - - } - - // 3. Serialization with Null Optional Fields - // Test with optional fields set to null to ensure they are correctly omitted from the output. - public void testSerializationWithNullOptionalFields() throws IOException { - // Create a message with minimal required fields - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + ROLE, null, null, null @@ -206,487 +38,18 @@ public void testSerializationWithNullOptionalFields() throws IOException { var messageList = new ArrayList(); messageList.add(message); - // Create the unified request with optional fields set to null - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - // 4. Serialization with Empty Lists - // Test with fields that are lists set to empty lists to ensure they are correctly serialized. - public void testSerializationWithEmptyLists() throws IOException { - // Create a message with minimal required fields - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - Collections.emptyList() // empty toolCalls list - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request with empty lists - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - Collections.emptyList(), // empty stop list - null, // temperature - null, // toolChoice - Collections.emptyList(), // empty tools list - null // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user", - "tool_calls": [] - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - // 5. Serialization with Nested Objects - // Test with nested objects (e.g., toolCalls, toolChoice, tool) to ensure they are correctly serialized. - public void testSerializationWithNestedObjects() throws IOException { - Random random = Randomness.get(); - - // Generate random values - String randomContent = "Hello, world! " + random.nextInt(1000); - String randomName = "name" + random.nextInt(1000); - String randomToolCallId = "tool_call_id" + random.nextInt(1000); - String randomArguments = "arguments" + random.nextInt(1000); - String randomFunctionName = "function_name" + random.nextInt(1000); - String randomType = "type" + random.nextInt(1000); - String randomModel = "model" + random.nextInt(1000); - String randomStop = "stop" + random.nextInt(1000); - float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); - float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); - - // Create a message with nested toolCalls - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString(randomContent), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - randomName, - randomToolCallId, - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id", - new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), - randomType - ) - ) - ); - - // Create a tool with nested function fields - UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( - randomType, - new UnifiedCompletionRequest.Tool.FunctionField( - "Fetches the weather in the given location", - "get_weather", - createParameters(), - true - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request with nested objects - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - randomModel, - 100L, // maxCompletionTokens - Collections.singletonList(randomStop), - randomTemperature, // temperature - new UnifiedCompletionRequest.ToolChoiceObject( - randomType, - new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) - ), - Collections.singletonList(tool), - randomTopP // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify - String jsonString = Strings.toString(builder); - // Expected JSON should be dynamically generated based on random values - String expectedJson = String.format( - Locale.US, - """ - { - "messages": [ - { - "content": "%s", - "role": "user", - "name": "%s", - "tool_call_id": "%s", - "tool_calls": [ - { - "id": "id", - "function": { - "arguments": "%s", - "name": "%s" - }, - "type": "%s" - } - ] - } - ], - "model": "%s", - "max_completion_tokens": 100, - "n": 1, - "stop": ["%s"], - "temperature": %.5f, - "tool_choice": { - "type": "%s", - "function": { - "name": "%s" - } - }, - "tools": [ - { - "type": "%s", - "function": { - "description": "Fetches the weather in the given location", - "name": "get_weather", - "parameters": { - "type": "object", - "properties": { - "unit": { - "description": "The unit to return the temperature in", - "type": "string", - "enum": ["F", "C"] - }, - "location": { - "description": "The location to get the weather for", - "type": "string" - } - }, - "additionalProperties": false, - "required": ["location", "unit"] - }, - "strict": true - } - } - ], - "top_p": %.5f, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """, - randomContent, - randomName, - randomToolCallId, - randomArguments, - randomFunctionName, - randomType, - randomModel, - randomStop, - randomTemperature, - randomType, - randomFunctionName, - randomType, - randomTopP - ); - assertJsonEquals(jsonString, expectedJson); - } - - // 6. Serialization with Different Content Types - // Test with different content types in messages (e.g., ContentString, ContentObjects) to ensure they are correctly serialized. - public void testSerializationWithDifferentContentTypes() throws IOException { - Random random = Randomness.get(); - - // Generate random values for ContentString - String randomContentString = "Hello, world! " + random.nextInt(1000); + var unifiedRequest = UnifiedCompletionRequest.of(messageList); - // Generate random values for ContentObjects - String randomText = "Random text " + random.nextInt(1000); - String randomType = "type" + random.nextInt(1000); - UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); - - var contentObjectsList = new ArrayList(); - contentObjectsList.add(contentObject); - UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); - - // Create messages with different content types - UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString(randomContentString), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - - UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message( - contentObjects, - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(messageWithString); - messageList.add(messageWithObjects); - - // Create the unified request with both types of messages - UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); - - // Create the unified chat input UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", USER); - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - // Serialize to XContent XContentBuilder builder = JsonXContent.contentBuilder(); entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - // Convert to string and verify - String jsonString = Strings.toString(builder); - String expectedJson = String.format(Locale.US, """ - { - "messages": [ - { - "content": "%s", - "role": "user" - }, - { - "content": [ - { - "text": "%s", - "type": "%s" - } - ], - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """, randomContentString, randomText, randomType); - assertJsonEquals(jsonString, expectedJson); - } - - // 7. Serialization with Special Characters - // Test with special characters in string fields to ensure they are correctly escaped and serialized. - public void testSerializationWithSpecialCharacters() throws IOException { - // Create a message with special characters - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - "name\nwith\nnewlines", - "tool_call_id\twith\ttabs", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id\\with\\backslashes", - new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), - "type" - ) - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - // Create the unified chat input - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Create the entity - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - // Serialize to XContent - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - // Convert to string and verify String jsonString = Strings.toString(builder); String expectedJson = """ - { - "messages": [ - { - "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", - "role": "user", - "name": "name\\nwith\\nnewlines", - "tool_call_id": "tool_call_id\\twith\\ttabs", - "tool_calls": [ - { - "id": "id\\\\with\\\\backslashes", - "function": { - "arguments": "arguments\\"with\\"quotes", - "name": "function_name/with/slashes" - }, - "type": "type" - } - ] - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(jsonString, expectedJson); - } - - // 8. Serialization with Boolean Fields - // Test with boolean fields (stream) set to both true and false to ensure they are correctly serialized. - public void testSerializationWithBooleanFields() throws IOException { - // Create a message with minimal required fields - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - new UnifiedCompletionRequest.ContentString("Hello, world!"), - OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, - null, - null, - null - ); - var messageList = new ArrayList(); - messageList.add(message); - // Create the unified request - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( - messageList, - null, // model - null, // maxCompletionTokens - null, // stop - null, // temperature - null, // toolChoice - null, // tools - null // topP - ); - - OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); - - // Test with stream set to true - UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); - OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputTrue, model); - - XContentBuilder builderTrue = JsonXContent.contentBuilder(); - entityTrue.toXContent(builderTrue, ToXContent.EMPTY_PARAMS); - - String jsonStringTrue = Strings.toString(builderTrue); - String expectedJsonTrue = """ - { - "messages": [ - { - "content": "Hello, world!", - "role": "user" - } - ], - "model": "model-name", - "n": 1, - "stream": true, - "stream_options": { - "include_usage": true - } - } - """; - assertJsonEquals(expectedJsonTrue, jsonStringTrue); - - // Test with stream set to false - UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); - OpenAiUnifiedChatCompletionRequestEntity entityFalse = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputFalse, model); - - XContentBuilder builderFalse = JsonXContent.contentBuilder(); - entityFalse.toXContent(builderFalse, ToXContent.EMPTY_PARAMS); - - String jsonStringFalse = Strings.toString(builderFalse); - String expectedJsonFalse = """ { "messages": [ { @@ -694,103 +57,15 @@ public void testSerializationWithBooleanFields() throws IOException { "role": "user" } ], - "model": "model-name", - "n": 1, - "stream": false - } - """; - assertJsonEquals(expectedJsonFalse, jsonStringFalse); - } - - // 9. a test without the content field to show that the content field is optional - public void testSerializationWithoutContentField() throws IOException { - UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( - null, - "assistant", - "name\nwith\nnewlines", - "tool_call_id\twith\ttabs", - Collections.singletonList( - new UnifiedCompletionRequest.ToolCall( - "id\\with\\backslashes", - new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), - "type" - ) - ) - ); - var messageList = new ArrayList(); - messageList.add(message); - UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); - - UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); - OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); - - OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); - - XContentBuilder builder = JsonXContent.contentBuilder(); - entity.toXContent(builder, ToXContent.EMPTY_PARAMS); - - String jsonString = Strings.toString(builder); - String expectedJson = """ - { - "messages": [ - { - "role": "assistant", - "name": "name\\nwith\\nnewlines", - "tool_call_id": "tool_call_id\\twith\\ttabs", - "tool_calls": [ - { - "id": "id\\\\with\\\\backslashes", - "function": { - "arguments": "arguments\\"with\\"quotes", - "name": "function_name/with/slashes" - }, - "type": "type" - } - ] - } - ], "model": "test-endpoint", "n": 1, "stream": true, "stream_options": { "include_usage": true - } + }, + "user": "a_user" } """; assertJsonEquals(jsonString, expectedJson); } - - public static Map createParameters() { - Map parameters = new LinkedHashMap<>(); - parameters.put("type", "object"); - - Map properties = new HashMap<>(); - - Map location = new HashMap<>(); - location.put("type", "string"); - location.put("description", "The location to get the weather for"); - properties.put("location", location); - - Map unit = new HashMap<>(); - unit.put("type", "string"); - unit.put("description", "The unit to return the temperature in"); - unit.put("enum", new String[] { "F", "C" }); - properties.put("unit", unit); - - parameters.put("properties", properties); - parameters.put("additionalProperties", false); - parameters.put("required", new String[] { "location", "unit" }); - - return parameters; - } - - private void assertJsonEquals(String actual, String expected) throws IOException { - try ( - var actualParser = createParser(JsonXContent.jsonXContent, actual); - var expectedParser = createParser(JsonXContent.jsonXContent, expected) - ) { - assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); - } - } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..0f305866ae988 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,720 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.unified; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Random; + +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; + +public class UnifiedChatCompletionRequestEntityTests extends ESTestCase { + + private static final String ROLE = "user"; + + public void testBasicSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithAllFields() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + "name", + "tool_call_id", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), + "type" + ) + ) + ); + + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList(tool), + 0.8f // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "name": "name", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "function_name" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "max_completion_tokens": 100, + "n": 1, + "stop": ["stop"], + "temperature": 0.9, + "tool_choice": "tool_choice", + "tools": [ + { + "type": "type", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "description": "The location to get the weather for", + "type": "string" + }, + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": 0.8, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + + } + + public void testSerializationWithNullOptionalFields() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithEmptyLists() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null, + Collections.emptyList() // empty toolCalls list + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + Collections.emptyList(), // empty stop list + null, // temperature + null, // toolChoice + Collections.emptyList(), // empty tools list + null // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "tool_calls": [] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithNestedObjects() throws IOException { + Random random = Randomness.get(); + + String randomContent = "Hello, world! " + random.nextInt(1000); + String randomName = "name" + random.nextInt(1000); + String randomToolCallId = "tool_call_id" + random.nextInt(1000); + String randomArguments = "arguments" + random.nextInt(1000); + String randomFunctionName = "function_name" + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + String randomModel = "model" + random.nextInt(1000); + String randomStop = "stop" + random.nextInt(1000); + float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContent), + ROLE, + randomName, + randomToolCallId, + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), + randomType + ) + ) + ); + + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + randomType, + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + randomModel, + 100L, // maxCompletionTokens + Collections.singletonList(randomStop), + randomTemperature, // temperature + new UnifiedCompletionRequest.ToolChoiceObject( + randomType, + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) + ), + Collections.singletonList(tool), + randomTopP // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = String.format( + Locale.US, + """ + { + "messages": [ + { + "content": "%s", + "role": "user", + "name": "%s", + "tool_call_id": "%s", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "%s", + "name": "%s" + }, + "type": "%s" + } + ] + } + ], + "model": "%s", + "max_completion_tokens": 100, + "n": 1, + "stop": ["%s"], + "temperature": %.5f, + "tool_choice": { + "type": "%s", + "function": { + "name": "%s" + } + }, + "tools": [ + { + "type": "%s", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + }, + "location": { + "description": "The location to get the weather for", + "type": "string" + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": %.5f, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, + randomContent, + randomName, + randomToolCallId, + randomArguments, + randomFunctionName, + randomType, + randomModel, + randomStop, + randomTemperature, + randomType, + randomFunctionName, + randomType, + randomTopP + ); + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithDifferentContentTypes() throws IOException { + Random random = Randomness.get(); + + String randomContentString = "Hello, world! " + random.nextInt(1000); + + String randomText = "Random text " + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); + + var contentObjectsList = new ArrayList(); + contentObjectsList.add(contentObject); + UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); + + UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContentString), + ROLE, + null, + null, + null + ); + + UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message(contentObjects, ROLE, null, null, null); + var messageList = new ArrayList(); + messageList.add(messageWithString); + messageList.add(messageWithObjects); + + UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = String.format(Locale.US, """ + { + "messages": [ + { + "content": "%s", + "role": "user" + }, + { + "content": [ + { + "text": "%s", + "type": "%s" + } + ], + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, randomContentString, randomText, randomType); + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithSpecialCharacters() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), + ROLE, + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", + "role": "user", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public void testSerializationWithBooleanFields() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); + OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputTrue, model); + + XContentBuilder builderTrue = JsonXContent.contentBuilder(); + entityTrue.toXContent(builderTrue, ToXContent.EMPTY_PARAMS); + + String jsonStringTrue = Strings.toString(builderTrue); + String expectedJsonTrue = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(expectedJsonTrue, jsonStringTrue); + + UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); + OpenAiUnifiedChatCompletionRequestEntity entityFalse = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputFalse, model); + + XContentBuilder builderFalse = JsonXContent.contentBuilder(); + entityFalse.toXContent(builderFalse, ToXContent.EMPTY_PARAMS); + + String jsonStringFalse = Strings.toString(builderFalse); + String expectedJsonFalse = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": false + } + """; + assertJsonEquals(expectedJsonFalse, jsonStringFalse); + } + + public void testSerializationWithoutContentField() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + null, + "assistant", + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "role": "assistant", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + public static Map createParameters() { + Map parameters = new LinkedHashMap<>(); + parameters.put("type", "object"); + + Map properties = new HashMap<>(); + + Map location = new HashMap<>(); + location.put("type", "string"); + location.put("description", "The location to get the weather for"); + properties.put("location", location); + + Map unit = new HashMap<>(); + unit.put("type", "string"); + unit.put("description", "The unit to return the temperature in"); + unit.put("enum", new String[] { "F", "C" }); + properties.put("unit", unit); + + parameters.put("properties", properties); + parameters.put("additionalProperties", false); + parameters.put("required", new String[] { "location", "unit" }); + + return parameters; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java index 1751e1c3be5e8..dd205b12408ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java @@ -71,7 +71,7 @@ public void testFromMap_InvalidElserModelId() { assertThat(validationException.getMessage(), containsString(Strings.format("unknown ELSER model id [%s]", invalidModelId))); } - public void testToXContent_WritesAlLFields() throws IOException { + public void testToXContent_WritesAllFields() throws IOException { var modelId = ElserModels.ELSER_V1_MODEL; var maxInputTokens = 10; var serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 4b2308b9f1565..8a826e99c3c04 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -114,22 +114,6 @@ public void testParseRequestConfig_CreatesASparseEmbeddingsModel() throws IOExce } } - public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { - try (var service = createServiceWithMockSender()) { - var failureListener = getModelListenerForException( - ElasticsearchStatusException.class, - "The [elastic] service does not support task type [completion]" - ); - - service.parseRequestConfig( - "id", - TaskType.COMPLETION, - getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of()), - failureListener - ); - } - } - public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createServiceWithMockSender()) { var config = getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of()); @@ -498,7 +482,7 @@ public void testGetConfiguration() throws Exception { { "service": "elastic", "name": "Elastic", - "task_types": ["sparse_embedding"], + "task_types": ["sparse_embedding" , "completion"], "configurations": { "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java new file mode 100644 index 0000000000000..cc1463232e7e5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.completion; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceCompletionModelTests extends ESTestCase { + + public void testOverridingModelId() { + var originalModel = new ElasticInferenceServiceCompletionModel( + "id", + TaskType.COMPLETION, + "elastic", + new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + new ElasticInferenceServiceComponents("url") + ); + + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("message"), "user", null, null, null)), + "new_model_id", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = ElasticInferenceServiceCompletionModel.of(originalModel, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("new_model_id")); + assertThat(overriddenModel.getTaskType(), is(TaskType.COMPLETION)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..0f6386f670338 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class ElasticInferenceServiceCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< + ElasticInferenceServiceCompletionServiceSettings> { + + @Override + protected Writeable.Reader instanceReader() { + return ElasticInferenceServiceCompletionServiceSettings::new; + } + + @Override + protected ElasticInferenceServiceCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected ElasticInferenceServiceCompletionServiceSettings mutateInstance(ElasticInferenceServiceCompletionServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, ElasticInferenceServiceCompletionServiceSettingsTests::createRandom); + } + + public void testFromMap() { + var modelId = "model_id"; + + var serviceSettings = ElasticInferenceServiceCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + ConfigurationParseContext.REQUEST + ); + + assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(240L)))); + } + + public void testFromMap_MissingModelId_ThrowsException() { + ValidationException validationException = expectThrows( + ValidationException.class, + () -> ElasticInferenceServiceCompletionServiceSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.REQUEST) + ); + + assertThat(validationException.getMessage(), containsString("does not contain the required setting [model_id]")); + } + + public void testToXContent_WritesAllFields() throws IOException { + var modelId = "model_id"; + var serviceSettings = new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(Strings.format(""" + {"model_id":"%s","rate_limit":{"requests_per_minute":1000}}""", modelId))); + } + + public static ElasticInferenceServiceCompletionServiceSettings createRandom() { + return new ElasticInferenceServiceCompletionServiceSettings(randomAlphaOfLength(4), RateLimitSettingsTests.createRandom()); + } +}