diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java new file mode 100644 index 0000000..6d59d3a --- /dev/null +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/DcgSearchMetric.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import java.util.List; + +public class DcgSearchMetric extends SearchMetric { + + protected final List relevanceScores; + + public DcgSearchMetric(final int k, final List relevanceScores) { + super(k); + this.relevanceScores = relevanceScores; + } + + @Override + public String getName() { + return "dcg_at_" + k; + } + + @Override + public double calculate() { + + double dcg = 0.0; + for(int i = 0; i < relevanceScores.size(); i++) { + double relevance = relevanceScores.get(i); + dcg += relevance / Math.log(i + 2); // Add 2 to avoid log(1) = 0 + } + return dcg; + + } + +} diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java new file mode 100644 index 0000000..b62aed7 --- /dev/null +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/NdcgSearchMetric.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import java.util.List; + +public class NdcgSearchMetric extends DcgSearchMetric { + + private final List idealRelevanceScores; + + public NdcgSearchMetric(final int k, final List relevanceScores, final List idealRelevanceScores) { + super(k, relevanceScores); + this.idealRelevanceScores = idealRelevanceScores; + } + + @Override + public String getName() { + return "ndcg_at_" + k; + } + + @Override + public double calculate() { + + double dcg = super.calculate(); + + double idcg = 0.0; + for(int i = 0; i < idealRelevanceScores.size(); i++) { + double relevance = idealRelevanceScores.get(i); + idcg += relevance / Math.log(i + 2); // Add 2 to avoid log(1) = 0 + } + + if(idcg == 0) { + return 0; + } + + return dcg / idcg; + + } + +} diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java new file mode 100644 index 0000000..9ae1e0c --- /dev/null +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/PrecisionSearchMetric.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import java.util.List; + +public class PrecisionSearchMetric extends SearchMetric { + + private final List relevanceScores; + + public PrecisionSearchMetric(final int k, final List relevanceScores) { + super(k); + this.relevanceScores = relevanceScores; + } + + @Override + public String getName() { + return "precision_at_" + k; + } + + @Override + public double calculate() { + + // TODO: Implement this. + return 0.0; + + } + +} diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/SearchMetric.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/SearchMetric.java new file mode 100644 index 0000000..658c202 --- /dev/null +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/SearchMetric.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.eval.metrics; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public abstract class SearchMetric { + + private static final Logger LOGGER = LogManager.getLogger(SearchMetric.class); + + protected int k; + + public abstract String getName(); + + public abstract double calculate(); + + public SearchMetric(final int k) { + this.k = k; + } + + public int getK() { + return k; + } + +} diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/SearchMetrics.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/SearchMetrics.java deleted file mode 100644 index 5a49d66..0000000 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/metrics/SearchMetrics.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.eval.metrics; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.eval.judgments.model.Judgment; -import org.opensearch.eval.runners.QueryResult; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -/** - * Provides the ability to calculate search metrics and stores them. - */ -public class SearchMetrics { - - private static final Logger LOGGER = LogManager.getLogger(SearchMetrics.class.getName()); - - private final int k; - private final double dcg; - private final double ndcg; - private final double precision; - - /** - * Calculate the search metrics for an entire query set. - * @param queryResults A list of {@link QueryResult}. - * @param judgments A list of {@link Judgment judgments} used for metric calculation. - * @param k The k used for metrics calculation, i.e. DCG@k. - */ - public SearchMetrics(final List queryResults, final List judgments, final int k) { - this.k = k; - - // TODO: Calculate the metrics for the whole query set. - this.dcg = 0.0; - this.ndcg = 0.0; - this.precision = 0.0; - } - - /** - * Calculate the search metrics for a single query. - * @param query The user query. - * @param orderedDocumentIds The documents returned for the user query in order. - * @param judgments A list of {@link Judgment judgments} used for metric calculation. - * @param k The k used for metrics calculation, i.e. DCG@k. - */ - public SearchMetrics(final String query, final List orderedDocumentIds, final List judgments, final int k) { - this.k = k; - - // TODO: Calculate the metrics for the single query. - final List scores = getRelevanceScores(query, orderedDocumentIds, judgments, k); - - this.dcg = calculateDCG(scores); - this.ndcg = 0.0; - this.precision = 0.0; - } - - /** - * Gets the metrics as a map for ease of indexing. - * @return A map of the search metrics. - */ - public Map getSearchMetricsAsMap() { - - final Map metrics = new HashMap<>(); - metrics.put("dcg_at_" + k, dcg); - metrics.put("ndcg_at_" + k, ndcg); - metrics.put("prec_at_" + k, precision); - - return metrics; - - } - - private List getRelevanceScores(final String query, final List orderedDocumentIds, final List judgments, final int k) { - - // Ordered list of scores. - final List scores = new ArrayList<>(); - - // Go through each document up to k and get the score. - for(int i = 0; i < k; i++) { - - final String documentId = orderedDocumentIds.get(i); - - // Get the score for this document for this query. - final Judgment judgment = Judgment.findJudgment(judgments, query, documentId); - - if(judgment != null) { - scores.add(judgment.getJudgment()); - } - - if(i == orderedDocumentIds.size()) { - // k is greater than the actual length of documents. - break; - } - - } - - String listOfScores = scores.stream().map(Object::toString).collect(Collectors.joining(", ")); - LOGGER.info("Got relevance scores: {}", listOfScores); - - return scores; - - } - - private double calculateDCG(final List relevanceScores) { - double dcg = 0.0; - for(int i = 0; i < relevanceScores.size(); i++) { - double relevance = relevanceScores.get(i); - dcg += relevance / Math.log(i + 2); // Add 2 to avoid log(1) = 0 - } - return dcg; - } - - private double calculateNDCG(final List relevanceScores, final List idealRelevanceScores) { - double dcg = calculateDCG(relevanceScores); - double idcg = calculateDCG(idealRelevanceScores); - - if(idcg == 0) { - return 0; // Avoid division by zero - } - - return dcg / idcg; - } - - public int getK() { - return k; - } - - public double getDcg() { - return dcg; - } - - public double getNdcg() { - return ndcg; - } - - public double getPrecision() { - return precision; - } - -} diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/OpenSearchAbstractQuerySetRunner.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/OpenSearchAbstractQuerySetRunner.java index a30888b..b5ffd9b 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/OpenSearchAbstractQuerySetRunner.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/OpenSearchAbstractQuerySetRunner.java @@ -16,10 +16,12 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.eval.SearchQualityEvaluationPlugin; import org.opensearch.eval.judgments.model.Judgment; -import org.opensearch.eval.metrics.SearchMetrics; +import org.opensearch.eval.metrics.DcgSearchMetric; +import org.opensearch.eval.metrics.NdcgSearchMetric; +import org.opensearch.eval.metrics.PrecisionSearchMetric; +import org.opensearch.eval.metrics.SearchMetric; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -29,6 +31,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.opensearch.eval.SearchQualityEvaluationRestHandler.QUERY_PLACEHOLDER; @@ -110,7 +113,16 @@ public void onResponse(final SearchResponse searchResponse) { } // TODO: Use getJudgment() to get the judgment for this document. - queryResults.add(new QueryResult(userQuery, orderedDocumentIds, judgments, k)); + final List relevanceScores = getRelevanceScores(query, orderedDocumentIds, k); + + final SearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores); + // TODO: Add these metrics in, too. + //final SearchMetric ndcgSearchmetric = new NdcgSearchMetric(k, relevanceScores, idealRelevanceScores); + //final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, relevanceScores); + + final Collection searchMetrics = List.of(dcgSearchMetric); // ndcgSearchmetric, precisionSearchMetric); + + queryResults.add(new QueryResult(userQuery, orderedDocumentIds, k, searchMetrics)); } @@ -125,7 +137,14 @@ public void onFailure(Exception ex) { } // TODO: Calculate the search metrics for the entire query set given the results and the judgments. - final SearchMetrics searchMetrics = new SearchMetrics(queryResults, judgments, k); + final List orderedDocumentIds = new ArrayList<>(); + final List relevanceScores = getRelevanceScores(query, orderedDocumentIds, k); + final SearchMetric dcgSearchMetric = new DcgSearchMetric(k, relevanceScores); + // TODO: Add these metrics in, too. + //final SearchMetric ndcgSearchmetric = new NdcgSearchMetric(k, relevanceScores, idealRelevanceScores); + //final SearchMetric precisionSearchMetric = new PrecisionSearchMetric(k, relevanceScores); + + final Collection searchMetrics = List.of(dcgSearchMetric); // ndcgSearchmetric, precisionSearchMetric); return new QuerySetRunResult(queryResults, searchMetrics); @@ -143,11 +162,15 @@ public void save(final QuerySetRunResult result) throws Exception { final Map results = new HashMap<>(); results.put("run_id", result.getRunId()); - results.put("search_metrics", result.getSearchMetrics().getSearchMetricsAsMap()); results.put("query_results", result.getQueryResultsAsMap()); - final IndexRequest indexRequest = new IndexRequest(SearchQualityEvaluationPlugin.QUERY_SETS_RUN_RESULTS_INDEX_NAME); - indexRequest.source(results); + // Calculate and add each metric to the object to index. + for(final SearchMetric searchMetric : result.getSearchMetrics()) { + results.put(searchMetric.getName(), searchMetric.calculate()); + } + + final IndexRequest indexRequest = new IndexRequest(SearchQualityEvaluationPlugin.QUERY_SETS_RUN_RESULTS_INDEX_NAME) + .source(results); client.index(indexRequest, new ActionListener<>() { @Override @@ -163,4 +186,33 @@ public void onFailure(Exception ex) { } + public List getRelevanceScores(final String query, final List orderedDocumentIds, final int k) { + + // Ordered list of scores. + final List scores = new ArrayList<>(); + + // Go through each document up to k and get the score. + for(int i = 0; i < k; i++) { + + final String documentId = orderedDocumentIds.get(i); + + // TODO: Find the judgment value for this combination of query and documentId from the index. + final double judgment = 0.1; + + scores.add(judgment); + + if(i == orderedDocumentIds.size()) { + // k is greater than the actual length of documents. + break; + } + + } + + String listOfScores = scores.stream().map(Object::toString).collect(Collectors.joining(", ")); + LOGGER.info("Got relevance scores: {}", listOfScores); + + return scores; + + } + } diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/QueryResult.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/QueryResult.java index 46d7c13..a001759 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/QueryResult.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/QueryResult.java @@ -8,9 +8,9 @@ */ package org.opensearch.eval.runners; -import org.opensearch.eval.judgments.model.Judgment; -import org.opensearch.eval.metrics.SearchMetrics; +import org.opensearch.eval.metrics.SearchMetric; +import java.util.Collection; import java.util.List; /** @@ -20,20 +20,21 @@ public class QueryResult { private final String query; private final List orderedDocumentIds; - private final SearchMetrics searchMetrics; + private final int k; + private final Collection searchMetrics; /** * Creates the search results. * @param query The query used to generate this result. * @param orderedDocumentIds A list of ordered document IDs in the same order as they appeared * in the query. - * @param judgments A list of {@link Judgment judgments} used for metric calculation. * @param k The k used for metrics calculation, i.e. DCG@k. */ - public QueryResult(final String query, final List orderedDocumentIds, final List judgments, final int k) { + public QueryResult(final String query, final List orderedDocumentIds, final int k, final Collection searchMetrics) { this.query = query; this.orderedDocumentIds = orderedDocumentIds; - this.searchMetrics = new SearchMetrics(query, orderedDocumentIds, judgments, k); + this.k = k; + this.searchMetrics = searchMetrics; } /** @@ -52,11 +53,11 @@ public List getOrderedDocumentIds() { return orderedDocumentIds; } - /** - * Gets the search metrics for this query. - * @return The {@link SearchMetrics} for this query. - */ - public SearchMetrics getSearchMetrics() { + public int getK() { + return k; + } + + public Collection getSearchMetrics() { return searchMetrics; } diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/QuerySetRunResult.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/QuerySetRunResult.java index afc7bcb..0c97e67 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/QuerySetRunResult.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/runners/QuerySetRunResult.java @@ -8,7 +8,7 @@ */ package org.opensearch.eval.runners; -import org.opensearch.eval.metrics.SearchMetrics; +import org.opensearch.eval.metrics.SearchMetric; import java.util.ArrayList; import java.util.Collection; @@ -24,17 +24,17 @@ public class QuerySetRunResult { private final String runId; private final List queryResults; - private final SearchMetrics searchMetrics; + private final Collection metrics; /** * Creates a new query set run result. A random UUID is generated as the run ID. * @param queryResults A collection of {@link QueryResult} that contains the queries and search results. - * @param searchMetrics The {@link SearchMetrics metrics} calculated from the search results. + * @param metrics The {@link SearchMetric metrics} calculated from the search results. */ - public QuerySetRunResult(final List queryResults, final SearchMetrics searchMetrics) { + public QuerySetRunResult(final List queryResults, final Collection metrics) { this.runId = UUID.randomUUID().toString(); this.queryResults = queryResults; - this.searchMetrics = searchMetrics; + this.metrics = metrics; } /** @@ -46,11 +46,11 @@ public String getRunId() { } /** - * Gets the {@link SearchMetrics metrics} calculated from the run. - * @return The {@link SearchMetrics metrics} calculated from the run. + * Gets the {@link SearchMetric metrics} calculated from the run. + * @return The {@link SearchMetric metrics} calculated from the run. */ - public SearchMetrics getSearchMetrics() { - return searchMetrics; + public Collection getSearchMetrics() { + return metrics; } /** @@ -71,7 +71,11 @@ public Collection> getQueryResultsAsMap() { q.put("query", queryResult.getQuery()); q.put("document_ids", queryResult.getOrderedDocumentIds()); - q.put("search_metrics", queryResult.getSearchMetrics().getSearchMetricsAsMap()); + + // Calculate and add each metric to the map. + for(final SearchMetric searchMetric : queryResult.getSearchMetrics()) { + q.put(searchMetric.getName(), searchMetric.calculate()); + } qs.add(q);