-
Notifications
You must be signed in to change notification settings - Fork 25k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Elastic Inference Service] Add ElasticInferenceService Unified ChatC…
…ompletions Integration (#118871)
- Loading branch information
Showing
29 changed files
with
1,814 additions
and
960 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 118871 | ||
summary: "[Elastic Inference Service] Add ElasticInferenceService Unified ChatCompletions Integration" | ||
area: Inference | ||
type: enhancement | ||
issues: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
...ference/external/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.elastic; | ||
|
||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; | ||
import org.elasticsearch.xpack.inference.external.http.HttpResult; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; | ||
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedStreamingProcessor; | ||
import org.elasticsearch.xpack.inference.external.request.Request; | ||
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; | ||
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; | ||
|
||
import java.util.concurrent.Flow; | ||
|
||
public class ElasticInferenceServiceUnifiedChatCompletionResponseHandler extends ElasticInferenceServiceResponseHandler { | ||
public ElasticInferenceServiceUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { | ||
super(requestType, parseFunction); | ||
} | ||
|
||
@Override | ||
public boolean canHandleStreamingResponses() { | ||
return true; | ||
} | ||
|
||
@Override | ||
public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) { | ||
var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); | ||
var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); // EIS uses the unified API spec | ||
|
||
flow.subscribe(serverSentEventProcessor); | ||
serverSentEventProcessor.subscribe(openAiProcessor); | ||
return new StreamingUnifiedChatCompletionResults(openAiProcessor); | ||
} | ||
} |
82 changes: 82 additions & 0 deletions
82
...nference/external/http/sender/ElasticInferenceServiceUnifiedCompletionRequestManager.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.http.sender; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.threadpool.ThreadPool; | ||
import org.elasticsearch.xpack.inference.external.elastic.ElasticInferenceServiceUnifiedChatCompletionResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.request.elastic.ElasticInferenceServiceUnifiedChatCompletionRequest; | ||
import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; | ||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; | ||
import org.elasticsearch.xpack.inference.telemetry.TraceContext; | ||
|
||
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
public class ElasticInferenceServiceUnifiedCompletionRequestManager extends ElasticInferenceServiceRequestManager { | ||
|
||
private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceUnifiedCompletionRequestManager.class); | ||
|
||
private static final ResponseHandler HANDLER = createCompletionHandler(); | ||
|
||
public static ElasticInferenceServiceUnifiedCompletionRequestManager of( | ||
ElasticInferenceServiceCompletionModel model, | ||
ThreadPool threadPool, | ||
TraceContext traceContext | ||
) { | ||
return new ElasticInferenceServiceUnifiedCompletionRequestManager( | ||
Objects.requireNonNull(model), | ||
Objects.requireNonNull(threadPool), | ||
Objects.requireNonNull(traceContext) | ||
); | ||
} | ||
|
||
private final ElasticInferenceServiceCompletionModel model; | ||
private final TraceContext traceContext; | ||
|
||
private ElasticInferenceServiceUnifiedCompletionRequestManager( | ||
ElasticInferenceServiceCompletionModel model, | ||
ThreadPool threadPool, | ||
TraceContext traceContext | ||
) { | ||
super(threadPool, model); | ||
this.model = model; | ||
this.traceContext = traceContext; | ||
} | ||
|
||
@Override | ||
public void execute( | ||
InferenceInputs inferenceInputs, | ||
RequestSender requestSender, | ||
Supplier<Boolean> hasRequestCompletedFunction, | ||
ActionListener<InferenceServiceResults> listener | ||
) { | ||
|
||
ElasticInferenceServiceUnifiedChatCompletionRequest request = new ElasticInferenceServiceUnifiedChatCompletionRequest( | ||
inferenceInputs.castTo(UnifiedChatInput.class), | ||
model, | ||
traceContext | ||
); | ||
|
||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); | ||
} | ||
|
||
private static ResponseHandler createCompletionHandler() { | ||
return new ElasticInferenceServiceUnifiedChatCompletionResponseHandler( | ||
"elastic inference service completion", | ||
// We use OpenAiChatCompletionResponseEntity here as the ElasticInferenceServiceResponseEntity fields are a subset of the OpenAI | ||
// one. | ||
OpenAiChatCompletionResponseEntity::fromResponse | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
...ference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.request.elastic; | ||
|
||
import org.apache.http.HttpHeaders; | ||
import org.apache.http.client.methods.HttpPost; | ||
import org.apache.http.entity.ByteArrayEntity; | ||
import org.apache.http.message.BasicHeader; | ||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.xcontent.XContentType; | ||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
import org.elasticsearch.xpack.inference.external.request.HttpRequest; | ||
import org.elasticsearch.xpack.inference.external.request.Request; | ||
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; | ||
import org.elasticsearch.xpack.inference.telemetry.TraceContext; | ||
import org.elasticsearch.xpack.inference.telemetry.TraceContextHandler; | ||
|
||
import java.net.URI; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.Objects; | ||
|
||
public class ElasticInferenceServiceUnifiedChatCompletionRequest implements Request { | ||
|
||
private final ElasticInferenceServiceCompletionModel model; | ||
private final UnifiedChatInput unifiedChatInput; | ||
private final TraceContextHandler traceContextHandler; | ||
|
||
public ElasticInferenceServiceUnifiedChatCompletionRequest( | ||
UnifiedChatInput unifiedChatInput, | ||
ElasticInferenceServiceCompletionModel model, | ||
TraceContext traceContext | ||
) { | ||
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); | ||
this.model = Objects.requireNonNull(model); | ||
this.traceContextHandler = new TraceContextHandler(traceContext); | ||
} | ||
|
||
@Override | ||
public HttpRequest createHttpRequest() { | ||
var httpPost = new HttpPost(model.uri()); | ||
var requestEntity = Strings.toString( | ||
new ElasticInferenceServiceUnifiedChatCompletionRequestEntity(unifiedChatInput, model.getServiceSettings().modelId()) | ||
); | ||
|
||
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); | ||
httpPost.setEntity(byteEntity); | ||
|
||
traceContextHandler.propagateTraceContext(httpPost); | ||
httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); | ||
|
||
return new HttpRequest(httpPost, getInferenceEntityId()); | ||
} | ||
|
||
@Override | ||
public URI getURI() { | ||
return model.uri(); | ||
} | ||
|
||
@Override | ||
public Request truncate() { | ||
// No truncation | ||
return this; | ||
} | ||
|
||
@Override | ||
public boolean[] getTruncationInfo() { | ||
// No truncation | ||
return null; | ||
} | ||
|
||
@Override | ||
public String getInferenceEntityId() { | ||
return model.getInferenceEntityId(); | ||
} | ||
|
||
@Override | ||
public boolean isStreaming() { | ||
return true; | ||
} | ||
} |
38 changes: 38 additions & 0 deletions
38
...e/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.request.elastic; | ||
|
||
import org.elasticsearch.xcontent.ToXContentObject; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; | ||
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
|
||
public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject { | ||
private static final String MODEL_FIELD = "model"; | ||
|
||
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; | ||
private final String modelId; | ||
|
||
public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) { | ||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); | ||
this.modelId = Objects.requireNonNull(modelId); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
unifiedRequestEntity.toXContent(builder, params); | ||
builder.field(MODEL_FIELD, modelId); | ||
builder.endObject(); | ||
|
||
return builder; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.