Skip to content

Commit

Permalink
Working on converting to a standalone app.
Browse files Browse the repository at this point in the history
  • Loading branch information
jzonthemtn committed Dec 31, 2024
1 parent 9ee0b9b commit 83633bc
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 117 deletions.
5 changes: 5 additions & 0 deletions opensearch-search-quality-evaluation-framework/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
<artifactId>opensearch-java</artifactId>
<version>2.19.0</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.18.2</version>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.eval.engine;

import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.gson.Gson;
import org.apache.hc.core5.http.HttpHost;
import org.apache.logging.log4j.LogManager;
Expand All @@ -23,15 +24,18 @@
import org.opensearch.client.opensearch._types.query_dsl.BoolQuery;
import org.opensearch.client.opensearch._types.query_dsl.MatchQuery;
import org.opensearch.client.opensearch._types.query_dsl.Query;
import org.opensearch.client.opensearch._types.query_dsl.WrapperQuery;
import org.opensearch.client.opensearch.core.BulkRequest;
import org.opensearch.client.opensearch.core.BulkResponse;
import org.opensearch.client.opensearch.core.IndexRequest;
import org.opensearch.client.opensearch.core.ScrollRequest;
import org.opensearch.client.opensearch.core.ScrollResponse;
import org.opensearch.client.opensearch.core.SearchRequest;
import org.opensearch.client.opensearch.core.SearchResponse;
import org.opensearch.client.opensearch.core.bulk.BulkOperation;
import org.opensearch.client.opensearch.core.bulk.IndexOperation;
import org.opensearch.client.opensearch.core.search.Hit;
import org.opensearch.client.opensearch.core.search.TrackHits;
import org.opensearch.client.opensearch.indices.CreateIndexRequest;
import org.opensearch.client.opensearch.indices.ExistsRequest;
import org.opensearch.client.transport.OpenSearchTransport;
Expand All @@ -43,6 +47,7 @@
import org.opensearch.eval.model.data.QueryResultMetric;
import org.opensearch.eval.model.data.QuerySet;
import org.opensearch.eval.model.data.RankAggregatedClickThrough;
import org.opensearch.eval.model.ubi.event.UbiEvent;
import org.opensearch.eval.model.ubi.query.UbiQuery;
import org.opensearch.eval.utils.TimeUtils;

Expand All @@ -51,6 +56,7 @@
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
Expand All @@ -60,6 +66,7 @@

import static org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel.INDEX_QUERY_DOC_CTR;
import static org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel.INDEX_RANK_AGGREGATED_CTR;
import static org.opensearch.eval.runners.OpenSearchQuerySetRunner.QUERY_PLACEHOLDER;

