Skip to content

Commit

Permalink
changing MLSearchActionRequest to an instance subclass of SearchActio…
Browse files Browse the repository at this point in the history
…nRequest

Signed-off-by: Dhrubo Saha <[email protected]>
  • Loading branch information
dhrubo-os committed Jan 30, 2025
1 parent cc47b67 commit 38f7b5b
Show file tree
Hide file tree
Showing 17 changed files with 106 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
*/
@Getter
public class MLSearchActionRequest extends SearchRequest {
SearchRequest searchRequest;
String tenantId;

/**
Expand All @@ -36,7 +35,7 @@ public class MLSearchActionRequest extends SearchRequest {
*/
@Builder
public MLSearchActionRequest(SearchRequest searchRequest, String tenantId) {
this.searchRequest = searchRequest;
super(searchRequest);
this.tenantId = tenantId;
}

Expand All @@ -50,6 +49,7 @@ public MLSearchActionRequest(StreamInput input) throws IOException {
super(input);
Version streamInputVersion = input.getVersion();
this.tenantId = streamInputVersion.onOrAfter(VERSION_2_19_0) ? input.readOptionalString() : null;

}

/**
Expand Down Expand Up @@ -80,6 +80,14 @@ public static MLSearchActionRequest fromActionRequest(ActionRequest actionReques
return (MLSearchActionRequest) actionRequest;
}

if (actionRequest instanceof SearchRequest) {
return MLSearchActionRequest
.builder()
.searchRequest((SearchRequest) actionRequest)
.tenantId(null) // No tenant ID in the original request
.build();
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
package org.opensearch.ml.common.transport.search;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;

import java.io.IOException;
import java.io.UncheckedIOException;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.Version;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class MLSearchActionRequestTest {

Expand All @@ -28,120 +24,66 @@ public void setUp() {
}

@Test
public void testConstructorAndGetters() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
assertEquals("test-index", request.getSearchRequest().indices()[0]);
assertEquals("test-tenant", request.getTenantId());
}
public void testSerializationDeserialization_Version_2_19_0() throws IOException {
// Set up a valid SearchRequest
SearchRequest searchRequest = new SearchRequest("test-index");

@Test
public void testStreamConstructorAndWriteTo() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
BytesStreamOutput out = new BytesStreamOutput();
request.writeTo(out);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
assertEquals("test-index", deserializedRequest.getSearchRequest().indices()[0]);
assertEquals("test-tenant", deserializedRequest.getTenantId());
}
// Create the MLSearchActionRequest
MLSearchActionRequest originalRequest = MLSearchActionRequest
.builder()
.searchRequest(searchRequest)
.tenantId("test-tenant")
.build();

@Test
public void testWriteToWithNullSearchRequest() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().tenantId("test-tenant").build();
BytesStreamOutput out = new BytesStreamOutput();
request.writeTo(out);
out.setVersion(Version.V_2_19_0);
originalRequest.writeTo(out);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(out.bytes().streamInput());
assertNull(deserializedRequest.getSearchRequest());
assertEquals("test-tenant", deserializedRequest.getTenantId());
}
StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_19_0);
MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);

@Test
public void testFromActionRequestWithMLSearchActionRequest() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);
assertSame(result, request);
assertEquals("test-tenant", deserializedRequest.getTenantId());
}

@Test
public void testFromActionRequestWithNonMLSearchActionRequest() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
request.writeTo(out);
}
};

MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(actionRequest);
assertNotSame(result, request);
assertEquals(request.getSearchRequest().indices()[0], result.getSearchRequest().indices()[0]);
assertEquals(request.getTenantId(), result.getTenantId());
}
public void testSerializationDeserialization_Version_2_18_0() throws IOException {

@Test(expected = UncheckedIOException.class)
public void testFromActionRequestIOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException("test");
}
};
MLSearchActionRequest.fromActionRequest(actionRequest);
}

@Test
public void testBackwardCompatibility() throws IOException {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();
// Create the MLSearchActionRequest
MLSearchActionRequest originalRequest = MLSearchActionRequest
.builder()
.searchRequest(searchRequest)
.tenantId("test-tenant")
.build();

BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(Version.V_2_18_0); // Older version
request.writeTo(out);
out.setVersion(Version.V_2_18_0);
originalRequest.writeTo(out);

StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_18_0);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
assertNull(deserializedRequest.getTenantId()); // Ensure tenantId is ignored

assertNull(deserializedRequest.getTenantId());
}

@Test
public void testFromActionRequestWithValidRequest() {
public void testFromActionRequest_WithMLSearchActionRequest() {
MLSearchActionRequest request = MLSearchActionRequest.builder().searchRequest(searchRequest).tenantId("test-tenant").build();

MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(request);

assertSame(request, result);
}

@Test
public void testMixedVersionCompatibility() throws IOException {
MLSearchActionRequest originalRequest = MLSearchActionRequest
.builder()
.searchRequest(searchRequest)
.tenantId("test-tenant")
.build();

// Serialize with a newer version
BytesStreamOutput out = new BytesStreamOutput();
out.setVersion(Version.V_2_19_0);
originalRequest.writeTo(out);
public void testFromActionRequest_WithSearchRequest() throws IOException {
SearchRequest simpleRequest = new SearchRequest("test-index");

// Deserialize with an older version
StreamInput in = out.bytes().streamInput();
in.setVersion(Version.V_2_18_0);
MLSearchActionRequest result = MLSearchActionRequest.fromActionRequest(simpleRequest);

MLSearchActionRequest deserializedRequest = new MLSearchActionRequest(in);
assertNull(deserializedRequest.getTenantId()); // tenantId should not exist in older versions
assertNotNull(result);
assertNull(result.getTenantId()); // Since tenantId wasn't in original request
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE;

import org.opensearch.OpenSearchException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand Down Expand Up @@ -73,14 +72,13 @@ public SearchConversationsTransportAction(

@Override
public void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
SearchRequest request = mlSearchActionRequest.getSearchRequest();
if (!featureIsEnabled) {
actionListener.onFailure(new OpenSearchException(ML_COMMONS_MEMORY_FEATURE_DISABLED_MESSAGE));
return;
} else {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, context::restore);
cmHandler.searchConversations(request, internalListener);
cmHandler.searchConversations(mlSearchActionRequest, internalListener);
} catch (Exception e) {
log.error("Failed to search memories", e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ public TransportSearchAgentAction(

@Override
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
request.getSearchRequest().indices(CommonValue.ML_AGENT_INDEX);
request.indices(CommonValue.ML_AGENT_INDEX);
String tenantId = request.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
search(request.getSearchRequest(), tenantId, actionListener);
search(request, tenantId, actionListener);
}

private void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ public SearchConnectorTransportAction(

@Override
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
request.getSearchRequest().indices(CommonValue.ML_CONNECTOR_INDEX);
request.indices(CommonValue.ML_CONNECTOR_INDEX);

String tenantId = request.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
search(request.getSearchRequest(), tenantId, actionListener);
search(request, tenantId, actionListener);
}

private void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
Expand Down Expand Up @@ -65,12 +64,11 @@ public SearchModelGroupTransportAction(
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
User user = RestActionUtils.getUserContext(client);
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search");
request.getSearchRequest().indices(CommonValue.ML_MODEL_GROUP_INDEX);
String tenantId = request.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
preProcessRoleAndPerformSearch(request.getSearchRequest(), tenantId, user, listener);
preProcessRoleAndPerformSearch(request, tenantId, user, listener);
}

private void preProcessRoleAndPerformSearch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.opensearch.common.inject.Inject;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.search.MLSearchActionRequest;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
Expand Down Expand Up @@ -44,11 +43,11 @@ public SearchModelTransportAction(

@Override
protected void doExecute(Task task, MLSearchActionRequest request, ActionListener<SearchResponse> actionListener) {
request.getSearchRequest().indices(CommonValue.ML_MODEL_INDEX);

String tenantId = request.getTenantId();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
mlSearchHandler.search(request.getSearchRequest(), tenantId, actionListener);
mlSearchHandler.search(request, tenantId, actionListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
Expand Down Expand Up @@ -51,7 +50,6 @@ public SearchTaskTransportAction(
@Override
protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest, ActionListener<SearchResponse> actionListener) {
String tenantId = mlSearchActionRequest.getTenantId();
SearchRequest request = mlSearchActionRequest.getSearchRequest();
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
Expand All @@ -65,16 +63,16 @@ protected void doExecute(Task task, MLSearchActionRequest mlSearchActionRequest,
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery();

// Preserve existing query if present
if (request.source().query() != null) {
queryBuilder.must(request.source().query());
if (mlSearchActionRequest.source().query() != null) {
queryBuilder.must(mlSearchActionRequest.source().query());
}
// Add tenancy filter
queryBuilder.filter(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId)); // Replace 'tenant_id_field' with actual field name

// Update the request's source with the new query
request.source().query(queryBuilder);
mlSearchActionRequest.source().query(queryBuilder);
}
client.search(request, ActionListener.runBefore(wrappedListener, context::restore));
client.search(mlSearchActionRequest, ActionListener.runBefore(wrappedListener, context::restore));
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
Expand Down
Loading

0 comments on commit 38f7b5b

Please sign in to comment.