Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Elastic Inference Service] Add ElasticInferenceService Unified ChatCompletions Integration #118871

Merged
merged 67 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
ccec39b
Starting completion model
jonathan-buttner Dec 6, 2024
467747f
Adding model
jonathan-buttner Dec 6, 2024
69ba46d
initial implementation of request and response handling, manager, and…
maxhniebergall Dec 9, 2024
39e2c27
Working response from openai
jonathan-buttner Dec 9, 2024
7984b69
Update docs/changelog/118301.yaml
jonathan-buttner Dec 9, 2024
be588f4
Fixing comment
jonathan-buttner Dec 10, 2024
38a58f9
Adding some initial tests
jonathan-buttner Dec 11, 2024
cad6f1e
Merge branch 'main' of github.com:elastic/elasticsearch into ml-eis-i…
jonathan-buttner Dec 11, 2024
2e4fb05
Moving tests around
jonathan-buttner Dec 11, 2024
1c0ab90
Merge branch 'main' into ml-eis-integration
jaybcee Dec 17, 2024
1abe7e6
Address some TODOs
jaybcee Dec 18, 2024
8e47f34
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Dec 18, 2024
9840c62
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Dec 18, 2024
cfd7580
Remove a TODO
jaybcee Dec 18, 2024
4fb6930
[CI] Auto commit changes from spotless
elasticsearchmachine Dec 18, 2024
6a3d916
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Dec 18, 2024
2730017
Fix tests
jaybcee Dec 18, 2024
729be3d
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Dec 19, 2024
ab979a1
Fix more tests
jaybcee Dec 19, 2024
2ce281f
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Dec 19, 2024
775c81b
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Dec 19, 2024
c609559
Update docs/changelog/118871.yaml
jaybcee Dec 19, 2024
5704a29
Update docs/changelog/118871.yaml
jaybcee Dec 19, 2024
42c96a2
Delete docs/changelog/118301.yaml
jaybcee Dec 19, 2024
34073db
Update docs/changelog/118871.yaml
jaybcee Dec 19, 2024
b03ee46
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
jaybcee Dec 19, 2024
dfa0a05
Rename EISUnifiedChatCompletionResponseHandler
jaybcee Dec 19, 2024
566bacc
Renames to ElasticInferenceServiceUnifiedCompletionRequestManager
jaybcee Dec 19, 2024
3d888b1
Remove eis from response
jaybcee Dec 19, 2024
80db7d3
Renames EISUnifiedChatCompletionRequest
jaybcee Dec 19, 2024
df74a82
Renames and comments
jaybcee Dec 19, 2024
d7dbf61
Adds n=1 hardcode comment
jaybcee Dec 19, 2024
1f96e4a
Fixes
jaybcee Dec 19, 2024
03a139a
Renames tool
jaybcee Dec 19, 2024
e90e693
Updates transport
jaybcee Dec 19, 2024
26f1ac3
propagateTraceContext extraction
jaybcee Dec 19, 2024
1d8d641
format
jaybcee Dec 19, 2024
3399d61
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Dec 19, 2024
fcfbfd3
[CI] Auto commit changes from spotless
elasticsearchmachine Dec 19, 2024
9b5503a
Clean up trace
jaybcee Dec 19, 2024
2a3faa4
[CI] Auto commit changes from spotless
elasticsearchmachine Dec 19, 2024
bd715c1
Merge branch 'main' into ml-eis-integration-jbc
maxhniebergall Dec 23, 2024
67d21f8
finish merge
maxhniebergall Dec 23, 2024
0feba86
Remove OpenAiRequest as it was uneeded
maxhniebergall Dec 23, 2024
93cc995
Address comments
jaybcee Dec 23, 2024
9dd88bf
[CI] Auto commit changes from spotless
elasticsearchmachine Dec 23, 2024
13c0bdf
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 3, 2025
635f734
[CI] Auto commit changes from spotless
elasticsearchmachine Jan 3, 2025
deab545
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 6, 2025
a089586
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 6, 2025
ce4cdf0
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 6, 2025
daaa0e1
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 6, 2025
7fd0883
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
jaybcee Jan 7, 2025
f84b1da
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
jaybcee Jan 7, 2025
eefcecf
Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/…
jaybcee Jan 7, 2025
567c54e
Update x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/…
jaybcee Jan 7, 2025
01b60d6
Update x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/…
jaybcee Jan 7, 2025
503fb4b
[CI] Auto commit changes from spotless
elasticsearchmachine Jan 7, 2025
44f9f8d
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 7, 2025
0419ab9
Address comments
jaybcee Jan 7, 2025
e9afd72
[CI] Auto commit changes from spotless
elasticsearchmachine Jan 7, 2025
f105279
Update changelog
jaybcee Jan 7, 2025
a9dc184
Better error messages
jaybcee Jan 7, 2025
343519d
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 7, 2025
7c8b1b3
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 8, 2025
f48eb06
Merge branch 'main' into ml-eis-integration-jbc
jaybcee Jan 8, 2025
e014357
[CI] Auto commit changes from spotless
elasticsearchmachine Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/118871.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118871
summary: "[Elastic Inference Service] Add ElasticInferenceService Unified ChatCompletions Integration"
area: Inference
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ static TransportVersion def(int id) {
public static final TransportVersion JINA_AI_INTEGRATION_ADDED = def(8_819_00_0);
public static final TransportVersion TRACK_INDEX_FAILED_DUE_TO_VERSION_CONFLICT_METRIC = def(8_820_00_0);
public static final TransportVersion REPLACE_FAILURE_STORE_OPTIONS_WITH_SELECTOR_SYNTAX = def(8_821_00_0);
public static final TransportVersion ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION = def(8_822_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,13 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalFloat(topP);
}

public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List<ToolCall> toolCalls)
implements
Writeable {
public record Message(
Content content,
String role,
@Nullable String name,
@Nullable String toolCallId,
@Nullable List<ToolCall> toolCalls
) implements Writeable {

@SuppressWarnings("unchecked")
static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(9));
assertThat(services.size(), equalTo(10));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -259,6 +259,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"elastic",
"googleaistudio",
"openai",
"streaming_completion_test_service"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.elastic;

import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;

import java.util.concurrent.Flow;

public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler {
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction);
}