/**
* Functionality for interacting with OpenSearch.
Expand Down Expand Up @@ -303,25 +310,18 @@ public UbiQuery getQueryFromQueryId(final String queryId) throws Exception {

LOGGER.debug("Getting query from query ID {}", queryId);

final String query = "{\"match\": {\"query_id\": \"" + queryId + "\" }}";
final WrapperQueryBuilder qb = QueryBuilders.wrapperQuery(query);

// The query_id should be unique anyway, but we are limiting it to a single result anyway.
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(qb);
searchSourceBuilder.from(0);
searchSourceBuilder.size(1);

final String[] indexes = {Constants.UBI_QUERIES_INDEX_NAME};
final SearchRequest searchRequest = new SearchRequest.Builder().query(q -> q.match(m -> m.field("query_id").query(FieldValue.of(queryId))))
.index(Constants.UBI_QUERIES_INDEX_NAME)
.from(0)
.size(1)
.build();

final SearchRequest searchRequest = new SearchRequest(indexes, searchSourceBuilder);
final SearchResponse response = client.search(searchRequest).get();
final SearchResponse<UbiQuery> searchResponse = client.search(searchRequest, UbiQuery.class);

// If this does not return a query then we cannot calculate the judgments. Each even should have a query associated with it.
if(response.getHits().getHits() != null & response.getHits().getHits().length > 0) {
if(searchResponse.hits().hits() != null & !searchResponse.hits().hits().isEmpty()) {

final SearchHit hit = response.getHits().getHits()[0];
return gson.fromJson(hit.getSourceAsString(), UbiQuery.class);
return searchResponse.hits().hits().get(0).source();

} else {

Expand All @@ -332,24 +332,65 @@ public UbiQuery getQueryFromQueryId(final String queryId) throws Exception {

}

private Collection<String> getQueryIdsHavingUserQuery(final String userQuery) throws Exception {
@Override
public List<String> runQuery(final String index, final String query, final int k, final String userQuery, final String idField) throws IOException {

// Replace the query placeholder with the user query.
final String parsedQuery = query.replace(QUERY_PLACEHOLDER, userQuery);

final String encodedQuery = Base64.getEncoder().encodeToString(parsedQuery.getBytes(StandardCharsets.UTF_8));

final WrapperQuery wrapperQuery = new WrapperQuery.Builder()
.query(encodedQuery)
.build();

final SearchRequest searchRequest = new SearchRequest.Builder()
.index(index)
.query(q -> q.wrapper(wrapperQuery))
.from(0)
.size(k)
.build();


final String query = "{\"match\": {\"user_query\": \"" + userQuery + "\" }}";
final WrapperQueryBuilder qb = QueryBuilders.wrapperQuery(query);
// TODO: Handle the searchPipeline if it is not null.
// TODO: Only return the idField since that's all we need.

final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(qb);
final SearchResponse<ObjectNode> searchResponse = client.search(searchRequest, ObjectNode.class);

final String[] indexes = {Constants.UBI_QUERIES_INDEX_NAME};
final List<String> orderedDocumentIds = new ArrayList<>();

final SearchRequest searchRequest = new SearchRequest(indexes, searchSourceBuilder);
final SearchResponse response = client.search(searchRequest).get();
for (int i = 0; i < searchResponse.hits().hits().size(); i++) {

final String documentId;

if ("_id".equals(idField)) {
documentId = searchResponse.hits().hits().get(i).id();
} else {
// TODO: Need to check this field actually exists.
// TODO: Does this work?
documentId = searchResponse.hits().hits().get(i).fields().get(idField).toString();
}

orderedDocumentIds.add(documentId);

}

return orderedDocumentIds;

}

private Collection<String> getQueryIdsHavingUserQuery(final String userQuery) throws Exception {

final SearchRequest searchRequest = new SearchRequest.Builder().query(q -> q.match(m -> m.field("user_query").query(FieldValue.of(userQuery))))
.index(Constants.UBI_QUERIES_INDEX_NAME)
.build();

final SearchResponse<UbiQuery> searchResponse = client.search(searchRequest, UbiQuery.class);

final Collection<String> queryIds = new ArrayList<>();

for(final SearchHit hit : response.getHits().getHits()) {
final String queryId = hit.getSourceAsMap().get("query_id").toString();
queryIds.add(queryId);
for (int i = 0; i < searchResponse.hits().hits().size(); i++) {
queryIds.add(searchResponse.hits().hits().get(i).source().getQueryId());
}

return queryIds;
Expand Down Expand Up @@ -394,22 +435,22 @@ public long getCountOfQueriesForUserQueryHavingResultInRankR(final String userQu
" }\n" +
" }";

final WrapperQueryBuilder qb = QueryBuilders.wrapperQuery(query);
final String encodedQuery = Base64.getEncoder().encodeToString(query.getBytes(StandardCharsets.UTF_8));

final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(qb);
searchSourceBuilder.trackTotalHits(true);
searchSourceBuilder.size(0);
final WrapperQuery wrapperQuery = new WrapperQuery.Builder()
.query(encodedQuery)
.build();

final String[] indexes = {Constants.UBI_EVENTS_INDEX_NAME};
final SearchRequest searchRequest = new SearchRequest.Builder()
.index(Constants.UBI_EVENTS_INDEX_NAME)
.query(q -> q.wrapper(wrapperQuery))
.size(0)
.trackTotalHits(TrackHits.of(t -> t.enabled(true)))
.build();

final SearchRequest searchRequest = new SearchRequest(indexes, searchSourceBuilder);
final SearchResponse response = client.search(searchRequest).get();
final SearchResponse<UbiEvent> searchResponse = client.search(searchRequest, UbiEvent.class);

// Won't be null as long as trackTotalHits is true.
if(response.getHits().getTotalHits() != null) {
countOfTimesShownAtRank += response.getHits().getTotalHits().value;
}
countOfTimesShownAtRank += searchResponse.hits().total().value();

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand All @@ -30,6 +31,8 @@ public abstract class SearchEngine {

public abstract Collection<Judgment> getJudgments(final String index) throws IOException;

public abstract List<String> runQuery(final String index, final String query, final int k, final String userQuery, final String idField) throws IOException;

public abstract String indexQuerySet(QuerySet querySet) throws IOException;
public abstract Collection<UbiQuery> getUbiQueries() throws IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,80 +64,28 @@ public QuerySetRunResult run(final String querySetId, final String judgmentsId,
// Loop over each query in the map and run each one.
for (final String userQuery : queryMap.keySet()) {

// Replace the query placeholder with the user query.
final String parsedQuery = query.replace(QUERY_PLACEHOLDER, userQuery);

// Build the query from the one that was passed in.
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();

searchSourceBuilder.query(QueryBuilders.wrapperQuery(parsedQuery));
searchSourceBuilder.from(0);
searchSourceBuilder.size(k);

final String[] includeFields = new String[]{idField};
final String[] excludeFields = new String[]{};
searchSourceBuilder.fetchSource(includeFields, excludeFields);

// LOGGER.info(searchSourceBuilder.toString());

final SearchRequest searchRequest = new SearchRequest(index);
searchRequest.source(searchSourceBuilder);

if (searchPipeline != null) {
searchSourceBuilder.pipeline(searchPipeline);
searchRequest.pipeline(searchPipeline);
}

// This is to keep OpenSearch from rejecting queries.
// TODO: Look at using the Workload Management in 2.18.0.
Thread.sleep(50);

client.search(searchRequest, new ActionListener<>() {

@Override
public void onResponse(final SearchResponse searchResponse) {
final List<String> orderedDocumentIds = searchEngine.runQuery(index, query, k, userQuery, idField);

final List<String> orderedDocumentIds = new ArrayList<>();
try {

for (final SearchHit hit : searchResponse.getHits().getHits()) {
final RelevanceScores relevanceScores = getRelevanceScores(judgmentsId, userQuery, orderedDocumentIds, k);

final String documentId;
// Calculate the metrics for this query.
final SearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores.getRelevanceScores());
final SearchMetric ndcgSearchmetric = new NdcgSearchMetric(k, relevanceScores.getRelevanceScores());
final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, threshold, relevanceScores.getRelevanceScores());

if ("_id".equals(idField)) {
documentId = hit.getId();
} else {
// TODO: Need to check this field actually exists.
documentId = hit.getSourceAsMap().get(idField).toString();
}
final Collection<SearchMetric> searchMetrics = List.of(dcgSearchMetric, ndcgSearchmetric, precisionSearchMetric);

orderedDocumentIds.add(documentId);
queryResults.add(new QueryResult(userQuery, orderedDocumentIds, k, searchMetrics, relevanceScores.getFrogs()));

}

try {

final RelevanceScores relevanceScores = getRelevanceScores(judgmentsId, userQuery, orderedDocumentIds, k);

// Calculate the metrics for this query.
final SearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores.getRelevanceScores());
final SearchMetric ndcgSearchmetric = new NdcgSearchMetric(k, relevanceScores.getRelevanceScores());
final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, threshold, relevanceScores.getRelevanceScores());

final Collection<SearchMetric> searchMetrics = List.of(dcgSearchMetric, ndcgSearchmetric, precisionSearchMetric);

queryResults.add(new QueryResult(userQuery, orderedDocumentIds, k, searchMetrics, relevanceScores.getFrogs()));

} catch (Exception ex) {
LOGGER.error("Unable to get relevance scores for judgments {} and user query {}.", judgmentsId, userQuery, ex);
}

}

@Override
public void onFailure(Exception ex) {
LOGGER.error("Unable to search using query: {}", searchSourceBuilder.toString(), ex);
}
});
} catch (Exception ex) {
LOGGER.error("Unable to get relevance scores for judgments {} and user query {}.", judgmentsId, userQuery, ex);
}

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
*/
package org.opensearch.eval.samplers;

