diff --git a/scripts/run-query-set.json b/scripts/run-query-set.json index a784520..bc6cc4c 100644 --- a/scripts/run-query-set.json +++ b/scripts/run-query-set.json @@ -1,6 +1,6 @@ { - "query_set_id": "6c895d94-3146-4d9b-8858-f632e3fb0201", - "judgments_id": "40181d0d-fd94-4d91-a4b6-d08230bbaa31", + "query_set_id": "11e8be61-72f6-4c78-8253-cee03451f5c0", + "judgments_id": "3a84ca49-2d59-4912-9bc9-e193a5c27587", "index": "ecommerce", "search_pipeline": "hybrid-search-pipeline", "id_field": "asin", diff --git a/src/main/java/org/opensearch/eval/engine/OpenSearchEngine.java b/src/main/java/org/opensearch/eval/engine/OpenSearchEngine.java index a5b6a68..87a4e6c 100644 --- a/src/main/java/org/opensearch/eval/engine/OpenSearchEngine.java +++ b/src/main/java/org/opensearch/eval/engine/OpenSearchEngine.java @@ -149,13 +149,28 @@ public String indexQuerySet(final QuerySet querySet) throws IOException { } @Override - public QuerySet getQuerySet(String querySetId) throws IOException { + public boolean doesQuerySetExist(final String querySetId) throws IOException { final Query query = Query.of(q -> q.term(m -> m.field("_id").value(FieldValue.of(querySetId)))); - final SearchResponse searchResponse = client.search(s -> s.index(Constants.QUERY_SETS_INDEX_NAME).query(query).size(1), QuerySet.class); + final TrackHits trackHits = new TrackHits.Builder().enabled(true).build(); + final SearchResponse searchResponse = client.search(s -> s.index(Constants.QUERY_SETS_INDEX_NAME).trackTotalHits(trackHits).query(query).size(1), QuerySet.class); + + if(searchResponse.hits().total().value() > 0) { + return true; + } else { + return false; + } + + } + + + @Override + public QuerySet getQuerySet(final String querySetId) throws IOException { + + final Query query = Query.of(q -> q.term(m -> m.field("_id").value(FieldValue.of(querySetId)))); - // TODO: Handle the query set not being found. + final SearchResponse searchResponse = client.search(s -> s.index(Constants.QUERY_SETS_INDEX_NAME).query(query).size(1), QuerySet.class); return searchResponse.hits().hits().getFirst().source(); @@ -235,6 +250,19 @@ public Collection getUbiQueries() throws IOException { } + @Override + public long getJudgments(final String index, final String judgmentsSetId) throws IOException { + + final Query query = Query.of(q -> q.term(m -> m.field("judgment_set_id").value(FieldValue.of(judgmentsSetId)))); + + final TrackHits trackHits = new TrackHits.Builder().enabled(true).build(); + final SearchResponse searchResponse = client.search(s -> s.index(Constants.JUDGMENTS_INDEX_NAME).query(query).trackTotalHits(trackHits).size(0), Judgment.class); + + return searchResponse.hits().total().value(); + + } + + @Override public Collection getJudgments(final String index) throws IOException { final Collection judgments = new ArrayList<>(); diff --git a/src/main/java/org/opensearch/eval/engine/SearchEngine.java b/src/main/java/org/opensearch/eval/engine/SearchEngine.java index 3529a76..e892e3a 100644 --- a/src/main/java/org/opensearch/eval/engine/SearchEngine.java +++ b/src/main/java/org/opensearch/eval/engine/SearchEngine.java @@ -29,6 +29,7 @@ public abstract class SearchEngine { public abstract boolean bulkIndex(String index, Map documents) throws IOException; public abstract Collection getJudgments(final String index) throws IOException; + public abstract long getJudgments(final String index, final String judgmentsSetId) throws IOException; public abstract List runQuery(final String index, final String query, final int k, final String userQuery, final String idField, final String pipeline) throws IOException; @@ -43,6 +44,14 @@ public abstract class SearchEngine { */ public abstract QuerySet getQuerySet(String querySetId) throws IOException; + /** + * Determines if a query set exists. + * @param querySetId The ID of the query set to get. + * @return true if the query set exists. + * @throws IOException Thrown upon an error searching. + */ + public abstract boolean doesQuerySetExist(String querySetId) throws IOException; + /** * Get a judgment from the index. * @param judgmentsId The ID of the judgments to find. diff --git a/src/main/java/org/opensearch/eval/model/data/Judgment.java b/src/main/java/org/opensearch/eval/model/data/Judgment.java index ee14fcf..4ba48fd 100644 --- a/src/main/java/org/opensearch/eval/model/data/Judgment.java +++ b/src/main/java/org/opensearch/eval/model/data/Judgment.java @@ -38,7 +38,7 @@ public class Judgment { private String timestamp; public Judgment() { - + // Empty constructor used for deserialization. } /** diff --git a/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java b/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java index 3593544..2b70342 100644 --- a/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java +++ b/src/main/java/org/opensearch/eval/runners/OpenSearchQuerySetRunner.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.eval.Constants; import org.opensearch.eval.engine.SearchEngine; import org.opensearch.eval.metrics.DcgSearchMetric; import org.opensearch.eval.metrics.NdcgSearchMetric; @@ -49,6 +50,18 @@ public OpenSearchQuerySetRunner(final SearchEngine searchEngine) { @Override public QuerySetRunResult run(final RunQuerySetParameters querySetParameters) throws Exception { + // Verify the given query set and judgment set exists prior to trying to run. + if(!searchEngine.doesQuerySetExist(querySetParameters.getQuerySetId())) { + LOGGER.error("The given query set {} does not exist", querySetParameters.getQuerySetId()); + throw new IllegalArgumentException("The given query set " + querySetParameters.getQuerySetId() + " does not exist"); + } + + final long judgmentCount = searchEngine.getJudgments(Constants.JUDGMENTS_INDEX_NAME, querySetParameters.getJudgmentsId()); + if(judgmentCount == 0) { + LOGGER.error("There are no judgments with the judgment set ID {}", querySetParameters.getJudgmentsId()); + throw new IllegalArgumentException("There are no judgments with the judgment set ID " + querySetParameters.getJudgmentsId()); + } + final QuerySet querySet = searchEngine.getQuerySet(querySetParameters.getQuerySetId()); LOGGER.info("Found {} queries in query set {}", querySet.getQuerySetQueries().size(), querySetParameters.getQuerySetId()); @@ -187,8 +200,6 @@ public void save(final QuerySetRunResult result) throws Exception { } - - } }