diff --git a/README.md b/README.md index d3f6cde..3c061fd 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ curl -s http://localhost:9200/judgments/_search | jq This is a standalone Java application to generate implicit judgments from indexed UBI data. It runs outside OpenSearch and queries the UBI indexes to get the data for calculating the implicit judgments. -To run it, run the `org.opensearch.qef.App` class. This will connect to OpenSearch running on `localhost:9200`. It expects the `ubi_events` and `ubi_queries` indexes to exist and be populated. +To run it, run the `org.opensearch.eval.App` class. This will connect to OpenSearch running on `localhost:9200`. It expects the `ubi_events` and `ubi_queries` indexes to exist and be populated. ## License diff --git a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/judgments/model/QuerySetQuery.java b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/judgments/model/QuerySetQuery.java index f958b98..2244df4 100644 --- a/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/judgments/model/QuerySetQuery.java +++ b/opensearch-search-quality-evaluation-plugin/src/main/java/org/opensearch/eval/judgments/model/QuerySetQuery.java @@ -1,3 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ package org.opensearch.eval.judgments.model; public class QuerySetQuery { diff --git a/opensearch-search-quality-evaluation/aggs.sh b/opensearch-search-quality-evaluation/aggs.sh new file mode 100755 index 0000000..ade861b --- /dev/null +++ b/opensearch-search-quality-evaluation/aggs.sh @@ -0,0 +1,20 @@ +#!/bin/bash -e + +curl -X GET http://localhost:9200/ubi_events/_search -H "Content-Type: application/json" -d' +{ + "size": 0, + "aggs": { + "By_Action": { + "terms": { + "field": "action_name" + }, + "aggs": { + "By_Position": { + "terms": { + "field": "event_attributes.position.index" + } + } + } + } + } +}' | jq \ No newline at end of file diff --git a/opensearch-search-quality-evaluation/build.gradle b/opensearch-search-quality-evaluation/build.gradle new file mode 100644 index 0000000..3e016c2 --- /dev/null +++ b/opensearch-search-quality-evaluation/build.gradle @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +apply plugin: 'java' +apply plugin: 'java-library-distribution' +apply plugin: 'application' + +description = "opensearch-search-quality-implicit-judgments" + +repositories { + mavenLocal() + mavenCentral() + maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } +} + +test { + useJUnitPlatform() +} + +dependencies { + implementation 'org.opensearch.client:opensearch-rest-high-level-client:2.18.0' + implementation 'org.apache.httpcomponents.client5:httpclient5:5.4' + compileOnly "org.apache.logging.log4j:log4j-core:2.24.0" + implementation "org.apache.commons:commons-lang3:3.17.0" + implementation "com.google.code.gson:gson:2.11.0" + + testImplementation "org.mockito:mockito-core:5.14.2" + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.11.3' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.11.3' +} + +tasks.withType(Tar){ + duplicatesStrategy = DuplicatesStrategy.EXCLUDE +} + +tasks.withType(Zip){ + duplicatesStrategy = DuplicatesStrategy.EXCLUDE +} \ No newline at end of file diff --git a/opensearch-search-quality-evaluation/coec.png b/opensearch-search-quality-evaluation/coec.png new file mode 100644 index 0000000..65e297a Binary files /dev/null and b/opensearch-search-quality-evaluation/coec.png differ diff --git a/opensearch-search-quality-evaluation/coec_definition.png b/opensearch-search-quality-evaluation/coec_definition.png new file mode 100644 index 0000000..e5ed3f9 Binary files /dev/null and b/opensearch-search-quality-evaluation/coec_definition.png differ diff --git a/opensearch-search-quality-evaluation/queries.txt b/opensearch-search-quality-evaluation/queries.txt new file mode 100644 index 0000000..9ad9c30 --- /dev/null +++ b/opensearch-search-quality-evaluation/queries.txt @@ -0,0 +1,11 @@ +GET ubi_queries/_search + +GET ubi_events/_search + +GET rank_aggregated_ctr/_search + +GET click_through_rates/_search + +GET judgments/_search + +DELETE rank_aggregated_ctr,click_through_rates,judgments diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/App.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/App.java new file mode 100644 index 0000000..40e2621 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/App.java @@ -0,0 +1,35 @@ +package org.opensearch.eval; + +import org.apache.http.HttpHost; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.eval.clickmodel.coec.CoecClickModel; +import org.opensearch.eval.clickmodel.coec.CoecClickModelParameters; +import org.opensearch.eval.model.Judgment; + +import java.util.Collection; + +/** + * Entry point for the OpenSearch Evaluation Framework standalone app. + */ +public class App { + + private static final Logger LOGGER = LogManager.getLogger(App.class.getName()); + + public static void main(String[] args) throws Exception { + + final RestClientBuilder builder = RestClient.builder(new HttpHost("localhost", 9200, "http")); + final RestHighLevelClient restHighLevelClient = new RestHighLevelClient(builder); + + final CoecClickModelParameters coecClickModelParameters = new CoecClickModelParameters(false, 20); + final CoecClickModel coecClickModel = new CoecClickModel(coecClickModelParameters); + + final Collection judgments = coecClickModel.calculateJudgments(); + Judgment.showJudgments(judgments); + + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/ClickModel.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/ClickModel.java new file mode 100644 index 0000000..9d1a540 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/ClickModel.java @@ -0,0 +1,15 @@ +package org.opensearch.eval.clickmodel; + +import org.opensearch.eval.model.Judgment; + +import java.io.IOException; +import java.util.Collection; + +public abstract class ClickModel { + + public static final String INDEX_UBI_EVENTS = "ubi_events"; + public static final String INDEX_UBI_QUERIES = "ubi_queries"; + + public abstract Collection calculateJudgments() throws IOException; + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/ClickModelParameters.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/ClickModelParameters.java new file mode 100644 index 0000000..2e70008 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/ClickModelParameters.java @@ -0,0 +1,5 @@ +package org.opensearch.eval.clickmodel; + +public abstract class ClickModelParameters { + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/coec/CoecClickModel.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/coec/CoecClickModel.java new file mode 100644 index 0000000..a11be83 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/coec/CoecClickModel.java @@ -0,0 +1,357 @@ +package org.opensearch.eval.clickmodel.coec; + +import com.google.gson.Gson; +import org.apache.commons.lang3.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchScrollRequest; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Requests; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.eval.clickmodel.ClickModel; +import org.opensearch.eval.engine.opensearch.OpenSearchHelper; +import org.opensearch.eval.model.ClickthroughRate; +import org.opensearch.eval.model.Judgment; +import org.opensearch.eval.model.ubi.event.UbiEvent; +import org.opensearch.eval.util.MathUtils; +import org.opensearch.eval.util.UserQueryHash; +import org.opensearch.search.Scroll; +import org.opensearch.search.SearchHit; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.terms.Terms; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.Map; +import java.util.Set; + +public class CoecClickModel extends ClickModel { + + // OpenSearch indexes. + public static final String INDEX_RANK_AGGREGATED_CTR = "rank_aggregated_ctr"; + public static final String INDEX_QUERY_DOC_CTR = "click_through_rates"; + public static final String INDEX_JUDGMENT = "judgments"; + + // UBI event names. + public static final String EVENT_CLICK = "click"; + public static final String EVENT_VIEW = "view"; + + private final CoecClickModelParameters parameters; + + private final OpenSearchHelper openSearchHelper; + + private final UserQueryHash userQueryHash = new UserQueryHash(); + private final Gson gson = new Gson(); + private final RestHighLevelClient client; + + private static final Logger LOGGER = LogManager.getLogger(CoecClickModel.class.getName()); + + public CoecClickModel(final CoecClickModelParameters parameters) { + + this.parameters = parameters; + this.openSearchHelper = new OpenSearchHelper(parameters.getRestHighLevelClient()); + this.client = parameters.getRestHighLevelClient(); + + } + + @Override + public Collection calculateJudgments() throws IOException { + + final int maxRank = parameters.getMaxRank(); + + // Calculate and index the rank-aggregated click-through. + final Map rankAggregatedClickThrough = getRankAggregatedClickThrough(); + LOGGER.info("Rank-aggregated clickthrough positions: {}", rankAggregatedClickThrough.size()); + showRankAggregatedClickThrough(rankAggregatedClickThrough); + + // Calculate and index the click-through rate for query/doc pairs. + final Map> clickthroughRates = getClickthroughRate(maxRank); + LOGGER.info("Clickthrough rates for number of queries: {}", clickthroughRates.size()); + showClickthroughRates(clickthroughRates); + + // Generate and index the implicit judgments. + final Collection judgments = calculateCoec(rankAggregatedClickThrough, clickthroughRates); + LOGGER.info("Number of judgments: {}", judgments.size()); + + return judgments; + + } + + public Collection calculateCoec(final Map rankAggregatedClickThrough, + final Map> clickthroughRates) throws IOException { + + // Calculate the COEC. + // Numerator is the total number of clicks received by a query/result pair. + // Denominator is the expected clicks (EC) that an average result would receive after being impressed i times at rank r, + // and CTR is the average CTR for each position in the results page (up to R) computed over all queries and results. + + // Format: query_id, query, document, judgment + final Collection judgments = new LinkedList<>(); + + // Up to Rank R + final int maxRank = 20; + + for(final String userQuery : clickthroughRates.keySet()) { + + // The clickthrough rates for this query. + final Collection ctrs = clickthroughRates.get(userQuery); + + for(final ClickthroughRate ctr : ctrs) { + + double denominatorSum = 0; + + for(int r = 0; r < maxRank; r++) { + + final double meanCtrAtRank = rankAggregatedClickThrough.getOrDefault(r, 0.0); + final int countOfTimesShownAtRank = openSearchHelper.getCountOfQueriesForUserQueryHavingResultInRankR(userQuery, ctr.getObjectId(), r); + +// System.out.println("rank = " + r); +// System.out.println("\tmeanCtrAtRank = " + meanCtrAtRank); +// System.out.println("\tcountOfTimesShownAtRank = " + countOfTimesShownAtRank); + + denominatorSum += (meanCtrAtRank * countOfTimesShownAtRank); + + } + + // Numerator is sum of clicks at all ranks up to the maxRank. + final int totalNumberClicksForQueryResult = ctr.getClicks(); + +// System.out.println("numerator = " + totalNumberClicksForQueryResult); +// System.out.println("denominator = " + denominatorSum); + + // Divide the numerator by the denominator (value). + final double judgment = totalNumberClicksForQueryResult / denominatorSum; + + // Hash the user query to get a query ID. + final int queryId = userQueryHash.getHash(userQuery); + + // Add the judgment to the list. + // TODO: What to do for query ID when the values are per user_query instead? + judgments.add(new Judgment(String.valueOf(queryId), userQuery, ctr.getObjectId(), judgment)); + + } + + } + + if(parameters.isPersist()) { + openSearchHelper.indexJudgments(judgments); + } + + return judgments; + + } + + /** + * Gets the clickthrough rates for each query and its results. + * @param maxRank The maximum rank position to consider. + * @return A map of user_query to the clickthrough rate for each query result. + * @throws IOException Thrown when a problem accessing OpenSearch. + */ + private Map> getClickthroughRate(final int maxRank) throws IOException { + + // For each query: + // - Get each document returned in that query (in the QueryResponse object). + // - Calculate the click-through rate for the document. (clicks/impressions) + + // TODO: Use maxRank in place of the hardcoded 20. + // TODO: Allow for a time period and for a specific application. + + final String query = "{\n" + + " \"bool\": {\n" + + " \"should\": [\n" + + " {\n" + + " \"term\": {\n" + + " \"action_name\": \"click\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"term\": {\n" + + " \"action_name\": \"view\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"must\": [\n" + + " {\n" + + " \"range\": {\n" + + " \"event_attributes.position.index\": {\n" + + " \"lte\": 20\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }"; + + final BoolQueryBuilder queryBuilder = new BoolQueryBuilder().must(new WrapperQueryBuilder(query)); + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(queryBuilder).size(1000); + final Scroll scroll = new Scroll(TimeValue.timeValueMinutes(10L)); + + final SearchRequest searchRequest = Requests + .searchRequest(INDEX_UBI_EVENTS) + .source(searchSourceBuilder) + .scroll(scroll); + + SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); + String scrollId = searchResponse.getScrollId(); + SearchHit[] searchHits = searchResponse.getHits().getHits(); + + final Map> queriesToClickthroughRates = new HashMap<>(); + + while (searchHits != null && searchHits.length > 0) { + + for (final SearchHit hit : searchHits) { + + final UbiEvent ubiEvent = gson.fromJson(hit.getSourceAsString(), UbiEvent.class); + + // We need to the hash of the query_id because two users can both search + // for "computer" and those searches will have different query IDs, but they are the same search. + final String userQuery = openSearchHelper.getUserQuery(ubiEvent.getQueryId()); + // LOGGER.debug("user_query = {}", userQuery); + + // Get the clicks for this queryId from the map, or an empty list if this is a new query. + final Set clickthroughRates = queriesToClickthroughRates.getOrDefault(userQuery, new LinkedHashSet<>()); + + // Get the ClickthroughRate object for the object that was interacted with. + final ClickthroughRate clickthroughRate = clickthroughRates.stream().filter(p -> p.getObjectId().equals(ubiEvent.getEventAttributes().getObject().getObjectId())).findFirst().orElse(new ClickthroughRate(ubiEvent.getEventAttributes().getObject().getObjectId())); + + if (StringUtils.equalsIgnoreCase(ubiEvent.getActionName(), EVENT_CLICK)) { + clickthroughRate.logClick(); + } else { + clickthroughRate.logEvent(); + } + + clickthroughRates.add(clickthroughRate); + queriesToClickthroughRates.put(userQuery, clickthroughRates); + // LOGGER.debug("clickthroughRate = {}", queriesToClickthroughRates.size()); + + } + + final SearchScrollRequest scrollRequest = new SearchScrollRequest(scrollId); + scrollRequest.scroll(scroll); + + searchResponse = client.scroll(scrollRequest, RequestOptions.DEFAULT); + scrollId = searchResponse.getScrollId(); + + searchHits = searchResponse.getHits().getHits(); + + } + + if(parameters.isPersist()) { + openSearchHelper.indexClickthroughRates(queriesToClickthroughRates); + } + + return queriesToClickthroughRates; + + } + + /** + * Calculate the rank-aggregated click through from the UBI events. + * @return A map of positions to clickthrough rates. + * @throws IOException Thrown when a problem accessing OpenSearch. + */ + public Map getRankAggregatedClickThrough() throws IOException { + + final Map rankAggregatedClickThrough = new HashMap<>(); + + // TODO: Allow for a time period and for a specific application. + + final QueryBuilder findRangeNumber = QueryBuilders.rangeQuery("event_attributes.position.index").lte(parameters.getMaxRank()); + final QueryBuilder queryBuilder = new BoolQueryBuilder().must(findRangeNumber); + + final TermsAggregationBuilder positionsAggregator = AggregationBuilders.terms("By_Position").field("event_attributes.position.index").size(parameters.getMaxRank()); + final TermsAggregationBuilder actionNameAggregation = AggregationBuilders.terms("By_Action").field("action_name").subAggregation(positionsAggregator).size(parameters.getMaxRank()); + + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(queryBuilder); + searchSourceBuilder.aggregation(actionNameAggregation); + searchSourceBuilder.from(0); + searchSourceBuilder.size(100); + + final SearchRequest searchRequest = new SearchRequest(INDEX_UBI_EVENTS).source(searchSourceBuilder); + final SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); + + final Map clickCounts = new HashMap<>(); + final Map viewCounts = new HashMap<>(); + + final Terms actionTerms = searchResponse.getAggregations().get("By_Action"); + final Collection actionBuckets = actionTerms.getBuckets(); + for(final Terms.Bucket actionBucket : actionBuckets) { + + // Handle the "click" bucket. + if(StringUtils.equalsIgnoreCase(actionBucket.getKey().toString(), EVENT_CLICK)) { + + final Terms positionTerms = actionBucket.getAggregations().get("By_Position"); + final Collection positionBuckets = positionTerms.getBuckets(); + + for(final Terms.Bucket positionBucket : positionBuckets) { + clickCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount()); + } + + } + + // Handle the "view" bucket. + if(StringUtils.equalsIgnoreCase(actionBucket.getKey().toString(), EVENT_VIEW)) { + + final Terms positionTerms = actionBucket.getAggregations().get("By_Position"); + final Collection positionBuckets = positionTerms.getBuckets(); + + for(final Terms.Bucket positionBucket : positionBuckets) { + viewCounts.put(Integer.valueOf(positionBucket.getKey().toString()), (double) positionBucket.getDocCount()); + } + + } + + } + + for(final Integer x : clickCounts.keySet()) { + //System.out.println("Position = " + x + ", Click Count = " + clickCounts.get(x) + ", Event Count = " + viewCounts.get(x)); + rankAggregatedClickThrough.put(x, clickCounts.get(x) / viewCounts.get(x)); + } + + if(parameters.isPersist()) { + openSearchHelper.indexRankAggregatedClickthrough(rankAggregatedClickThrough); + } + + return rankAggregatedClickThrough; + + } + + private void showClickthroughRates(final Map> clickthroughRates) { + + for(final String userQuery : clickthroughRates.keySet()) { + + LOGGER.info("user_query: {}", userQuery); + + for(final ClickthroughRate clickthroughRate : clickthroughRates.get(userQuery)) { + + LOGGER.info("\t - {}", clickthroughRate.toString()); + + } + + } + + } + + private void showRankAggregatedClickThrough(final Map rankAggregatedClickThrough) { + + for(final int position : rankAggregatedClickThrough.keySet()) { + + LOGGER.info("Position: {}, # ctr: {}", position, MathUtils.round(rankAggregatedClickThrough.get(position), parameters.getRoundingDigits())); + + } + + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/coec/CoecClickModelParameters.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/coec/CoecClickModelParameters.java new file mode 100644 index 0000000..b8d5b77 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/clickmodel/coec/CoecClickModelParameters.java @@ -0,0 +1,50 @@ +package org.opensearch.eval.clickmodel.coec; + +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.eval.clickmodel.ClickModelParameters; + +public class CoecClickModelParameters extends ClickModelParameters { + + private final RestHighLevelClient restHighLevelClient; + private final boolean persist; + private final int maxRank; + private int roundingDigits = 3; + + public CoecClickModelParameters(boolean persist, final int maxRank) { + + final RestClientBuilder builder = RestClient.builder("http://localhost:9200"); + this.restHighLevelClient = new RestHighLevelClient(builder); + + this.persist = persist; + this.maxRank = maxRank; + } + + public CoecClickModelParameters(boolean persist, final int maxRank, final int roundingDigits) { + + final RestClientBuilder builder = RestClient.builder("http://localhost:9200"); + this.restHighLevelClient = new RestHighLevelClient(builder); + + this.persist = persist; + this.maxRank = maxRank; + this.roundingDigits = roundingDigits; + } + + public RestHighLevelClient getRestHighLevelClient() { + return restHighLevelClient; + } + + public boolean isPersist() { + return persist; + } + + public int getMaxRank() { + return maxRank; + } + + public int getRoundingDigits() { + return roundingDigits; + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/engine/opensearch/OpenSearchHelper.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/engine/opensearch/OpenSearchHelper.java new file mode 100644 index 0000000..edc8e56 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/engine/opensearch/OpenSearchHelper.java @@ -0,0 +1,222 @@ +package org.opensearch.eval.engine.opensearch; + +import com.google.gson.Gson; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.eval.model.ClickthroughRate; +import org.opensearch.eval.model.Judgment; +import org.opensearch.eval.model.ubi.query.UbiQuery; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +import static org.opensearch.eval.clickmodel.coec.CoecClickModel.INDEX_JUDGMENT; +import static org.opensearch.eval.clickmodel.coec.CoecClickModel.INDEX_QUERY_DOC_CTR; +import static org.opensearch.eval.clickmodel.coec.CoecClickModel.INDEX_RANK_AGGREGATED_CTR; +import static org.opensearch.eval.clickmodel.coec.CoecClickModel.INDEX_UBI_QUERIES; + +public class OpenSearchHelper { + + private final RestHighLevelClient client; + private final Gson gson = new Gson(); + + // Used to cache the query ID->user_query to avoid unnecessary lookups to OpenSearch. + private static final Map userQueryCache = new HashMap<>(); + + public OpenSearchHelper(final RestHighLevelClient client) { + this.client = client; + } + + /** + * Gets the user query for a given query ID. + * @param queryId The query ID. + * @return The user query. + * @throws IOException Thrown when there is a problem accessing OpenSearch. + */ + public String getUserQuery(final String queryId) throws IOException { + + // If it's in the cache just get it and return it. + if(userQueryCache.containsKey(queryId)) { + return userQueryCache.get(queryId); + } + + // Cache it and return it. + final UbiQuery ubiQuery = getQueryFromQueryId(queryId); + userQueryCache.put(queryId, ubiQuery.getUserQuery()); + + return ubiQuery.getUserQuery(); + + } + + /** + * Gets the query object for a given query ID. + * @param queryId The query ID. + * @return A {@link UbiQuery} object for the given query ID. + */ + public UbiQuery getQueryFromQueryId(final String queryId) throws IOException { + + final String query = "{\"match\": {\"query_id\": \"" + queryId + "\" }}"; + final WrapperQueryBuilder qb = QueryBuilders.wrapperQuery(query); + + // The query_id should be unique anyway, but we are limiting it to a single result anyway. + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(qb); + searchSourceBuilder.from(0); + searchSourceBuilder.size(1); + + final String[] indexes = {INDEX_UBI_QUERIES}; + + final SearchRequest searchRequest = new SearchRequest(indexes, searchSourceBuilder); + final SearchResponse response = client.search(searchRequest, RequestOptions.DEFAULT); + + // Will only be a single result. + final SearchHit hit = response.getHits().getHits()[0]; + + return gson.fromJson(hit.getSourceAsString(), UbiQuery.class); + + } + + public int getCountOfQueriesForUserQueryHavingResultInRankR(final String userQuery, final String objectId, final int rank) throws IOException { + + int countOfTimesShownAtRank = 0; + + final String query = "{\"match\": {\"user_query\": \"" + userQuery + "\" }}"; + final WrapperQueryBuilder qb = QueryBuilders.wrapperQuery(query); + + final SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.query(qb); + + final String[] indexes = {INDEX_UBI_QUERIES}; + + final SearchRequest searchRequest = new SearchRequest(indexes, searchSourceBuilder); + final SearchResponse response = client.search(searchRequest, RequestOptions.DEFAULT); + + for(final SearchHit searchHit : response.getHits().getHits()) { + + final List queryResponseObjectIds = (List) searchHit.getSourceAsMap().get("query_response_object_ids"); + + if(queryResponseObjectIds.get(rank).equals(objectId)) { + countOfTimesShownAtRank++; + } + + } + + return countOfTimesShownAtRank; + + } + + /** + * Index the rank-aggregated clickthrough values. + * @param rankAggregatedClickThrough A map of position to clickthrough values. + * @throws IOException Thrown when there is a problem accessing OpenSearch. + */ + public void indexRankAggregatedClickthrough(final Map rankAggregatedClickThrough) throws IOException { + + if(!rankAggregatedClickThrough.isEmpty()) { + + // TODO: Split this into multiple bulk insert requests. + + final BulkRequest request = new BulkRequest(); + + for (final int position : rankAggregatedClickThrough.keySet()) { + + final Map jsonMap = new HashMap<>(); + jsonMap.put("position", position); + jsonMap.put("ctr", rankAggregatedClickThrough.get(position)); + + final IndexRequest indexRequest = new IndexRequest(INDEX_RANK_AGGREGATED_CTR).id(UUID.randomUUID().toString()).source(jsonMap); + + request.add(indexRequest); + + } + + client.bulk(request, RequestOptions.DEFAULT); + + } + + } + + /** + * Index the clickthrough rates. + * @param clickthroughRates A map of query IDs to a collection of {@link ClickthroughRate} objects. + * @throws IOException Thrown when there is a problem accessing OpenSearch. + */ + public void indexClickthroughRates(final Map> clickthroughRates) throws IOException { + + if(!clickthroughRates.isEmpty()) { + + // TODO: Split this into multiple bulk insert requests. + + final BulkRequest request = new BulkRequest(); + + for(final String queryId : clickthroughRates.keySet()) { + + for(final ClickthroughRate clickthroughRate : clickthroughRates.get(queryId)) { + + final Map jsonMap = new HashMap<>(); + jsonMap.put("query_id", queryId); + jsonMap.put("clicks", clickthroughRate.getClicks()); + jsonMap.put("events", clickthroughRate.getEvents()); + jsonMap.put("ctr", clickthroughRate.getClickthroughRate()); + + final IndexRequest indexRequest = new IndexRequest(INDEX_QUERY_DOC_CTR).id(UUID.randomUUID().toString()).source(jsonMap); + + request.add(indexRequest); + + } + + } + + client.bulk(request, RequestOptions.DEFAULT); + + } + + } + + /** + * Index the judgments. + * @param judgments A collection of {@link Judgment judgments}. + * @throws IOException Thrown when there is a problem accessing OpenSearch. + */ + public void indexJudgments(final Collection judgments) throws IOException { + + if(!judgments.isEmpty()) { + + // TODO: Split this into multiple bulk insert requests. + + final BulkRequest request = new BulkRequest(); + + for (final Judgment judgment : judgments) { + + final Map jsonMap = new HashMap<>(); + jsonMap.put("query_id", judgment.getQueryId()); + jsonMap.put("query", judgment.getQuery()); + jsonMap.put("document", judgment.getDocument()); + jsonMap.put("judgment", judgment.getJudgment()); + + final IndexRequest indexRequest = new IndexRequest(INDEX_JUDGMENT).id(UUID.randomUUID().toString()).source(jsonMap); + + request.add(indexRequest); + + } + + client.bulk(request, RequestOptions.DEFAULT); + + } + + } + +} \ No newline at end of file diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ClickthroughRate.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ClickthroughRate.java new file mode 100644 index 0000000..a4563a3 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ClickthroughRate.java @@ -0,0 +1,69 @@ +package org.opensearch.eval.model; + +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.opensearch.eval.util.MathUtils; + +/** + * A query result and its number of clicks and total events. + */ +public class ClickthroughRate { + + private final String objectId; + private int clicks; + private int events; + + public ClickthroughRate(final String objectId) { + this.objectId = objectId; + this.clicks = 0; + this.events = 0; + } + + public ClickthroughRate(final String objectId, final int clicks, final int events) { + this.objectId = objectId; + this.clicks = clicks; + this.events = events; + } + + @Override + public boolean equals(Object obj) { + return EqualsBuilder.reflectionEquals(this, obj); + } + + @Override + public int hashCode() { + int result = 17; + result = 29 * result + objectId.hashCode(); + return result; + } + + @Override + public String toString() { + return "object_id: " + objectId + ", clicks: " + clicks + ", events: " + events + ", ctr: " + MathUtils.round(getClickthroughRate()); + } + + public void logClick() { + clicks++; + events++; + } + + public void logEvent() { + events++; + } + + public double getClickthroughRate() { + return (double) clicks / events; + } + + public int getClicks() { + return clicks; + } + + public int getEvents() { + return events; + } + + public String getObjectId() { + return objectId; + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/Judgment.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/Judgment.java new file mode 100644 index 0000000..61f642a --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/Judgment.java @@ -0,0 +1,80 @@ +package org.opensearch.eval.model; + +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.apache.commons.lang3.builder.HashCodeBuilder; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.eval.util.MathUtils; + +import java.util.Collection; + +/** + * A judgment of a search result's quality for a given query. + */ +public class Judgment { + + private static final Logger LOGGER = LogManager.getLogger(Judgment.class.getName()); + + private final String queryId; + private final String query; + private final String document; + private final double judgment; + + public Judgment(final String queryId, final String query, final String document, final double judgment) { + this.queryId = queryId; + this.query = query; + this.document = document; + this.judgment = judgment; + } + + public String toJudgmentString() { + return queryId + ", " + query + ", " + document + ", " + MathUtils.round(judgment); + } + + public static void showJudgments(final Collection judgments) { + + LOGGER.info("query_id, query, document, judgment"); + + for(final Judgment judgment : judgments) { + LOGGER.info(judgment.toJudgmentString()); + } + + } + + @Override + public String toString() { + return "query_id: " + queryId + ", query: " + query + ", document: " + document + ", judgment: " + MathUtils.round(judgment); + } + + @Override + public boolean equals(Object obj) { + return EqualsBuilder.reflectionEquals(this, obj); + } + + @Override + public int hashCode() { + return new HashCodeBuilder(17, 37). + append(queryId). + append(query). + append(document). + append(judgment). + toHashCode(); + } + + public String getQueryId() { + return queryId; + } + + public String getQuery() { + return query; + } + + public String getDocument() { + return document; + } + + public double getJudgment() { + return judgment; + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/EventAttributes.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/EventAttributes.java new file mode 100644 index 0000000..33f78b4 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/EventAttributes.java @@ -0,0 +1,40 @@ +package org.opensearch.eval.model.ubi.event; + +import com.google.gson.annotations.SerializedName; + +public class EventAttributes { + + @SerializedName("object") + private EventObject object; + + @SerializedName("session_id") + private String sessionId; + + @SerializedName("position") + private Position position; + + public EventObject getObject() { + return object; + } + + public void setObject(EventObject object) { + this.object = object; + } + + public String getSessionId() { + return sessionId; + } + + public void setSessionId(String sessionId) { + this.sessionId = sessionId; + } + + public Position getPosition() { + return position; + } + + public void setPosition(Position position) { + this.position = position; + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/EventObject.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/EventObject.java new file mode 100644 index 0000000..c3d3f7d --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/EventObject.java @@ -0,0 +1,29 @@ +package org.opensearch.eval.model.ubi.event; + +import com.google.gson.annotations.SerializedName; + +public class EventObject { + + @SerializedName("object_id_field") + private String objectIdField; + + @SerializedName("object_id") + private String objectId; + + public String getObjectId() { + return objectId; + } + + public void setObjectId(String objectId) { + this.objectId = objectId; + } + + public String getObjectIdField() { + return objectIdField; + } + + public void setObjectIdField(String objectIdField) { + this.objectIdField = objectIdField; + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/Position.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/Position.java new file mode 100644 index 0000000..1f3d60c --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/Position.java @@ -0,0 +1,18 @@ +package org.opensearch.eval.model.ubi.event; + +import com.google.gson.annotations.SerializedName; + +public class Position { + + @SerializedName("index") + private int index; + + public int getIndex() { + return index; + } + + public void setIndex(int index) { + this.index = index; + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/UbiEvent.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/UbiEvent.java new file mode 100644 index 0000000..939fa5b --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/event/UbiEvent.java @@ -0,0 +1,46 @@ +package org.opensearch.eval.model.ubi.event; + +import com.google.gson.annotations.SerializedName; + +/** + * A UBI event. + */ +public class UbiEvent { + + @SerializedName("action_name") + private String actionName; + + @SerializedName("client_id") + private String clientId; + + @SerializedName("query_id") + private String queryId; + + @SerializedName("event_attributes") + private EventAttributes eventAttributes; + + @Override + public String toString() { + return actionName + ", " + clientId + ", " + queryId + ", " + eventAttributes.getObject() + ", " + eventAttributes.getPosition().getIndex(); + } + + public String getActionName() { + return actionName; + } + + public String getClientId() { + return clientId; + } + + public String getQueryId() { + return queryId; + } + + public EventAttributes getEventAttributes() { + return eventAttributes; + } + + public void setEventAttributes(EventAttributes eventAttributes) { + this.eventAttributes = eventAttributes; + } +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/query/QueryResponse.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/query/QueryResponse.java new file mode 100644 index 0000000..c67abd6 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/query/QueryResponse.java @@ -0,0 +1,50 @@ +package org.opensearch.eval.model.ubi.query; + +import java.util.List; + +/** + * A query response for a {@link UbiQuery query}. + */ +public class QueryResponse { + + private final String queryId; + private final String queryResponseId; + private final List queryResponseObjectIds; + + /** + * Creates a query response. + * @param queryId The ID of the query. + * @param queryResponseId The ID of the query response. + * @param queryResponseObjectIds A list of IDs for the hits in the query. + */ + public QueryResponse(final String queryId, final String queryResponseId, final List queryResponseObjectIds) { + this.queryId = queryId; + this.queryResponseId = queryResponseId; + this.queryResponseObjectIds = queryResponseObjectIds; + } + + /** + * Gets the query ID. + * @return The query ID. + */ + public String getQueryId() { + return queryId; + } + + /** + * Gets the query response ID. + * @return The query response ID. + */ + public String getQueryResponseId() { + return queryResponseId; + } + + /** + * Gets the list of query response hit IDs. + * @return A list of query response hit IDs. + */ + public List getQueryResponseObjectIds() { + return queryResponseObjectIds; + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/query/UbiQuery.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/query/UbiQuery.java new file mode 100644 index 0000000..39405b7 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/model/ubi/query/UbiQuery.java @@ -0,0 +1,105 @@ +package org.opensearch.eval.model.ubi.query; + +import com.google.gson.annotations.SerializedName; +import org.apache.commons.lang3.builder.EqualsBuilder; +import org.apache.commons.lang3.builder.HashCodeBuilder; + +import java.util.Map; + +/** + * A UBI query. + */ +public class UbiQuery { + + @SerializedName("timestamp") + private long timestamp; + + @SerializedName("query_id") + private String queryId; + + @SerializedName("client_id") + private String clientId; + + @SerializedName("user_query") + private String userQuery; + + @SerializedName("query") + private String query; + + @SerializedName("query_attributes") + private Map queryAttributes; + + @SerializedName("query_response") + private QueryResponse queryResponse; + + @Override + public boolean equals(Object obj) { + return EqualsBuilder.reflectionEquals(this, obj); + } + + @Override + public int hashCode() { + return new HashCodeBuilder(17, 37). + append(queryId). + append(userQuery). + append(clientId). + toHashCode(); + } + + public long getTimestamp() { + return timestamp; + } + + public void setTimestamp(long timestamp) { + this.timestamp = timestamp; + } + + public String getQueryId() { + return queryId; + } + + public void setQueryId(String queryId) { + this.queryId = queryId; + } + + public void setClientId(String clientId) { + this.clientId = clientId; + } + + public String getClientId() { + return clientId; + } + + public String getUserQuery() { + return userQuery; + } + + public void setUserQuery(String userQuery) { + this.userQuery = userQuery; + } + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public Map getQueryAttributes() { + return queryAttributes; + } + + public void setQueryAttributes(Map queryAttributes) { + this.queryAttributes = queryAttributes; + } + + public QueryResponse getQueryResponse() { + return queryResponse; + } + + public void setQueryResponse(QueryResponse queryResponse) { + this.queryResponse = queryResponse; + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/util/MathUtils.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/util/MathUtils.java new file mode 100644 index 0000000..c0738f8 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/util/MathUtils.java @@ -0,0 +1,18 @@ +package org.opensearch.eval.util; + +public class MathUtils { + + private MathUtils() { + + } + + public static String round(final double value, final int decimalPlaces) { + double factor = Math.pow(10, decimalPlaces); + return String.valueOf(Math.round(value * factor) / factor); + } + + public static String round(final double value) { + return round(value, 3); + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/util/UserQueryHash.java b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/util/UserQueryHash.java new file mode 100644 index 0000000..7dd26ad --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/java/org/opensearch/eval/util/UserQueryHash.java @@ -0,0 +1,36 @@ +package org.opensearch.eval.util; + +import java.util.HashMap; +import java.util.Map; + +public class UserQueryHash { + + private final Map userQueries; + private int count = 1; + + public UserQueryHash() { + this.userQueries = new HashMap<>(); + } + + public int getHash(String userQuery) { + + final int hash; + + if(userQueries.containsKey(userQuery)) { + + return userQueries.get(userQuery); + + } else { + + userQueries.put(userQuery, count); + hash = count; + count++; + + + } + + return hash; + + } + +} diff --git a/opensearch-search-quality-evaluation/src/main/resources/events-mapping.json b/opensearch-search-quality-evaluation/src/main/resources/events-mapping.json new file mode 100644 index 0000000..cdb1393 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/resources/events-mapping.json @@ -0,0 +1,45 @@ +{ + "properties": { + "application": { "type": "keyword", "ignore_above": 256 }, + "action_name": { "type": "keyword", "ignore_above": 100 }, + "client_id": { "type": "keyword", "ignore_above": 100 }, + "message": { "type": "keyword", "ignore_above": 1024 }, + "message_type": { "type": "keyword", "ignore_above": 100 }, + "timestamp": { + "type": "date", + "format":"strict_date_time", + "ignore_malformed": true, + "doc_values": true + }, + "event_attributes": { + "dynamic": true, + "properties": { + "position": { + "properties": { + "ordinal": { "type": "integer" }, + "x": { "type": "integer" }, + "y": { "type": "integer" }, + "page_depth": { "type": "integer" }, + "scroll_depth": { "type": "integer" }, + "trail": { "type": "text", + "fields": { "keyword": { "type": "keyword", "ignore_above": 256 } + } + } + } + }, + "object": { + "properties": { + "internal_id": { "type": "keyword" }, + "object_id": { "type": "keyword", "ignore_above": 256 }, + "object_id_field": { "type": "keyword", "ignore_above": 100 }, + "name": { "type": "keyword", "ignore_above": 256 }, + "description": { "type": "text", + "fields": { "keyword": { "type": "keyword", "ignore_above": 256 } } + }, + "object_detail": { "type": "object" } + } + } + } + } + } +} \ No newline at end of file diff --git a/opensearch-search-quality-evaluation/src/main/resources/log4j2.xml b/opensearch-search-quality-evaluation/src/main/resources/log4j2.xml new file mode 100644 index 0000000..eea5a4c --- /dev/null +++ b/opensearch-search-quality-evaluation/src/main/resources/log4j2.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/clickmodel/coec/CoecClickModelIT.java b/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/clickmodel/coec/CoecClickModelIT.java new file mode 100644 index 0000000..1d81474 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/clickmodel/coec/CoecClickModelIT.java @@ -0,0 +1,159 @@ +package org.opensearch.eval.clickmodel.coec; + +import org.apache.http.HttpHost; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestClientBuilder; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.eval.engine.opensearch.OpenSearchHelper; +import org.opensearch.eval.model.Judgment; + +import java.io.IOException; +import java.util.Collection; +import java.util.UUID; + +public class CoecClickModelIT { + + private static final Logger LOGGER = LogManager.getLogger(CoecClickModelIT.class.getName()); + + @Disabled + @Test + public void calculateJudgmentForDoc1() throws IOException { + + final RestClientBuilder builder = RestClient.builder(new HttpHost("localhost", 9200, "http")); + final RestHighLevelClient restHighLevelClient = new RestHighLevelClient(builder); + +// // Remove any existing indexes. +// final boolean exists = restHighLevelClient.indices().exists(new GetIndexRequest("ubi_events"), RequestOptions.DEFAULT); +// if(exists) { +// restHighLevelClient.indices().delete(new DeleteIndexRequest("ubi_events"), RequestOptions.DEFAULT); +// } +// +// // Create the ubi_events index. +// final CreateIndexRequest createIndexRequest = new CreateIndexRequest("ubi_events").mapping(getResourceFile("/events-mapping.json")); +// restHighLevelClient.indices().create(createIndexRequest, RequestOptions.DEFAULT); + + final int numberOfViews = 250; + final int numberOfClicks = 110; + + final BulkRequest bulkRequest = new BulkRequest(); + + // Index the view. + for(int x = 1; x <= numberOfViews; x++) { + + final String event = "{\n" + + " \"action_name\" : \"view\",\n" + + " \"client_id\" : \"" + UUID.randomUUID() + "\",\n" + + " \"query_id\" : \"" + UUID.randomUUID() + "\",\n" + + " \"message_type\" : null,\n" + + " \"message\" : null,\n" + + " \"timestamp\" : 1.7276472197111678E9,\n" + + " \"event_attributes\" : {\n" + + " \"object\" : {\n" + + " \"object_id_field\" : \"primary_ean\",\n" + + " \"object_id\" : \"0731304258193\",\n" + + " \"description\" : \"APC IT Power Distribution Module 3 Pole 5 Wire 32A IEC309 620cm power distribution unit (PDU)\"\n" + + " },\n" + + " \"position\" : {\n" + + " \"index\" : 7\n" + + " },\n" + + " \"session_id\" : \"d4ed2513-aaa9-48c1-bcb9-e936a4e903a9\"\n" + + " }" + + "}"; + + final IndexRequest indexRequest = new IndexRequest(CoecClickModel.INDEX_UBI_EVENTS) + .id(String.valueOf(x)) + .source(event, XContentType.JSON); + + bulkRequest.add(indexRequest); + + } + + // Index the clicks. + for(int x = 1; x <= numberOfClicks; x++) { + +// final String query = "{\n" + +// " \"user_query\" : \"computer\",\n" + +// " \"query_id\" : \"" + UUID.randomUUID() + "\",\n" + +// " \"message_type\" : null,\n" + +// " \"message\" : null,\n" + +// " \"timestamp\" : 1.7276472197111678E9,\n" + +// " \"event_attributes\" : {\n" + +// " \"object\" : {\n" + +// " \"object_id_field\" : \"primary_ean\",\n" + +// " \"object_id\" : \"0731304258193\",\n" + +// " \"description\" : \"APC IT Power Distribution Module 3 Pole 5 Wire 32A IEC309 620cm power distribution unit (PDU)\"\n" + +// " },\n" + +// " \"position\" : {\n" + +// " \"index\" : 7\n" + +// " },\n" + +// " \"session_id\" : \"" + UUID.randomUUID() + "\"\n" + +// " }" + +// "}"; +// +// final IndexRequest queryIndexRequest = new IndexRequest(CoecClickModel.INDEX_UBI_QUERIES) +// .id(String.valueOf(x)) +// .source(query, XContentType.JSON); +// bulkRequest.add(queryIndexRequest); + + final String event = "{\n" + + " \"action_name\" : \"click\",\n" + + " \"client_id\" : \"" + UUID.randomUUID() + "\",\n" + + " \"query_id\" : \"" + UUID.randomUUID() + "\",\n" + + " \"message_type\" : null,\n" + + " \"message\" : null,\n" + + " \"timestamp\" : 1.7276472197111678E9,\n" + + " \"event_attributes\" : {\n" + + " \"object\" : {\n" + + " \"object_id_field\" : \"primary_ean\",\n" + + " \"object_id\" : \"0731304258193\",\n" + + " \"description\" : \"APC IT Power Distribution Module 3 Pole 5 Wire 32A IEC309 620cm power distribution unit (PDU)\"\n" + + " },\n" + + " \"position\" : {\n" + + " \"index\" : 7\n" + + " },\n" + + " \"session_id\" : \"" + UUID.randomUUID() + "\"\n" + + " }" + + "}"; + + final IndexRequest eventIndexRequest = new IndexRequest(CoecClickModel.INDEX_UBI_EVENTS) + .id(String.valueOf(x)) + .source(event, XContentType.JSON); + bulkRequest.add(eventIndexRequest); + + } + + restHighLevelClient.bulk(bulkRequest, RequestOptions.DEFAULT); + + final OpenSearchHelper openSearchHelper = new OpenSearchHelper(restHighLevelClient); + + final CoecClickModelParameters coecClickModelParameters = new CoecClickModelParameters(true, 20); + final CoecClickModel coecClickModel = new CoecClickModel(coecClickModelParameters); + + final Collection judgments = coecClickModel.calculateJudgments(); + Judgment.showJudgments(judgments); + + } +// +// private XContentBuilder getResourceFile(final String fileName) { +// try (InputStream is = CoecClickModelIT.class.getResourceAsStream(fileName)) { +// ByteArrayOutputStream out = new ByteArrayOutputStream(); +// Streams.copy(is.readAllBytes(), out); +// +// final String message = out.toString(StandardCharsets.UTF_8); +// +// return XContentFactory.jsonBuilder().value(message); +// +// } catch (IOException e) { +// throw new IllegalStateException("Unable to get mapping from resource [" + fileName + "]", e); +// } +// } + +} diff --git a/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/clickmodel/coec/CoecClickModelTest.java b/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/clickmodel/coec/CoecClickModelTest.java new file mode 100644 index 0000000..94f7e60 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/clickmodel/coec/CoecClickModelTest.java @@ -0,0 +1,119 @@ +package org.opensearch.eval.clickmodel.coec; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.eval.engine.opensearch.OpenSearchHelper; +import org.opensearch.eval.model.ClickthroughRate; +import org.opensearch.eval.model.Judgment; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.when; + +public class CoecClickModelTest { + + private static final Logger LOGGER = LogManager.getLogger(CoecClickModelTest.class.getName()); + + @Disabled + @Test + public void calculateJudgmentForDoc1() throws IOException { + + final RestHighLevelClient restHighLevelClient = Mockito.mock(RestHighLevelClient.class); + final OpenSearchHelper openSearchHelper = Mockito.mock(OpenSearchHelper.class); + + when(openSearchHelper.getCountOfQueriesForUserQueryHavingResultInRankR(anyString(), anyString(), anyInt())).thenReturn(250); + + final CoecClickModelParameters parameters = new CoecClickModelParameters(false, 20); + final CoecClickModel coecClickModel = new CoecClickModel(parameters); + + final Map rankAggregatedClickThrough = new HashMap<>(); + rankAggregatedClickThrough.put(1, 0.450); + + final Set ctrs = new HashSet<>(); + ctrs.add(new ClickthroughRate("object_id_1", 110, 250)); + + final Map> clickthroughRates = new HashMap<>(); + clickthroughRates.put("computer", ctrs); + + final Collection< Judgment> judgments = coecClickModel.calculateCoec(rankAggregatedClickThrough, clickthroughRates); + + Judgment.showJudgments(judgments); + + Assertions.assertEquals(1, judgments.size()); + Assertions.assertEquals(0.9777777777777777, judgments.iterator().next().getJudgment(), 0.01); + + } + + @Disabled + @Test + public void calculateJudgmentForDoc2() throws IOException { + + final RestHighLevelClient restHighLevelClient = Mockito.mock(RestHighLevelClient.class); + final OpenSearchHelper openSearchHelper = Mockito.mock(OpenSearchHelper.class); + + when(openSearchHelper.getCountOfQueriesForUserQueryHavingResultInRankR(anyString(), anyString(), anyInt())).thenReturn(124); + + final CoecClickModelParameters parameters = new CoecClickModelParameters(false, 20); + final CoecClickModel coecClickModel = new CoecClickModel(parameters); + + final Map rankAggregatedClickThrough = new HashMap<>(); + rankAggregatedClickThrough.put(2, 0.175); + + final Set ctrs = new HashSet<>(); + ctrs.add(new ClickthroughRate("object_id_2", 31, 124)); + + final Map> clickthroughRates = new HashMap<>(); + clickthroughRates.put("computer", ctrs); + + final Collection< Judgment> judgments = coecClickModel.calculateCoec(rankAggregatedClickThrough, clickthroughRates); + + Judgment.showJudgments(judgments); + + Assertions.assertEquals(1, judgments.size()); + Assertions.assertEquals(1.4285714285714286, judgments.iterator().next().getJudgment(), 0.01); + + } + + @Disabled + @Test + public void calculateJudgmentForDoc3() throws IOException { + + final RestHighLevelClient restHighLevelClient = Mockito.mock(RestHighLevelClient.class); + final OpenSearchHelper openSearchHelper = Mockito.mock(OpenSearchHelper.class); + + when(openSearchHelper.getCountOfQueriesForUserQueryHavingResultInRankR(anyString(), anyString(), anyInt())).thenReturn(240); + + final CoecClickModelParameters parameters = new CoecClickModelParameters(false, 20); + final CoecClickModel coecClickModel = new CoecClickModel(parameters); + + final Map rankAggregatedClickThrough = new HashMap<>(); + rankAggregatedClickThrough.put(3, 0.075); + + final Set ctrs = new HashSet<>(); + ctrs.add(new ClickthroughRate("object_id_3", 30, 240)); + + final Map> clickthroughRates = new HashMap<>(); + clickthroughRates.put("computer", ctrs); + + final Collection< Judgment> judgments = coecClickModel.calculateCoec(rankAggregatedClickThrough, clickthroughRates); + + Judgment.showJudgments(judgments); + + Assertions.assertEquals(1, judgments.size()); + Assertions.assertEquals(1.6666666666666667, judgments.iterator().next().getJudgment(), 0.01); + + } + +} diff --git a/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/model/ubi/UbiEventTest.java b/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/model/ubi/UbiEventTest.java new file mode 100644 index 0000000..39ffcc0 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/test/java/org/opensearch/eval/model/ubi/UbiEventTest.java @@ -0,0 +1,5 @@ +package org.opensearch.eval.model.ubi; + +public class UbiEventTest { + +} diff --git a/opensearch-search-quality-evaluation/src/test/resources/log4j2.xml b/opensearch-search-quality-evaluation/src/test/resources/log4j2.xml new file mode 100644 index 0000000..eea5a4c --- /dev/null +++ b/opensearch-search-quality-evaluation/src/test/resources/log4j2.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/opensearch-search-quality-evaluation/src/test/resources/ubi_event.json b/opensearch-search-quality-evaluation/src/test/resources/ubi_event.json new file mode 100644 index 0000000..d68e4e6 --- /dev/null +++ b/opensearch-search-quality-evaluation/src/test/resources/ubi_event.json @@ -0,0 +1,27 @@ +{ + "_index" : "ubi_events", + "_id" : "520ff171-e6b4-4900-9c63-dba48a9753f7", + "_score" : 1.0, + "_ignored" : [ + "timestamp" + ], + "_source" : { + "action_name" : "click", + "client_id" : "2ce5eece-53cf-4b0c-9c55-bbeb57ad8642", + "query_id" : "efbeb66a-5b6b-48bd-89a9-33b171f95b2b", + "message_type" : null, + "message" : null, + "timestamp" : 1.7276472197111678E9, + "event_attributes" : { + "object" : { + "object_id_field" : "primary_ean", + "object_id" : "0731304258193", + "description" : "APC IT Power Distribution Module 3 Pole 5 Wire 32A IEC309 620cm power distribution unit (PDU)" + }, + "position" : { + "index" : 7 + }, + "session_id" : "d4ed2513-aaa9-48c1-bcb9-e936a4e903a9" + } + } +} \ No newline at end of file