Skip to content

Commit

Permalink
Merge pull request #46 from o19s/44-pptss
Browse files Browse the repository at this point in the history
Adds PPTSS sampling
  • Loading branch information
jzonthemtn authored Nov 23, 2024
2 parents 87e168f + 09ebd5a commit 222fc37
Show file tree
Hide file tree
Showing 12 changed files with 415 additions and 72 deletions.
3 changes: 0 additions & 3 deletions data/esci/ubi_queries_events.ndjson.bz2

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ services:
logger.level: info
OPENSEARCH_INITIAL_ADMIN_PASSWORD: SuperSecretPassword_123
http.max_content_length: 500mb
OPENSEARCH_JAVA_OPTS: "-Xms8192m -Xmx8192m"
OPENSEARCH_JAVA_OPTS: "-Xms8g -Xmx8g"
ulimits:
memlock:
soft: -1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=none&max_queries=500"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=none&query_set_size=500"

#echo ${QUERY_SET}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash -e

#QUERY_SET=`curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss" | jq .query_set | tr -d '"'`
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&max_queries=500"
curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/queryset?name=test&description=fake&sampling=pptss&query_set_size=5000"

#echo ${QUERY_SET}

#curl -s http://localhost:9200/search_quality_eval_query_sets/_search | jq
#curl -s -X GET http://localhost:9200/search_quality_eval_query_sets/_doc/${QUERY_SET} | jq

# Run the query set now.
#curl -s -X POST "http://localhost:9200/_plugins/search_quality_eval/run?id=${QUERY_SET}" | jq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
*/
package org.opensearch.eval;

import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.delete.DeleteRequest;
Expand All @@ -24,24 +23,25 @@
import org.opensearch.core.rest.RestStatus;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModel;
import org.opensearch.eval.judgments.clickmodel.coec.CoecClickModelParameters;
import org.opensearch.eval.samplers.AllQueriesQuerySampler;
import org.opensearch.eval.samplers.AllQueriesQuerySamplerParameters;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeAbstractQuerySampler;
import org.opensearch.eval.samplers.ProbabilityProportionalToSizeParameters;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.IOException;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;

