Skip to content

Commit

Permalink
Refactoring transport action tests to test unified validation code
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner committed Dec 3, 2024
1 parent 99d202f commit 41f9bce
Show file tree
Hide file tree
Showing 6 changed files with 504 additions and 346 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;

import java.util.function.Supplier;
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;
Expand Down Expand Up @@ -75,27 +76,43 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct

var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
var service = serviceRegistry.getService(unparsedModel.service());
if (service.isEmpty()) {
var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
try {
validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
validationHelper(
() -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false,
() -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType())
);
validationHelper(
() -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel),
() -> createInvalidTaskTypeException(request, unparsedModel)
);
} catch (Exception e) {
recordMetrics(unparsedModel, timer, e);
listener.onFailure(e);
return;
}

if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
// not the wildcard task type and not the model task type
var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
recordMetrics(unparsedModel, timer, e);
listener.onFailure(e);
return;
}

if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) {
var e = createIncompatibleTaskTypeException(request, unparsedModel);
recordMetrics(unparsedModel, timer, e);
listener.onFailure(e);
return;
}
// if (service.isEmpty()) {
// var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
// recordMetrics(unparsedModel, timer, e);
// listener.onFailure(e);
// return;
// }

// if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
// // not the wildcard task type and not the model task type
// var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
// recordMetrics(unparsedModel, timer, e);
// listener.onFailure(e);
// return;
// }

// if (isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel)) {
// var e = createInvalidTaskTypeException(request, unparsedModel);
// recordMetrics(unparsedModel, timer, e);
// listener.onFailure(e);
// return;
// }

var model = service.get()
.parsePersistedConfigWithSecrets(
Expand All @@ -117,9 +134,15 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
}

private static void validationHelper(Supplier<Boolean> validationFailure, Supplier<ElasticsearchStatusException> exceptionCreator) {
if (validationFailure.get()) {
throw exceptionCreator.get();
}
}

protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel);

protected abstract ElasticsearchStatusException createIncompatibleTaskTypeException(Request request, UnparsedModel unparsedModel);
protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel);

private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
try {
Expand Down Expand Up @@ -225,7 +248,7 @@ private static ElasticsearchStatusException unknownServiceException(String servi
return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
}

private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) {
private static ElasticsearchStatusException requestModelTaskTypeMismatchException(TaskType requested, TaskType expected) {
return new ElasticsearchStatusException(
"Incompatible task_type, the requested type [{}] does not match the model type [{}]",
RestStatus.BAD_REQUEST,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request
}

@Override
protected ElasticsearchStatusException createIncompatibleTaskTypeException(
InferenceAction.Request request,
UnparsedModel unparsedModel
) {
protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, UnparsedModel unparsedModel) {
return null;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.
}

@Override
protected ElasticsearchStatusException createIncompatibleTaskTypeException(
protected ElasticsearchStatusException createInvalidTaskTypeException(
UnifiedCompletionAction.Request request,
UnparsedModel unparsedModel
) {
Expand Down
Loading

0 comments on commit 41f9bce

Please sign in to comment.