Skip to content

Commit

Permalink
Sampling queries until the max is reached.
Browse files Browse the repository at this point in the history
  • Loading branch information
jzonthemtn committed Nov 25, 2024
1 parent 8ae2a45 commit 7a10c6d
Show file tree
Hide file tree
Showing 6 changed files with 545 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final int querySetSize = Integer.parseInt(request.param("query_set_size", "1000"));

// Create a query set by finding all the unique user_query terms.
if ("none".equalsIgnoreCase(sampling)) {
if (AllQueriesQuerySampler.NAME.equalsIgnoreCase(sampling)) {

// If we are not sampling queries, the query sets should just be directly
// indexed into OpenSearch using the `ubu_queries` index directly.
Expand All @@ -119,7 +119,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli


// Create a query set by using PPTSS sampling.
} else if ("pptss".equalsIgnoreCase(sampling)) {
} else if (ProbabilityProportionalToSizeAbstractQuerySampler.NAME.equalsIgnoreCase(sampling)) {

LOGGER.info("Creating query set using PPTSS");

Expand Down Expand Up @@ -192,7 +192,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
final int maxRank = Integer.parseInt(request.param("max_rank", "20"));
final long judgments;

if ("coec".equalsIgnoreCase(clickModel)) {
if (CoecClickModel.CLICK_MODEL_NAME.equalsIgnoreCase(clickModel)) {

final CoecClickModelParameters coecClickModelParameters = new CoecClickModelParameters(true, maxRank);
final CoecClickModel coecClickModel = new CoecClickModel(client, coecClickModelParameters);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@

public class CoecClickModel extends ClickModel {

public static final String CLICK_MODEL_NAME = "coec";

// OpenSearch indexes.
public static final String INDEX_RANK_AGGREGATED_CTR = "rank_aggregated_ctr";
public static final String INDEX_QUERY_DOC_CTR = "click_through_rates";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
*/
package org.opensearch.eval.samplers;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.node.NodeClient;
import org.opensearch.eval.SearchQualityEvaluationPlugin;
import org.opensearch.eval.SearchQualityEvaluationRestHandler;

import java.time.Instant;
import java.util.Collection;
Expand All @@ -24,23 +27,27 @@
*/
public abstract class AbstractQuerySampler {

private static final Logger LOGGER = LogManager.getLogger(AbstractQuerySampler.class);

/**
* Gets the name of the sampler.
* @return The name of the sampler.
*/
abstract String getName();
public abstract String getName();

/**
* Samples the queries and inserts the query set into an index.
* @return A query set ID.
*/
abstract String sample() throws Exception;
public abstract String sample() throws Exception;

/**
* Index the query set.
*/
protected String indexQuerySet(final NodeClient client, final String name, final String description, final String sampling, Collection<String> queries) throws Exception {

LOGGER.info("Indexing {} queries for query set {}", queries.size(), name);

final Map<String, Object> querySet = new HashMap<>();
querySet.put("name", name);
querySet.put("description", description);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
*/
public class AllQueriesQuerySampler extends AbstractQuerySampler {

public static final String NAME = "none";

private final NodeClient client;
private final AllQueriesQuerySamplerParameters parameters;

Expand All @@ -39,7 +41,7 @@ public AllQueriesQuerySampler(final NodeClient client, final AllQueriesQuerySamp

@Override
public String getName() {
return "none";
return NAME;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
*/
public class ProbabilityProportionalToSizeAbstractQuerySampler extends AbstractQuerySampler {

public static final String NAME = "pptss";

private static final Logger LOGGER = LogManager.getLogger(ProbabilityProportionalToSizeAbstractQuerySampler.class);

private final NodeClient client;
Expand All @@ -52,7 +54,7 @@ public ProbabilityProportionalToSizeAbstractQuerySampler(final NodeClient client

@Override
public String getName() {
return "pptss";
return NAME;
}

@Override
Expand Down Expand Up @@ -80,7 +82,7 @@ public String sample() throws Exception {

while (searchHits != null && searchHits.length > 0) {

LOGGER.info("search hits size = " + searchHits.length);
LOGGER.info("search hits size = {}", searchHits.length);

for(final SearchHit hit : searchHits) {
final Map<String, Object> fields = hit.getSourceAsMap();
Expand Down Expand Up @@ -128,8 +130,12 @@ public String sample() throws Exception {
final Set<String> querySet = new HashSet<>();
final Set<Double> randomNumbers = new HashSet<>();

// Generate a random number between 0 and 1 for the size of the query set.
for(int count = 0; count < parameters.getQuerySetSize(); count++) {
// Generate random numbers between 0 and 1 for the size of the query set.
// Do this until our query set has reached the requested maximum size.
// This may require generating more random numbers than what was requested
// because removing duplicate user queries will require randomly picking more queries.
int count = 1;
while(querySet.size() < parameters.getQuerySetSize() && count < userQueries.size()) {

// Make a random number not yet used.
double random;
Expand All @@ -147,9 +153,11 @@ public String sample() throws Exception {
smallestDelta = delta;
closestQuery = query;
}

}

querySet.add(closestQuery);
count++;

// LOGGER.info("Generated random value: {}; Smallest delta = {}; Closest query = {}", random, smallestDelta, closestQuery);

}
Expand Down
Loading

0 comments on commit 7a10c6d

Please sign in to comment.