diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 50f83dc7b02eb..52ebaaaf8826c 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -35,7 +35,6 @@ import java.util.stream.Stream; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalToIgnoringCase; import static org.hamcrest.Matchers.hasSize; @@ -58,7 +57,7 @@ public void testCRUD() throws IOException { } var getAllModels = getAllModels(); - int numModels = 12; + int numModels = 13; assertThat(getAllModels, hasSize(numModels)); var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING); @@ -553,8 +552,8 @@ private static String expectedResult(String input) { } } - public void testGetZeroModels() throws IOException { + public void testGetCompletionModels() throws IOException { var models = getModels("_all", TaskType.COMPLETION); - assertThat(models, empty()); + assertEquals(models.size(), 1); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 48416faac6a06..e19034644862a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,6 +16,8 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; @@ -42,6 +44,7 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; @@ -67,10 +70,14 @@ public class ElasticInferenceService extends SenderService { public static final String NAME = "elastic"; public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service"; - private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.COMPLETION); private static final String SERVICE_NAME = "Elastic"; + private static final String DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + private static final String DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".eis-alpha-1"; + private static final Set DEFAULT_EIS_ENDPOINT_IDS = Set.of(DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1); + + private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + private final List defaultEndpoints; public ElasticInferenceService( HttpRequestSender.Factory factory, @@ -79,6 +86,23 @@ public ElasticInferenceService( ) { super(factory, serviceComponents); this.elasticInferenceServiceComponents = elasticInferenceServiceComponents; + this.defaultEndpoints = initDefaultEndpoints(); + } + + private List initDefaultEndpoints() { + return List.of(v1DefaultCompletionModel()); + } + + private ElasticInferenceServiceCompletionModel v1DefaultCompletionModel() { + return new ElasticInferenceServiceCompletionModel( + DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.COMPLETION, + NAME, + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_EIS_CHAT_COMPLETION_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ); } @Override @@ -175,6 +199,17 @@ public void parseRequestConfig( Map config, ActionListener parsedModelListener ) { + if (DEFAULT_EIS_ENDPOINT_IDS.contains(inferenceEntityId)) { + parsedModelListener.onFailure( + new ElasticsearchStatusException( + "[{}] is a reserved inference Id. Cannot create a new inference endpoint with a reserved Id", + RestStatus.BAD_REQUEST, + inferenceEntityId + ) + ); + return; + } + try { Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); @@ -210,6 +245,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + public List defaultConfigIds() { + return List.of(new DefaultConfigId(DEFAULT_EIS_CHAT_COMPLETION_ENDPOINT_ID_V1, TaskType.COMPLETION, this)); + } + + @Override + public void defaultConfigs(ActionListener> defaultsListener) { + defaultsListener.onResponse(defaultEndpoints); + } + private static ElasticInferenceServiceModel createModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 5146cec1552af..3801e0acc8727 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -35,7 +35,6 @@ public class ElasticInferenceServiceSettings { public ElasticInferenceServiceSettings(Settings settings) { eisGatewayUrl = EIS_GATEWAY_URL.get(settings); elasticInferenceServiceUrl = ELASTIC_INFERENCE_SERVICE_URL.get(settings); - } public static List> getSettingsDefinitions() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 84039cd7cc33c..d129dcf9fea6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -74,7 +74,7 @@ public ElasticInferenceServiceCompletionModel( } - ElasticInferenceServiceCompletionModel( + public ElasticInferenceServiceCompletionModel( String inferenceEntityId, TaskType taskType, String service, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java index 3c8182a7d41a4..931ce8109462e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.XContentBuilder; @@ -60,7 +61,7 @@ public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map