public class SearchQualityEvaluationRestHandler extends BaseRestHandler {
Expand Down Expand Up @@ -87,47 +87,30 @@ public List<Route> routes() {
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {

// Handle managing query sets.
if(StringUtils.equalsIgnoreCase(request.path(), QUERYSET_MANAGEMENT_URL)) {
if(QUERYSET_MANAGEMENT_URL.equalsIgnoreCase(request.path())) {

// Creating a new query set by sampling the UBI queries.
if (request.method().equals(RestRequest.Method.POST)) {

final String name = request.param("name");
final String description = request.param("description");
final String sampling = request.param("sampling", "pptss");
final int maxQueries = Integer.parseInt(request.param("max_queries", "1000"));
final int querySetSize = Integer.parseInt(request.param("query_set_size", "1000"));

// Create a query set by finding all the unique user_query terms.
if (StringUtils.equalsIgnoreCase(sampling, "none")) {
if ("none".equalsIgnoreCase(sampling)) {

// If we are not sampling queries, the query sets should just be directly
// indexed into OpenSearch using the `ubu_queries` index directly.

try {

// Get queries from the UBI queries index.
final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
searchSourceBuilder.from(0);
searchSourceBuilder.size(maxQueries);
final AllQueriesQuerySamplerParameters parameters = new AllQueriesQuerySamplerParameters(name, description, sampling, querySetSize);
final AllQueriesQuerySampler sampler = new AllQueriesQuerySampler(client, parameters);

final SearchRequest searchRequest = new SearchRequest(SearchQualityEvaluationPlugin.UBI_QUERIES_INDEX_NAME);
searchRequest.source(searchSourceBuilder);
// Sample and index the queries.
final String querySetId = sampler.sample();

final SearchResponse searchResponse = client.search(searchRequest).get();

LOGGER.info("Found {} user queries from the ubi_queries index.", searchResponse.getHits().getTotalHits().toString());

final Set<String> queries = new HashSet<>();
for(final SearchHit hit : searchResponse.getHits().getHits()) {
final Map<String, Object> fields = hit.getSourceAsMap();
queries.add(fields.get("user_query").toString());
}

LOGGER.info("Found {} user queries from the ubi_queries index.", queries.size());

// Create the query set and return its ID.
final String querySetId = indexQuerySet(client, name, description, sampling, queries);
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"query_set\": \"" + querySetId + "\"}"));

} catch(Exception ex) {
Expand All @@ -136,15 +119,18 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli


// Create a query set by using PPTSS sampling.
} else if (StringUtils.equalsIgnoreCase(sampling, "pptss")) {
} else if ("pptss".equalsIgnoreCase(sampling)) {

// TODO: Use the PPTSS sampling method - https://opensourceconnections.com/blog/2022/10/13/how-to-succeed-with-explicit-relevance-evaluation-using-probability-proportional-to-size-sampling/
final Collection<String> queries = List.of("computer", "desk", "table", "battery");
LOGGER.info("Creating query set using PPTSS");

final ProbabilityProportionalToSizeParameters parameters = new ProbabilityProportionalToSizeParameters(name, description, sampling, querySetSize);
final ProbabilityProportionalToSizeAbstractQuerySampler sampler = new ProbabilityProportionalToSizeAbstractQuerySampler(client, parameters);

try {

// Create the query set and return its ID.
final String querySetId = indexQuerySet(client, name, description, sampling, queries);
// Sample and index the queries.
final String querySetId = sampler.sample();

return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"query_set\": \"" + querySetId + "\"}"));

} catch(Exception ex) {
Expand All @@ -162,7 +148,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

// Handle running query sets.
} else if(StringUtils.equalsIgnoreCase(request.path(), QUERYSET_RUN_URL)) {
} else if(QUERYSET_RUN_URL.equalsIgnoreCase(request.path())) {

final String id = request.param("id");

Expand Down Expand Up @@ -197,7 +183,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
return restChannel -> restChannel.sendResponse(new BytesRestResponse(RestStatus.OK, "{\"message\": \"Query set " + id + " run initiated.\"}"));

// Handle the on-demand creation of implicit judgments.
} else if(StringUtils.equalsIgnoreCase(request.path(), IMPLICIT_JUDGMENTS_URL)) {
} else if(IMPLICIT_JUDGMENTS_URL.equalsIgnoreCase(request.path())) {

if (request.method().equals(RestRequest.Method.POST)) {

Expand All @@ -206,7 +192,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final int maxRank = Integer.parseInt(request.param("max_rank", "20"));
final long judgments;

if (StringUtils.equalsIgnoreCase(clickModel, "coec")) {
if ("coec".equalsIgnoreCase(clickModel)) {

final CoecClickModelParameters coecClickModelParameters = new CoecClickModelParameters(true, maxRank);
final CoecClickModel coecClickModel = new CoecClickModel(client, coecClickModelParameters);
Expand Down Expand Up @@ -249,7 +235,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

// Handle the scheduling of creating implicit judgments.
} else if(StringUtils.equalsIgnoreCase(request.path(), SCHEDULING_URL)) {
} else if(SCHEDULING_URL.equalsIgnoreCase(request.path())) {

if (request.method().equals(RestRequest.Method.POST)) {

Expand All @@ -270,15 +256,15 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

// Read the start_time.
final Instant startTime;
if (StringUtils.isEmpty(request.param("start_time"))) {
if (request.param("start_time") == null) {
startTime = Instant.now();
} else {
startTime = Instant.ofEpochMilli(Long.parseLong(request.param("start_time")));
}

// Read the interval.
final int interval;
if (StringUtils.isEmpty(request.param("interval"))) {
if (request.param("interval") == null) {
// Default to every 24 hours.
interval = 1440;
} else {
Expand Down Expand Up @@ -355,29 +341,4 @@ public void onFailure(Exception e) {

}

/**
* Index the query set.
*/
private String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Collection<String> queries) throws Exception {

final Map<String, Object> querySet = new HashMap<>();
querySet.put("name", name);
querySet.put("description", description);
querySet.put("sampling", sampling);
querySet.put("queries", queries);
querySet.put("created_at", Instant.now().toEpochMilli());

final String querySetId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.QUERY_SETS_INDEX_NAME)
.id(querySetId)
.source(querySet)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();

return querySetId;

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import com.google.gson.annotations.SerializedName;

/**
* A UBI event.
* Creates a representation of a UBI event.
*/
public class UbiEvent {

Expand All @@ -27,6 +27,13 @@ public class UbiEvent {
@SerializedName("event_attributes")
private EventAttributes eventAttributes;

/**
* Creates a new representation of an UBI event.
*/
public UbiEvent() {

}

@Override
public String toString() {
return actionName + ", " + clientId + ", " + queryId + ", " + eventAttributes.getObject().toString() + ", " + eventAttributes.getPosition().getIndex();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.eval.SearchQualityEvaluationPlugin;

import java.time.Instant;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

/**
* An interface for sampling UBI queries.
*/
public abstract class AbstractQuerySampler {

/**
* Gets the name of the sampler.
* @return The name of the sampler.
*/
abstract String getName();

/**
* Samples the queries and inserts the query set into an index.
* @return A query set ID.
*/
abstract String sample() throws Exception;

/**
* Index the query set.
*/
protected String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Collection<String> queries) throws Exception {

final Map<String, Object> querySet = new HashMap<>();
querySet.put("name", name);
querySet.put("description", description);
querySet.put("sampling", sampling);
querySet.put("queries", queries);
querySet.put("created_at", Instant.now().toEpochMilli());

final String querySetId = UUID.randomUUID().toString();

final IndexRequest indexRequest = new IndexRequest().index(SearchQualityEvaluationPlugin.QUERY_SETS_INDEX_NAME)
.id(querySetId)
.source(querySet)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

client.index(indexRequest).get();

return querySetId;

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.eval.samplers;

public class AbstractSamplerParameters {

private final String name;
private final String description;
private final String sampling;
private final int querySetSize;

public AbstractSamplerParameters(final String name, final String description, final String sampling, final int querySetSize) {
this.name = name;
this.description = description;
this.sampling = sampling;
this.querySetSize = querySetSize;
}

public String getName() {
return name;
}

public String getDescription() {
return description;
}

public String getSampling() {
return sampling;
}

public int getQuerySetSize() {
return querySetSize;
}

}
Loading

0 comments on commit 222fc37

Please sign in to comment.