Skip to content

Commit

Permalink
airframe-sql: Fixes #2646 Resolve AllColumns inputs as ResolvedAttrib…
Browse files Browse the repository at this point in the history
…utes (#2649)

- Resolve AllColumns inputs so as not to pull-up too much details
- Make explicit that ResolvedAttribute.sourceColumns is given only when the attribute is a direct reference of table columns without any change
  • Loading branch information
xerial authored Dec 20, 2022
1 parent 185ad21 commit 7c0616e
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ object TypeResolver extends LogSupport {
object resolveAggregationIndexes extends RewriteRule {
def apply(context: AnalyzerContext): PlanRewriter = {
case a @ Aggregate(child, selectItems, groupingKeys, having, _) =>
var changed = false
val resolvedGroupingKeys: List[GroupingKey] = groupingKeys.map {
case k @ GroupingKey(LongLiteral(i, _), _) if i <= selectItems.length =>
// Use a simpler form of attributes
Expand All @@ -92,11 +93,16 @@ object TypeResolver extends LogSupport {
case other =>
other
}
changed = true
GroupingKey(keyItem, k.nodeLocation)
case other =>
other
}
Aggregate(child, selectItems, resolvedGroupingKeys, having, a.nodeLocation)
if (changed) {
Aggregate(child, selectItems, resolvedGroupingKeys, having, a.nodeLocation)
} else {
a
}
}
}

Expand All @@ -108,16 +114,17 @@ object TypeResolver extends LogSupport {
object resolveAggregationKeys extends RewriteRule {
def apply(context: AnalyzerContext): PlanRewriter = {
case a @ Aggregate(child, selectItems, groupingKeys, having, _) =>
val resolvedChild = resolveRelation(context, child)
val inputAttributes = resolvedChild.outputAttributes
val resolvedChild = resolveRelation(context, child)
val childOutputAttributes = resolvedChild.outputAttributes
val resolvedGroupingKeys =
groupingKeys.map(x => {
val e = resolveExpression(context, x.child, inputAttributes, false)
val e = resolveExpression(context, x.child, childOutputAttributes)
GroupingKey(e, e.nodeLocation)
})
val resolvedHaving = having.map {
_.transformUpExpression { case x: Expression =>
resolveExpression(context, x, a.outputAttributes, false)
// Having recognize attributes only from the input relation
resolveExpression(context, x, childOutputAttributes)
}
}
Aggregate(resolvedChild, selectItems, resolvedGroupingKeys, resolvedHaving, a.nodeLocation)
Expand All @@ -144,7 +151,7 @@ object TypeResolver extends LogSupport {
val resolvedChild = resolveRelation(context, child)
val inputAttributes = resolvedChild.outputAttributes
val resolvedSortItems = sortItems.map { sortItem =>
val e = resolveExpression(context, sortItem.sortKey, inputAttributes, false)
val e = resolveExpression(context, sortItem.sortKey, inputAttributes)
sortItem.copy(sortKey = e)
}
s.copy(orderBy = resolvedSortItems)
Expand Down Expand Up @@ -198,7 +205,7 @@ object TypeResolver extends LogSupport {
val resolvedJoin =
Join(joinType, resolveRelation(context, left), resolveRelation(context, right), u, j.nodeLocation)
val resolvedJoinKeys: Seq[Expression] = joinKeys.flatMap { k =>
findMatchInInputAttributes(context, k, resolvedJoin.inputAttributes, false) match {
findMatchInInputAttributes(context, k, resolvedJoin.inputAttributes) match {
case x if x.size < 2 =>
throw SQLErrorCode.ColumnNotFound.newException(
s"join key column: ${k.sqlExpr} is not found",
Expand All @@ -214,7 +221,7 @@ object TypeResolver extends LogSupport {
val resolvedJoin =
Join(joinType, resolveRelation(context, left), resolveRelation(context, right), u, j.nodeLocation)
val resolvedJoinKeys: Seq[Expression] = Seq(leftKey, rightKey).flatMap { k =>
findMatchInInputAttributes(context, k, resolvedJoin.inputAttributes, false) match {
findMatchInInputAttributes(context, k, resolvedJoin.inputAttributes) match {
case Nil =>
throw SQLErrorCode.ColumnNotFound.newException(
s"join key column: ${k.sqlExpr} is not found",
Expand Down Expand Up @@ -243,12 +250,12 @@ object TypeResolver extends LogSupport {
def apply(context: AnalyzerContext): PlanRewriter = {
case filter @ Filter(child, filterExpr, _) =>
filter.transformUpExpressions { case x: Expression =>
resolveExpression(context, x, filter.inputAttributes, true)
resolveExpression(context, x, filter.inputAttributes)
}
case u: Union => u // UNION is resolved later by resolveUnion()
case u: Intersect => u // INTERSECT is resolved later by resolveIntersect()
case r: Relation =>
r.transformUpExpressions { case x: Expression => resolveExpression(context, x, r.inputAttributes, true) }
r.transformUpExpressions { case x: Expression => resolveExpression(context, x, r.inputAttributes) }
}
}

Expand Down Expand Up @@ -279,7 +286,7 @@ object TypeResolver extends LogSupport {
// // TODO check (prefix).* to resolve attributes
// resolvedColumns ++= inputAttributes
case SingleColumn(expr, alias, qualifier, nodeLocation) =>
resolveExpression(context, expr, inputAttributes, true) match {
resolveExpression(context, expr, inputAttributes) match {
case s: SingleColumn =>
resolvedColumns += s
case r: ResolvedAttribute if alias.isEmpty =>
Expand Down Expand Up @@ -324,15 +331,14 @@ object TypeResolver extends LogSupport {
private def findMatchInInputAttributes(
context: AnalyzerContext,
expr: Expression,
inputAttributes: Seq[Attribute],
isSelectItem: Boolean
inputAttributes: Seq[Attribute]
): List[Expression] = {
val resolvedAttributes = inputAttributes.map(resolveAttribute)

def matchedColumn(matched: Seq[Expression], name: String): Expression = {
matched match {
case attrs if attrs.size == 1 => attrs.head
case attrs => MultiColumn(attrs, Some(name), None)
case attrs => MultiColumn(attrs, Some(name), None, None)
}
}

Expand Down Expand Up @@ -373,18 +379,44 @@ object TypeResolver extends LogSupport {
}
}.distinct

def toResolvedAttribute(name: String, expr: Expression): ResolvedAttribute = {
ResolvedAttribute(
name,
expr.dataType,
None,
Seq.empty,
expr.nodeLocation
)
}

val results = expr match {
case i: Identifier =>
lookup(i.value).map {
case a: Attribute => a
case m: MultiColumn => m
// retain alias for select column
case expr => if (isSelectItem) SingleColumn(expr, Some(i.value), None, expr.nodeLocation) else expr
// No need to resolve Attribute expressions
case a: Attribute => a
case expr =>
// Resolve expr as ResolvedAttribute so as not to pull-up too much details
toResolvedAttribute(i.value, expr)
}
case u @ UnresolvedAttribute(name, _) =>
lookup(name)
case a @ AllColumns(_, None, _) =>
List(a.copy(columns = Some(inputAttributes)))
// Resolve the inputs of AllColumn as ResolvedAttribute
// so as not to pull up too much details
val allColumns = resolvedAttributes.map {
case s @ SingleColumn(m: MultiColumn, alias, qualifier, _) =>
// Pull-up MultiColumn to simplify the expression
m.copy(alias = m.alias.orElse(s.alias), qualifier = qualifier)
case m: MultiColumn =>
// MultiColumn is already resolved
m
case r: ResolvedAttribute =>
// This path preserves already resolved column tags
r
case other =>
toResolvedAttribute(other.name, other)
}
List(a.copy(columns = Some(allColumns)))
case _ =>
List(expr)
}
Expand All @@ -398,10 +430,9 @@ object TypeResolver extends LogSupport {
private def resolveExpression(
context: AnalyzerContext,
expr: Expression,
inputAttributes: Seq[Attribute],
isSelectItem: Boolean
inputAttributes: Seq[Attribute]
): Expression = {
findMatchInInputAttributes(context, expr, inputAttributes, isSelectItem) match {
findMatchInInputAttributes(context, expr, inputAttributes) match {
case lst if lst.length > 1 =>
trace(s"${expr} -> ${lst}")
throw SQLErrorCode.SyntaxError.newException(s"${expr.sqlExpr} is ambiguous", expr.nodeLocation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ sealed trait Expression extends TreeNode[Expression] with Product {
}
}

def typeDescriptionWithNodeType: String = {
s"[${this.getClass.getSimpleName.replaceAll("$", "")}] ${typeDescription}"
}

/**
* Column name without qualifier
* @return
Expand Down Expand Up @@ -310,12 +314,11 @@ object Expression {
override def toString = {
columns match {
case Some(attrs) if attrs.nonEmpty =>
val tables = attrs
.collect { case a: ResolvedAttribute =>
a.sourceColumns.map(_.table)
}.flatten.distinct
s"AllColumns(${tables.map(t => s"${t.name}.*").mkString(", ")})"
case _ => s"AllColumns(${name})"
// Show the detailed node type for the ease of debugging
val inputs = attrs.map(_.typeDescriptionWithNodeType).mkString(", ")
s"AllColumns(${inputs})"
case _ =>
s"AllColumns(${qualifier.map(q => s"${q}.").getOrElse("")}*)"
}
}

Expand Down Expand Up @@ -348,10 +351,8 @@ object Expression {
qualifier: Option[String] = None,
nodeLocation: Option[NodeLocation]
) extends Attribute {
override def name: String = alias.getOrElse(expr.attributeName)
override def dataTypeName: String = {
expr.dataTypeName
}
override def name: String = alias.getOrElse(expr.attributeName)
override def dataType: DataType = expr.dataType

override def children: Seq[Expression] = Seq(expr)
override def toString = s"SingleColumn(${alias.map(a => s"${expr} as ${a}").getOrElse(s"${expr}")})"
Expand Down Expand Up @@ -382,7 +383,7 @@ object Expression {
Seq(this)
} else if (expr.isInstanceOf[MultiColumn]) {
val m = expr.asInstanceOf[MultiColumn]
if (m.name.contains(s"${tableName}.${columnName}")) {
if (m.alias.contains(s"${tableName}.${columnName}")) {
Seq(m)
} else {
m.matched(tableName, columnName)
Expand All @@ -395,7 +396,7 @@ object Expression {
Seq(this)
} else if (expr.isInstanceOf[MultiColumn]) {
val m = expr.asInstanceOf[MultiColumn]
if (m.name.contains(columnName)) {
if (m.alias.contains(columnName)) {
Seq(m)
} else {
m.matched(columnName)
Expand All @@ -405,7 +406,7 @@ object Expression {

private def matchesWithMultiColumn(tableName: String, columnName: String): Boolean = {
expr match {
case MultiColumn(inputs, _, _) =>
case MultiColumn(inputs, _, _, _) =>
inputs.exists {
case r: ResolvedAttribute => r.matchesWith(tableName, columnName)
case s: SingleColumn => s.matchesWith(tableName, columnName)
Expand All @@ -418,7 +419,7 @@ object Expression {

private def matchesWithMultiColumn(columnName: String): Boolean = {
expr match {
case MultiColumn(inputs, name, _) =>
case MultiColumn(inputs, name, _, _) =>
name.contains(columnName) || inputs.exists {
case r: ResolvedAttribute => r.name == columnName
case s: SingleColumn => s.matchesWith(columnName)
Expand All @@ -429,28 +430,39 @@ object Expression {
}
}
}

/**
* A single column merged from multiple input expressions (e.g., union, join)
* @param inputs
* @param alias
* @param nodeLocation
*/
case class MultiColumn(
inputs: Seq[Expression],
name: Option[String],
alias: Option[String],
qualifier: Option[String],
nodeLocation: Option[NodeLocation]
) extends Expression
with LogSupport {
) extends Attribute {
require(inputs.nonEmpty, s"The inputs of MultiColumn should not be empty: ${this}")

override def children: Seq[Expression] = inputs

override def attributeName: String = {
name.getOrElse(inputs.head.attributeName)
override def name: String = {
alias.getOrElse(inputs.head.attributeName)
}

override def withQualifier(newQualifier: String): Attribute = {
this.copy(qualifier = Some(newQualifier))
}

override def dataType: DataType = {
inputs.head.dataType
}

override def toString: String = s"MultiColumn(${inputs.mkString(", ")}${name.map(" as " + _).getOrElse("")})"
override def toString: String = s"MultiColumn(${inputs.mkString(", ")}${alias.map(" as " + _).getOrElse("")})"

def matched(tableName: String, columnName: String): Seq[Expression] = {
if (name.contains(s"${tableName}.${columnName}")) {
if (alias.contains(s"${tableName}.${columnName}")) {
Seq(this)
} else {
inputs.collect {
Expand All @@ -462,7 +474,7 @@ object Expression {
}

def matched(columnName: String): Seq[Expression] = {
if (name.contains(columnName)) {
if (alias.contains(columnName)) {
Seq(this)
} else {
inputs.collect {
Expand Down Expand Up @@ -570,8 +582,14 @@ object Expression {
window: Option[Window],
nodeLocation: Option[NodeLocation]
) extends Expression {
// TODO: Resolve the function return type using a function catalog
// override def dataType: DataType = super.dataType
override def dataType: DataType = {
if (functionName == "count") {
DataType.LongType
} else {
// TODO: Resolve the function return type using a function catalog
DataType.UnknownType
}
}

override def children: Seq[Expression] = args ++ filter.toSeq ++ window.toSeq
def functionName: String = name.toString.toLowerCase(Locale.US)
Expand Down Expand Up @@ -697,7 +715,16 @@ object Expression {
right: Expression,
nodeLocation: Option[NodeLocation]
) extends ArithmeticExpression
with BinaryExpression
with BinaryExpression {
override def dataType: DataType = {
if (left.dataType == right.dataType) {
left.dataType
} else {
// TODO type escalation e.g., (Double) op (Long) -> (Double)
DataType.UnknownType
}
}
}
case class ArithmeticUnaryExpr(sign: Sign, child: Expression, nodeLocation: Option[NodeLocation])
extends ArithmeticExpression
with UnaryExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ object LogicalPlan {
}
}
val columns = (0 until values.head.size).map { i =>
SingleColumn(MultiColumn(values.map(_(i)), None, None), None, None, None)
SingleColumn(MultiColumn(values.map(_(i)), None, None, None), None, None, None)
}
columns
}
Expand Down Expand Up @@ -712,7 +712,7 @@ object LogicalPlan {
if (dupAttrs.isEmpty) {
r
} else {
SingleColumn(MultiColumn(Seq(r) ++ dupAttrs, None, None), None, None, None)
SingleColumn(MultiColumn(Seq(r) ++ dupAttrs, None, None, None), None, None, None)
}
case r => r
}
Expand Down Expand Up @@ -767,6 +767,7 @@ object LogicalPlan {
MultiColumn(
relations.map(_.outputAttributes(i)),
Some(output.name),
None,
output.nodeLocation
),
None,
Expand Down Expand Up @@ -806,7 +807,7 @@ object LogicalPlan {
override def outputAttributes: Seq[Attribute] = {
relations.head.outputAttributes.zipWithIndex.map { case (output, i) =>
SingleColumn(
MultiColumn(relations.map(_.outputAttributes(i)), Some(output.name), output.nodeLocation),
MultiColumn(relations.map(_.outputAttributes(i)), Some(output.name), None, output.nodeLocation),
None,
output match {
case r: ResolvedAttribute => r.qualifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ object LogicalPlanPrinter extends LogSupport {
}
s": ${printAttr(inputAttrs)} => ${printAttr(outputAttrs)}"
}
val prefix = s"${ws}[${m.modelName}]${functionSig}"

val prefix = m match {
case t: TableScan =>
s"${ws}[${m.modelName}] ${t.table.fullName}${functionSig}"
case _ =>
s"${ws}[${m.modelName}]${functionSig}"
}

attr.length match {
case 0 =>
out.println(prefix)
Expand Down
Loading

0 comments on commit 7c0616e

Please sign in to comment.