Skip to content

Commit

Permalink
[Elastic Inference Service] Add ElasticInferenceService Unified ChatC…
Browse files Browse the repository at this point in the history
…ompletions Integration (#118871)
  • Loading branch information
jaybcee authored Jan 8, 2025
1 parent b05ab7a commit 18345c4
Show file tree
Hide file tree
Showing 29 changed files with 1,814 additions and 960 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/118871.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118871
summary: "[Elastic Inference Service] Add ElasticInferenceService Unified ChatCompletions Integration"
area: Inference
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ static TransportVersion def(int id) {
public static final TransportVersion JINA_AI_INTEGRATION_ADDED = def(8_819_00_0);
public static final TransportVersion TRACK_INDEX_FAILED_DUE_TO_VERSION_CONFLICT_METRIC = def(8_820_00_0);
public static final TransportVersion REPLACE_FAILURE_STORE_OPTIONS_WITH_SELECTOR_SYNTAX = def(8_821_00_0);
public static final TransportVersion ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION = def(8_822_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,13 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(topP);
}

public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List<ToolCall> toolCalls)
implements
Writeable {
public record Message(
Content content,
String role,
@Nullable String name,
@Nullable String toolCallId,
@Nullable List<ToolCall> toolCalls
) implements Writeable {

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(9));
assertThat(services.size(), equalTo(10));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -259,6 +259,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"elastic",
"googleaistudio",
"openai",
"streaming_completion_test_service"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.elastic;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;

import java.util.concurrent.Flow;

public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction);
}

@Override
public boolean canHandleStreamingResponses() {
return true;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.Objects;
import java.util.function.Supplier;

public class ElasticInferenceServiceUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager {

private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceUnifiedCompletionRequestManager.class);

private static final ResponseHandler HANDLER = createCompletionHandler();

public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
ElasticInferenceServiceCompletionModel model,
ThreadPool threadPool,
TraceContext traceContext
) {
return new ElasticInferenceServiceUnifiedCompletionRequestManager(
Objects.requireNonNull(model),
Objects.requireNonNull(threadPool),
Objects.requireNonNull(traceContext)
);
}

private final ElasticInferenceServiceCompletionModel model;
private final TraceContext traceContext;

private ElasticInferenceServiceUnifiedCompletionRequestManager(
ElasticInferenceServiceCompletionModel model,
ThreadPool threadPool,
TraceContext traceContext
) {
super(threadPool, model);
this.model = model;
this.traceContext = traceContext;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {

ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest(
inferenceInputs.castTo(UnifiedChatInput.class),
model,
traceContext
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createCompletionHandler() {
return new ElasticInferenceServiceUnifiedChatCompletionResponseHandler(
"elastic inference service completion",
// We use OpenAiChatCompletionResponseEntity here as the ElasticInferenceServiceResponseEntity fields are a subset of the OpenAI
// one.
OpenAiChatCompletionResponseEntity::fromResponse
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
Expand All @@ -27,13 +27,10 @@
public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {

private final URI uri;

private final ElasticInferenceServiceSparseEmbeddingsModel model;

private final Truncator.TruncationResult truncationResult;
private final Truncator truncator;

private final TraceContext traceContext;
private final TraceContextHandler traceContextHandler;

public ElasticInferenceServiceSparseEmbeddingsRequest(
Truncator truncator,
Expand All @@ -45,7 +42,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
this.truncationResult = truncationResult;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
this.traceContext = traceContext;
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
Expand All @@ -56,15 +53,16 @@ public HttpRequest createHttpRequest() {
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

if (traceContext != null) {
propagateTraceContext(httpPost);
}

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
Expand All @@ -75,32 +73,15 @@ public URI getURI() {
return this.uri;
}

public TraceContext getTraceContext() {
return traceContext;
}

@Override
public Request truncate() {
var truncatedInput = truncator.truncate(truncationResult.input());

return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContext);
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContextHandler.traceContext());
}

@Override
public boolean[] getTruncationInfo() {
return truncationResult.truncated().clone();
}

private void propagateTraceContext(HttpPost httpPost) {
var traceParent = traceContext.traceParent();
var traceState = traceContext.traceState();

if (traceParent != null) {
httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent);
}

if (traceState != null) {
httpPost.setHeader(Task.TRACE_STATE, traceState);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {

private final ElasticInferenceServiceCompletionModel model;
private final UnifiedChatInput unifiedChatInput;
private final TraceContextHandler traceContextHandler;

public ElasticInferenceServiceUnifiedChatCompletionRequest(
UnifiedChatInput unifiedChatInput,
ElasticInferenceServiceCompletionModel model,
TraceContext traceContext
) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequest createHttpRequest() {
var httpPost = new HttpPost(model.uri());
var requestEntity = Strings.toString(
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
// No truncation
return this;
}

@Override
public boolean[] getTruncationInfo() {
// No truncation
return null;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;

import java.io.IOException;
import java.util.Objects;

public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
private static final String MODEL_FIELD = "model";

private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
private final String modelId;

public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.modelId = Objects.requireNonNull(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
unifiedRequestEntity.toXContent(builder, params);
builder.field(MODEL_FIELD, modelId);
builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;

public class OpenAiEmbeddingsRequest implements OpenAiRequest {
public class OpenAiEmbeddingsRequest implements Request {

private final Truncator truncator;
private final OpenAiAccount account;
Expand Down
Loading

0 comments on commit 18345c4

Please sign in to comment.