From add837afd15e1816e6415387559f51bd04996ef9 Mon Sep 17 00:00:00 2001 From: impadmin Date: Thu, 15 Oct 2020 17:15:17 +0530 Subject: [PATCH] [CALCITE-4233] In Elasticsearch adapter, support generating disjunction max (dis_max) queries (shlok7296) close apache/calcite#2218 --- .../elasticsearch/ElasticsearchFilter.java | 17 +++++++- .../adapter/elasticsearch/QueryBuilders.java | 37 +++++++++++++++++ .../ElasticSearchAdapterTest.java | 41 +++++++++++++++++++ 3 files changed, 94 insertions(+), 1 deletion(-) diff --git a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchFilter.java b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchFilter.java index dee9753d15cc..ff6e72ba0bd4 100644 --- a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchFilter.java +++ b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/ElasticsearchFilter.java @@ -23,7 +23,9 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.ObjectMapper; @@ -31,6 +33,7 @@ import java.io.IOException; import java.io.StringWriter; import java.io.UncheckedIOException; +import java.util.Iterator; import java.util.Objects; /** @@ -82,7 +85,19 @@ String translateMatch(RexNode condition) throws IOException, StringWriter writer = new StringWriter(); JsonGenerator generator = mapper.getFactory().createGenerator(writer); - QueryBuilders.constantScoreQuery(PredicateAnalyzer.analyze(condition)).writeJson(generator); + boolean disMax = condition.isA(SqlKind.OR); + Iterator operands = ((RexCall) condition).getOperands().iterator(); + while (operands.hasNext() && !disMax) { + if (operands.next().isA(SqlKind.OR)) { + disMax = true; + break; + } + } + if (disMax) { + QueryBuilders.disMaxQueryBuilder(PredicateAnalyzer.analyze(condition)).writeJson(generator); + } else { + QueryBuilders.constantScoreQuery(PredicateAnalyzer.analyze(condition)).writeJson(generator); + } generator.flush(); generator.close(); return "{\"query\" : " + writer.toString() + "}"; diff --git a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/QueryBuilders.java b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/QueryBuilders.java index afb4e72aeacb..c7fa1647af6b 100644 --- a/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/QueryBuilders.java +++ b/elasticsearch/src/main/java/org/apache/calcite/adapter/elasticsearch/QueryBuilders.java @@ -186,6 +186,16 @@ static ConstantScoreQueryBuilder constantScoreQuery(QueryBuilder queryBuilder) { return new ConstantScoreQueryBuilder(queryBuilder); } + /** + * A query that wraps another query and simply returns a dismax score equal to the + * query boost for every document in the query. + * + * @param queryBuilder The query to wrap in a constant score query + */ + static DisMaxQueryBuilder disMaxQueryBuilder(QueryBuilder queryBuilder) { + return new DisMaxQueryBuilder(queryBuilder); + } + /** * A filter to filter only documents where a field exists in them. * @@ -540,6 +550,33 @@ private ConstantScoreQueryBuilder(final QueryBuilder builder) { } } + /** + * A query that wraps a filter and simply returns a dismax score equal to the + * query boost for every document in the filter. + */ + static class DisMaxQueryBuilder extends QueryBuilder { + + private final QueryBuilder builder; + + private DisMaxQueryBuilder(final QueryBuilder builder) { + this.builder = Objects.requireNonNull(builder, "builder"); + } + + @Override void writeJson(final JsonGenerator generator) throws IOException { + generator.writeStartObject(); + generator.writeFieldName("dis_max"); + generator.writeStartObject(); + generator.writeFieldName("queries"); + generator.writeStartArray(); + builder.writeJson(generator); + generator.writeEndArray(); + generator.writeEndObject(); + generator.writeEndObject(); + } + } + + + /** * A query that matches on all documents. *
diff --git a/elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java b/elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java
index 70999aaac931..f42faf3bb52c 100644
--- a/elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java
+++ b/elasticsearch/src/test/java/org/apache/calcite/adapter/elasticsearch/ElasticSearchAdapterTest.java
@@ -295,6 +295,19 @@ private static Consumer sortedResultSetChecker(String column,
         .query("select _MAP['state'] from elastic.zips order by _MAP['city']")
         .returnsCount(ZIPS_SIZE);
 
+    CalciteAssert.that()
+        .with(newConnectionFactory())
+        .query("select * from elastic.zips where _MAP['state'] = 'NY' or "
+            + "_MAP['city'] = 'BROOKLYN'"
+            + " order by _MAP['city']")
+        .queryContains(
+            ElasticsearchChecker.elasticsearchChecker(
+                "query:{'dis_max':{'queries':[{'bool':{'should':"
+                    + "[{'term':{'state':'NY'}},{'term':"
+                    + "{'city':'BROOKLYN'}}]}}]}},'sort':[{'city':'asc'}]",
+                String.format(Locale.ROOT, "size:%s",
+                    ElasticsearchTransport.DEFAULT_FETCH_SIZE)));
+
     CalciteAssert.that()
         .with(newConnectionFactory())
         .query("select _MAP['city'] from elastic.zips where _MAP['state'] = 'NY' "
@@ -421,6 +434,34 @@ private static Consumer sortedResultSetChecker(String column,
         .explainContains(explain);
   }
 
+  @Test public void testDismaxQuery() {
+    final String sql = "select * from zips\n"
+        + "where state = 'CA' or pop >= 94000\n"
+        + "order by state, pop";
+    final String explain = "PLAN=ElasticsearchToEnumerableConverter\n"
+        + "  ElasticsearchSort(sort0=[$4], sort1=[$3], dir0=[ASC], dir1=[ASC])\n"
+        + "    ElasticsearchProject(city=[CAST(ITEM($0, 'city')):VARCHAR(20)], longitude=[CAST(ITEM(ITEM($0, 'loc'), 0)):FLOAT], latitude=[CAST(ITEM(ITEM($0, 'loc'), 1)):FLOAT], pop=[CAST(ITEM($0, 'pop')):INTEGER], state=[CAST(ITEM($0, 'state')):VARCHAR(2)], id=[CAST(ITEM($0, 'id')):VARCHAR(5)])\n"
+        + "      ElasticsearchFilter(condition=[OR(=(CAST(ITEM($0, 'state')):VARCHAR(2), 'CA'), >=(CAST(ITEM($0, 'pop')):INTEGER, 94000))])\n"
+        + "        ElasticsearchTableScan(table=[[elastic, zips]])\n\n";
+    calciteAssert()
+        .query(sql)
+        .queryContains(
+            ElasticsearchChecker.elasticsearchChecker("'query' : "
+                    + "{'dis_max':{'queries':[{bool:"
+                    + "{should:[{term:{state:'CA'}},"
+                    + "{range:{pop:{gte:94000}}}]}}]}}",
+                "'script_fields': {longitude:{script:'params._source.loc[0]'}, "
+                    + "latitude:{script:'params._source.loc[1]'}, "
+                    + "city:{script: 'params._source.city'}, "
+                    + "pop:{script: 'params._source.pop'}, "
+                    + "state:{script: 'params._source.state'}, "
+                    + "id:{script: 'params._source.id'}}",
+                "sort: [ {state: 'asc'}, {pop: 'asc'}]",
+                String.format(Locale.ROOT, "size:%s",
+                    ElasticsearchTransport.DEFAULT_FETCH_SIZE)))
+        .explainContains(explain);
+  }
+
   @Test void testFilterSortDesc() {
     final String sql = "select * from zips\n"
         + "where pop BETWEEN 95000 AND 100000\n"