Skip to content

Commit

Permalink
Fixing various tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner committed Dec 3, 2024
1 parent fa415d8 commit f382246
Show file tree
Hide file tree
Showing 24 changed files with 145 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new),
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new),
new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new)
// new NamedWriteableRegistry.Entry(Stop.class, StopValues.NAME, StopValues::new),
// new NamedWriteableRegistry.Entry(Stop.class, StopString.NAME, StopString::new)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,20 @@

public class ChatCompletionInput extends InferenceInputs {
private final List<String> input;
private final boolean stream;

public ChatCompletionInput(List<String> input) {
this(input, false);
}

public ChatCompletionInput(List<String> input, boolean stream) {
super();
super(stream);
this.input = Objects.requireNonNull(input);
this.stream = stream;
}

public List<String> getInputs() {
return this.input;
}

public boolean stream() {
return stream;
}

public int inputSize() {
return input.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,20 @@ public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) {
}

private final List<String> input;
private final boolean stream;

public DocumentsOnlyInput(List<String> input) {
this(input, false);
}

public DocumentsOnlyInput(List<String> input, boolean stream) {
super();
super(stream);
this.input = Objects.requireNonNull(input);
this.stream = stream;
}

public List<String> getInputs() {
return this.input;
}

public boolean stream() {
return stream;
}

public int inputSize() {
return input.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
import org.elasticsearch.common.Strings;

public abstract class InferenceInputs {
private final boolean stream;

public InferenceInputs(boolean stream) {
this.stream = stream;
}

public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs, Class<?> clazz) {
return new IllegalArgumentException(
Strings.format("Unable to convert inference inputs type: [%s] to [%s]", inferenceInputs.getClass(), clazz)
Expand All @@ -24,5 +30,9 @@ public <T> T castTo(Class<T> clazz) {
return clazz.cast(this);
}

public boolean stream() {
return stream;
}

public abstract int inputSize();
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,15 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {

private final String query;
private final List<String> chunks;
private final boolean stream;

public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, false);
}

public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
super();
super(stream);
this.query = Objects.requireNonNull(query);
this.chunks = Objects.requireNonNull(chunks);
this.stream = stream;
}

public String getQuery() {
Expand All @@ -43,10 +41,6 @@ public List<String> getChunks() {
return chunks;
}

public boolean stream() {
return stream;
}

public int inputSize() {
return chunks.size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

public class UnifiedChatInput extends InferenceInputs {
private final UnifiedCompletionRequest request;
private final boolean stream;

public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) {
super(stream);
this.request = Objects.requireNonNull(request);
this.stream = stream;
}

public UnifiedChatInput(ChatCompletionInput completionInput, String roleValue) {
Expand Down Expand Up @@ -47,10 +46,6 @@ public UnifiedCompletionRequest getRequest() {
return request;
}

public boolean stream() {
return stream;
}

public int inputSize() {
return request.messages().size();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private static InferenceInputs createInput(Model model, List<String> input, @Nul
return switch (model.getTaskType()) {
case COMPLETION -> new ChatCompletionInput(input, stream);
case RERANK -> new QueryAndDocsInputs(query, input, stream);
case TEXT_EMBEDDING -> new DocumentsOnlyInput(input, stream);
case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream);
default -> throw new ElasticsearchStatusException(
Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()),
RestStatus.BAD_REQUEST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,8 @@ protected void doInfer(
) {
if (model instanceof GoogleAiStudioCompletionModel completionModel) {
var requestManager = new GoogleAiStudioCompletionRequestManager(completionModel, getServiceComponents().threadPool());
var docsOnly = DocumentsOnlyInput.of(inputs);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
completionModel.uri(docsOnly.stream()),
completionModel.uri(inputs.stream()),
"Google AI Studio completion"
);
var action = new SingleInputSenderExecutableAction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.common.Truncator;
Expand Down Expand Up @@ -160,9 +161,11 @@ public static Model getInvalidModel(String inferenceEntityId, String serviceName
var mockConfigs = mock(ModelConfigurations.class);
when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId);
when(mockConfigs.getService()).thenReturn(serviceName);
when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

var mockModel = mock(Model.class);
when(mockModel.getConfigurations()).thenReturn(mockConfigs);
when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

return mockModel;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,11 @@ public void testOneInputIsValid() {
assertTrue("Test failed to call listener.", testRan.get());
}

public void testInvalidInputType() {
var badInput = mock(InferenceInputs.class);
var actualException = new AtomicReference<Exception>();

executableAction.execute(
badInput,
mock(TimeValue.class),
ActionListener.wrap(shouldNotSucceed -> fail("Test failed."), actualException::set)
);

assertThat(actualException.get(), notNullValue());
assertThat(actualException.get().getMessage(), is("Invalid inference input type"));
assertThat(actualException.get(), instanceOf(ElasticsearchStatusException.class));
assertThat(((ElasticsearchStatusException) actualException.get()).status(), is(RestStatus.INTERNAL_SERVER_ERROR));
}

public void testMoreThanOneInput() {
var badInput = mock(DocumentsOnlyInput.class);
when(badInput.getInputs()).thenReturn(List.of("one", "two"));
var input = List.of("one", "two");
when(badInput.getInputs()).thenReturn(input);
when(badInput.inputSize()).thenReturn(input.size());
var actualException = new AtomicReference<Exception>();

executableAction.execute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider;
Expand Down Expand Up @@ -130,7 +131,7 @@ public void testCompletionRequestAction() throws IOException {
);
var action = creator.create(model, Map.of());
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var result = listener.actionGet(TIMEOUT);

assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string"))));
Expand Down Expand Up @@ -163,7 +164,7 @@ public void testChatCompletionRequestAction_HandlesException() throws IOExceptio
);
var action = creator.create(model, Map.of());
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));

assertThat(sender.sendCount(), is(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
Expand Down Expand Up @@ -49,6 +49,7 @@
import static org.mockito.Mockito.mock;

public class AnthropicActionCreatorTests extends ESTestCase {

private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
Expand Down Expand Up @@ -103,7 +104,7 @@ public void testCreate_ChatCompletionModel() throws IOException {
var action = actionCreator.create(model, overriddenTaskSettings);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);

Expand Down Expand Up @@ -168,7 +169,7 @@ public void testCreate_ChatCompletionModel_FailsFromInvalidResponseFormat() thro
var action = actionCreator.create(model, overriddenTaskSettings);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.AnthropicCompletionRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
Expand Down Expand Up @@ -113,7 +113,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
var action = createAction(getUrl(webServer), "secret", "model", 1, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);

Expand Down Expand Up @@ -149,7 +149,7 @@ public void testExecute_ThrowsElasticsearchException() {
var action = createAction(getUrl(webServer), "secret", "model", 1, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));

Expand All @@ -170,7 +170,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled
var action = createAction(getUrl(webServer), "secret", "model", 1, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));

Expand All @@ -187,7 +187,7 @@ public void testExecute_ThrowsException() {
var action = createAction(getUrl(webServer), "secret", "model", 1, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));

Expand Down Expand Up @@ -229,7 +229,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc
var action = createAction(getUrl(webServer), "secret", "model", 1, sender);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.common.TruncatorTests;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
Expand Down Expand Up @@ -160,7 +161,7 @@ public void testChatCompletionRequestAction() throws IOException {
var action = creator.create(model, Map.of());

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils;
Expand Down Expand Up @@ -475,7 +476,7 @@ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOExcept
var action = actionCreator.create(model, taskSettingsWithUserOverride);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);

Expand Down Expand Up @@ -531,7 +532,7 @@ public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOExceptio
var action = actionCreator.create(model, requestTaskSettingsWithoutUser);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var result = listener.actionGet(TIMEOUT);

Expand Down Expand Up @@ -589,7 +590,7 @@ public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat(
var action = actionCreator.create(model, requestTaskSettingsWithoutUser);

PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener);

var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(
Expand Down
Loading

0 comments on commit f382246

Please sign in to comment.