import org.opensearch.eval.Constants;
import org.opensearch.eval.engine.SearchEngine;
import org.opensearch.eval.model.ubi.query.UbiQuery;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -41,24 +42,13 @@ public String getName() {
@Override
public String sample() throws Exception {

// Get queries from the UBI queries index.
// TODO: This needs to use scroll or something else.
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
searchSourceBuilder.from(0);
searchSourceBuilder.size(parameters.getQuerySetSize());

final SearchRequest searchRequest = new SearchRequest(Constants.UBI_QUERIES_INDEX_NAME).source(searchSourceBuilder);

// TODO: Don't use .get()
final SearchResponse searchResponse = client.search(searchRequest).get();
final Collection<UbiQuery> ubiQueries = searchEngine.getUbiQueries();

final Map<String, Long> queries = new HashMap<>();

for(final SearchHit hit : searchResponse.getHits().getHits()) {
for(final UbiQuery ubiQuery : ubiQueries) {

final Map<String, Object> fields = hit.getSourceAsMap();
queries.merge(fields.get("user_query").toString(), 1L, Long::sum);
queries.merge(ubiQuery.getUserQuery(), 1L, Long::sum);

// Will be useful for paging once implemented.
if(queries.size() > parameters.getQuerySetSize()) {
Expand Down

0 comments on commit 83633bc

Please sign in to comment.