Skip to content

Commit

Permalink
#44 Working on wiring up PPTSS.
Browse files Browse the repository at this point in the history
  • Loading branch information
jzonthemtn committed Nov 22, 2024
1 parent 0fcfeb7 commit b170242
Show file tree
Hide file tree
Showing 11 changed files with 271 additions and 123 deletions.
3 changes: 0 additions & 3 deletions data/esci/ubi_queries_events.ndjson.bz2

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=none&max_queries=500"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=none&query_set_size=500"

#echo ${QUERY_SET}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&max_queries=500"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=500"

#echo ${QUERY_SET}

#curl -s http://localhost:9200/search_quality_eval_query_sets/_search | jq
#curl -s -X GET http://localhost:9200/search_quality_eval_query_sets/_doc/${QUERY_SET} | jq

# Run the query set now.
#curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/run?id=${QUERY_SET}" | jq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
*/
package org.opensearch.eval;

import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.delete.DeleteRequest;
Expand All @@ -24,26 +23,25 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters;
import org.opensearch.eval.samplers.AllQueriesQuerySampler;
import org.opensearch.eval.samplers.AllQueriesQuerySamplerParameters;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeAbstractQuerySampler;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeParameters;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeQuerySampler;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

public class SearchQualityEvaluationRestHandler extends BaseRestHandler {
Expand Down Expand Up @@ -89,7 +87,7 @@ public List<Route> routes() {
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {

// Handle managing query sets.
if(StringUtils.equalsIgnoreCase(request.path(), QUERYSET_MANAGEMENT_URL)) {
if(QUERYSET_MANAGEMENT_URL.equalsIgnoreCase(request.path())) {

// Creating a new query set by sampling the UBI queries.
if (request.method().equals(RestRequest.Method.POST)) {
Expand All @@ -100,36 +98,19 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final int querySetSize = Integer.parseInt(request.param("query_set_size", "1000"));

// Create a query set by finding all the unique user_query terms.
if (StringUtils.equalsIgnoreCase(sampling, "none")) {
if ("none".equalsIgnoreCase(sampling)) {

// If we are not sampling queries, the query sets should just be directly
// indexed into OpenSearch using the `ubu_queries` index directly.

try {

// Get queries from the UBI queries index.
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
searchSourceBuilder.from(0);
searchSourceBuilder.size(querySetSize);
final AllQueriesQuerySamplerParameters parameters = new AllQueriesQuerySamplerParameters(name, description, sampling, querySetSize);
final AllQueriesQuerySampler sampler = new AllQueriesQuerySampler(client, parameters);

final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.UBI_QUERIES_INDEX_NAME);
searchRequest.source(searchSourceBuilder);
// Sample and index the queries.
final String querySetId = sampler.sample();

final SearchResponse searchResponse = client.search(searchRequest).get();

// LOGGER.info("Found {} user queries from the ubi_queries index.", searchResponse.getHits().getTotalHits().toString());

final Set<String> queries = new HashSet<>();
for(final SearchHit hit : searchResponse.getHits().getHits()) {
final Map<String, Object> fields = hit.getSourceAsMap();
queries.add(fields.get("user_query").toString());
}

// LOGGER.info("Found {} user queries from the ubi_queries index.", queries.size());

// Create the query set and return its ID.
final String querySetId = indexQuerySet(client, name, description, sampling, queries);
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"query_set\": \"" + querySetId + "\"}"));

} catch(Exception ex) {
Expand All @@ -138,19 +119,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli


// Create a query set by using PPTSS sampling.
} else if (StringUtils.equalsIgnoreCase(sampling, "pptss")) {
} else if ("pptss".equalsIgnoreCase(sampling)) {

final ProbabilityProportionalToSizeParameters parameters = new ProbabilityProportionalToSizeParameters(querySetSize);
final ProbabilityProportionalToSizeQuerySampler sampler = new ProbabilityProportionalToSizeQuerySampler(parameters);
LOGGER.info("Creating query set using PPTSS");

// TODO: Get all queries from the ubi_queries index.

final Collection<String> queries = sampler.sample();
final ProbabilityProportionalToSizeParameters parameters = new ProbabilityProportionalToSizeParameters(name, description, sampling, querySetSize);
final ProbabilityProportionalToSizeAbstractQuerySampler sampler = new ProbabilityProportionalToSizeAbstractQuerySampler(client, parameters);

try {

// Create the query set and return its ID.
final String querySetId = indexQuerySet(client, name, description, sampling, queries);
// Sample and index the queries.
final String querySetId = sampler.sample();

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"query_set\": \"" + querySetId + "\"}"));

} catch(Exception ex) {
Expand All @@ -168,7 +148,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

// Handle running query sets.
} else if(StringUtils.equalsIgnoreCase(request.path(), QUERYSET_RUN_URL)) {
} else if(QUERYSET_RUN_URL.equalsIgnoreCase(request.path())) {

final String id = request.param("id");

Expand Down Expand Up @@ -203,7 +183,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Query set " + id + " run initiated.\"}"));

// Handle the on-demand creation of implicit judgments.
} else if(StringUtils.equalsIgnoreCase(request.path(), IMPLICIT_JUDGMENTS_URL)) {
} else if(IMPLICIT_JUDGMENTS_URL.equalsIgnoreCase(request.path())) {

if (request.method().equals(RestRequest.Method.POST)) {

Expand All @@ -212,7 +192,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final int maxRank = Integer.parseInt(request.param("max_rank", "20"));
final long judgments;

if (StringUtils.equalsIgnoreCase(clickModel, "coec")) {
if ("coec".equalsIgnoreCase(clickModel)) {

final CoecClickModelParameters coecClickModelParameters = new CoecClickModelParameters(true, maxRank);
final CoecClickModel coecClickModel = new CoecClickModel(client, coecClickModelParameters);
Expand Down Expand Up @@ -255,7 +235,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

// Handle the scheduling of creating implicit judgments.
} else if(StringUtils.equalsIgnoreCase(request.path(), SCHEDULING_URL)) {
} else if(SCHEDULING_URL.equalsIgnoreCase(request.path())) {

if (request.method().equals(RestRequest.Method.POST)) {

Expand All @@ -276,15 +256,15 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

// Read the start_time.
final Instant startTime;
if (StringUtils.isEmpty(request.param("start_time"))) {
if (request.param("start_time") == null) {
startTime = Instant.now();
} else {
startTime = Instant.ofEpochMilli(Long.parseLong(request.param("start_time")));
}

// Read the interval.
final int interval;
if (StringUtils.isEmpty(request.param("interval"))) {
if (request.param("interval") == null) {
// Default to every 24 hours.
interval = 1440;
} else {
Expand Down Expand Up @@ -361,29 +341,4 @@ public void onFailure(Exception e) {

}

/**
* Index the query set.
*/
private String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Collection<String> queries) throws Exception {

final Map<String, Object> querySet = new HashMap<>();
querySet.put("name", name);
querySet.put("description", description);
querySet.put("sampling", sampling);
querySet.put("queries", queries);
querySet.put("created_at", Instant.now().toEpochMilli());

final String querySetId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.QUERY_SETS_INDEX_NAME)
.id(querySetId)
.source(querySet)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();

return querySetId;

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.eval.SearchQualityEvaluationPlugin;

import java.time.Instant;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

/**
* An interface for sampling UBI queries.
*/
public abstract class AbstractQuerySampler {

/**
* Gets the name of the sampler.
* @return The name of the sampler.
*/
abstract String getName();

/**
* Samples the queries and inserts the query set into an index.
* @return A query set ID.
*/
abstract String sample() throws Exception;

/**
* Index the query set.
*/
protected String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Collection<String> queries) throws Exception {

final Map<String, Object> querySet = new HashMap<>();
querySet.put("name", name);
querySet.put("description", description);
querySet.put("sampling", sampling);
querySet.put("queries", queries);
querySet.put("created_at", Instant.now().toEpochMilli());

final String querySetId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.QUERY_SETS_INDEX_NAME)
.id(querySetId)
.source(querySet)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();

return querySetId;

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

public class AbstractSamplerParameters {

private final String name;
private final String description;
private final String sampling;
private final int querySetSize;

public AbstractSamplerParameters(final String name, final String description, final String sampling, final int querySetSize) {
this.name = name;
this.description = description;
this.sampling = sampling;
this.querySetSize = querySetSize;
}

public String getName() {
return name;
}

public String getDescription() {
return description;
}

public String getSampling() {
return sampling;
}

public int getQuerySetSize() {
return querySetSize;
}

}
Loading

0 comments on commit b170242

Please sign in to comment.