@Override
public boolean canHandleStreamingResponses() {
return true;
}

@Override
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec

flow.subscribe(serverSentEventProcessor);
serverSentEventProcessor.subscribe(openAiProcessor);
return new StreamingUnifiedChatCompletionResults(openAiProcessor);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceUnifiedChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;

import java.util.Objects;
import java.util.function.Supplier;

public class ElasticInferenceServiceUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager {

private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceUnifiedCompletionRequestManager.class);

private static final ResponseHandler HANDLER = createCompletionHandler();

public static ElasticInferenceServiceUnifiedCompletionRequestManager of(
ElasticInferenceServiceCompletionModel model,
ThreadPool threadPool,
TraceContext traceContext
) {
return new ElasticInferenceServiceUnifiedCompletionRequestManager(
Objects.requireNonNull(model),
Objects.requireNonNull(threadPool),
Objects.requireNonNull(traceContext)
);
}

private final ElasticInferenceServiceCompletionModel model;
private final TraceContext traceContext;

private ElasticInferenceServiceUnifiedCompletionRequestManager(
ElasticInferenceServiceCompletionModel model,
ThreadPool threadPool,
TraceContext traceContext
) {
super(threadPool, model);
this.model = model;
this.traceContext = traceContext;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {

ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest(
inferenceInputs.castTo(UnifiedChatInput.class),
model,
traceContext
);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createCompletionHandler() {
return new ElasticInferenceServiceUnifiedChatCompletionResponseHandler(
"elastic inference service completion",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this string with spaces in it correct? Seems like a bit of a weird value. Normally I think our non-error-message strings uses underscores instead of spaces. Definitely just a nit though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The OpenAI one does the same thing. Not sure whats best, but I think they should be consistent. I'll keep this for now.

// We use OpenAiChatCompletionResponseEntity here as the ElasticInferenceServiceResponseEntity fields are a subset of the OpenAI
// one.
OpenAiChatCompletionResponseEntity::fromResponse
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
Expand All @@ -27,13 +27,10 @@
public class ElasticInferenceServiceSparseEmbeddingsRequest implements ElasticInferenceServiceRequest {

private final URI uri;

private final ElasticInferenceServiceSparseEmbeddingsModel model;

private final Truncator.TruncationResult truncationResult;
private final Truncator truncator;

private final TraceContext traceContext;
private final TraceContextHandler traceContextHandler;

public ElasticInferenceServiceSparseEmbeddingsRequest(
Truncator truncator,
Expand All @@ -45,7 +42,7 @@ public ElasticInferenceServiceSparseEmbeddingsRequest(
this.truncationResult = truncationResult;
this.model = Objects.requireNonNull(model);
this.uri = model.uri();
this.traceContext = traceContext;
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
Expand All @@ -56,15 +53,16 @@ public HttpRequest createHttpRequest() {
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

if (traceContext != null) {
propagateTraceContext(httpPost);
}

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

public TraceContext getTraceContext() {
return traceContextHandler.traceContext();
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
Expand All @@ -75,32 +73,15 @@ public URI getURI() {
return this.uri;
}

public TraceContext getTraceContext() {
return traceContext;
}

@Override
public Request truncate() {
var truncatedInput = truncator.truncate(truncationResult.input());

return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContext);
return new ElasticInferenceServiceSparseEmbeddingsRequest(truncator, truncatedInput, model, traceContextHandler.traceContext());
}

@Override
public boolean[] getTruncationInfo() {
return truncationResult.truncated().clone();
}

private void propagateTraceContext(HttpPost httpPost) {
var traceParent = traceContext.traceParent();
var traceState = traceContext.traceState();

if (traceParent != null) {
httpPost.setHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParent);
}

if (traceState != null) {
httpPost.setHeader(Task.TRACE_STATE, traceState);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.message.BasicHeader;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request {

private final ElasticInferenceServiceCompletionModel model;
private final UnifiedChatInput unifiedChatInput;
private final TraceContextHandler traceContextHandler;

public ElasticInferenceServiceUnifiedChatCompletionRequest(
UnifiedChatInput unifiedChatInput,
ElasticInferenceServiceCompletionModel model,
TraceContext traceContext
) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
this.traceContextHandler = new TraceContextHandler(traceContext);
}

@Override
public HttpRequest createHttpRequest() {
var httpPost = new HttpPost(model.uri());
var requestEntity = Strings.toString(
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId())
);

ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));
httpPost.setEntity(byteEntity);

traceContextHandler.propagateTraceContext(httpPost);
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
// No truncation
return this;
}

@Override
public boolean[] getTruncationInfo() {
// No truncation
return null;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.elastic;

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;

import java.io.IOException;
import java.util.Objects;

public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
private static final String MODEL_FIELD = "model";

private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
private final String modelId;

public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.modelId = Objects.requireNonNull(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
unifiedRequestEntity.toXContent(builder, params);
builder.field(MODEL_FIELD, modelId);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reminder for myself that we can probably merge this and the openai class back together since they're sending the same stuff.

builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader;

public class OpenAiEmbeddingsRequest implements OpenAiRequest {
public class OpenAiEmbeddingsRequest implements Request {

private final Truncator truncator;
private final OpenAiAccount account;
Expand Down
Loading
Loading